-
Notifications
You must be signed in to change notification settings - Fork 228
Closed
Description
Hello,
I was trying to use the macro Turing.@addlogprob! in a case where I wanted to use PG sampler, but it could not work.
I know that it is for advanced user, however
when trying to understand the problem I found that PG and SMC do not work either with broadcasting (function called demo_array) in the following code.
I implemented 4 times the function than you call coinflip in your quickstart (here called demo) with different 'complexity' of implementation. I tried four sampler: HMC MH SMC and PG.
For the four different way of implementing it, it works only for all with HMC and MH
With PG and SMC only the original function work:
using Turing
using Pkg
Pkg.status("Turing")
@model function demo(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end
@model function demo_array(x)
# 'array'
p ~ Beta(1,1)
x .~ Bernoulli(p)
return
end
@model function demo_addprob(x)
# 'logprob'
p ~ Beta(1,1)
loglik = loglikelihood(Bernoulli(p), x)
Turing.@addlogprob!(loglik)
return
end
function function_demo_PPL(model, varinfo, context, x)
p, varinfo = DynamicPPL.tilde_assume!!(
context,
Beta(1, 1),
Turing.@varname(p),
varinfo,
)
DynamicPPL.dot_tilde_observe!!(context, Bernoulli(p), x, Turing.@varname(x), varinfo)
end
demo_PPL(x) = Turing.Model(function_demo_PPL, (; x))
data = [true for _ in 1:20]
p_model = Dict("demo"=>demo,
"array"=>demo_array,
"add_prob"=>demo_addprob,
"PPL"=>demo_PPL)
p_sampler = Dict("HMC"=>HMC(0.05, 10), "SMC"=>SMC(),"MH"=>MH(),"PG"=>PG(20))
results=[]
for model_name in ["demo","array","add_prob","PPL"]
for (sampler_name, sampler) in p_sampler
c = sample(p_model[model_name](data),sampler, 1000)
push!(results,[model_name,sampler_name,mean(c[:p])])
end
end
println("Expected value p=",round(21/22,digits=2))
for (model_name,sampler,p) in results
println("Model name: ",model_name,", sampler: ",sampler,", p=",round(p,digits=2))
end
returns
Status `~/alignExp/Project.toml`
⌃ [fce5fe82] Turing v0.21.13
Info Packages marked with ⌃ have new versions available and may be upgradable.
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:08
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:03
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:02
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:02
Expected value p=0.95
Model name: demo, sampler: HMC, p=0.95
Model name: demo, sampler: SMC, p=0.93
Model name: demo, sampler: MH, p=0.95
Model name: demo, sampler: PG, p=0.96
Model name: array, sampler: HMC, p=0.95
Model name: array, sampler: SMC, p=0.5
Model name: array, sampler: MH, p=0.96
Model name: array, sampler: PG, p=0.51
Model name: add_prob, sampler: HMC, p=0.96
Model name: add_prob, sampler: SMC, p=0.5
Model name: add_prob, sampler: MH, p=0.95
Model name: add_prob, sampler: PG, p=0.5
Model name: PPL, sampler: HMC, p=0.96
Model name: PPL, sampler: SMC, p=0.5
Model name: PPL, sampler: MH, p=0.95
Model name: PPL, sampler: PG, p=0.5
Metadata
Metadata
Assignees
Labels
No labels