Skip to content

Commit 2fe116d

Browse files
authored
Update dppl_logistic_regression.jl
1 parent 0db6027 commit 2fe116d

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

models/dppl_logistic_regression.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
using StatsFuns: logistic
2-
using LazyArrays
32

43
d, n = 100, 10_000
54
X = randn(d, n)
65
w = randn(d)
76
y = Int.(logistic.(X' * w) .> 0.5)
87

9-
function safelogistic(x::T) where {T}
10-
logistic(x) * (1 - 2 * eps(T)) + eps(T)
11-
end
12-
13-
lazyarray(f, x) = LazyArray(Base.broadcasted(f, x))
148

159
@model function dppl_logistic_regression(Xt, y)
1610
N, D = size(Xt)
17-
w ~ filldist(Normal(), D)
18-
y ~ arraydist(lazyarray(x -> Bernoulli(safelogistic(x)), Xt * w))
11+
w ~ product_distribution(Normal.(zeros(D)))
12+
y ~ product_distribution(Bernoulli.(logistic.(Xt * w)))
1913
end
2014

2115
model = dppl_logistic_regression(X', y)

0 commit comments

Comments
 (0)