We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0db6027 commit 2fe116dCopy full SHA for 2fe116d
models/dppl_logistic_regression.jl
@@ -1,21 +1,15 @@
1
using StatsFuns: logistic
2
-using LazyArrays
3
4
d, n = 100, 10_000
5
X = randn(d, n)
6
w = randn(d)
7
y = Int.(logistic.(X' * w) .> 0.5)
8
9
-function safelogistic(x::T) where {T}
10
- logistic(x) * (1 - 2 * eps(T)) + eps(T)
11
-end
12
-
13
-lazyarray(f, x) = LazyArray(Base.broadcasted(f, x))
14
15
@model function dppl_logistic_regression(Xt, y)
16
N, D = size(Xt)
17
- w ~ filldist(Normal(), D)
18
- y ~ arraydist(lazyarray(x -> Bernoulli(safelogistic(x)), Xt * w))
+ w ~ product_distribution(Normal.(zeros(D)))
+ y ~ product_distribution(Bernoulli.(logistic.(Xt * w)))
19
end
20
21
model = dppl_logistic_regression(X', y)
0 commit comments