-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Labels
this-repoSomething to do with just this repoSomething to do with just this repo
Description
There's one in test/mcmc/hmc.jl, there's also one in the docs (see if it's the same)
@testset "multivariate support" begin
# Define NN flow
function nn(x, b1, w11, w12, w13, bo, wo)
h = tanh.([w11 w12 w13]' * x .+ b1)
return logistic(dot(wo, h) + bo)
end
# Generating training data
N = 20
M = N ÷ 4
x1s = rand(M) * 5
x2s = rand(M) * 5
xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M])
append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M]))
xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M])
append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M]))
xs = [xt1s; xt0s]
ts = [ones(M); ones(M); zeros(M); zeros(M)]
# Define model
alpha = 0.16 # regularizatin term
var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior
@model function bnn(ts)
b1 ~ MvNormal(
[0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
)
w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
bo ~ Normal(0, var_prior)
wo ~ MvNormal(
[0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
)
for i in rand(1:N, 10)
y = nn(xs[i], b1, w11, w12, w13, bo, wo)
ts[i] ~ Bernoulli(y)
end
return b1, w11, w12, w13, bo, wo
end
# Sampling
chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5; adtype=adbackend), 10)
end
Metadata
Metadata
Assignees
Labels
this-repoSomething to do with just this repoSomething to do with just this repo