Skip to content

Commit 6c60c3b

Browse files
torfjeldeyebai
andcommitted
Fixes failing tests (#276)
Co-authored-by: Hong Ge <[email protected]>
1 parent 4de6f54 commit 6c60c3b

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

test/compat/ad.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
return logpdf(InverseGamma(2, 3), s) +
1010
logpdf(Normal(0, sqrt(s)), m) +
11-
logpdf(dist, 1.5) + logpdf(dist, 2.0)
11+
logpdf(dist, 1.5) +
12+
logpdf(dist, 2.0)
1213
end
1314

1415
test_model_ad(gdemo_default, logp_gdemo_default)

test/turing/loglikelihoods.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
y = randn()
1414
model = demo(xs, y)
1515
chain = sample(model, MH(), MCMCThreads(), 100, 2)
16-
var_to_likelihoods = pointwise_loglikelihoods(model, chain)
16+
var_to_likelihoods = pointwise_loglikelihoods(
17+
model, MCMCChains.get_sections(chain, :parameters)
18+
)
1719
@test haskey(var_to_likelihoods, "xs[1]")
1820
@test haskey(var_to_likelihoods, "xs[2]")
1921
@test haskey(var_to_likelihoods, "xs[3]")
@@ -32,8 +34,8 @@
3234
results = pointwise_loglikelihoods(model, var_info)
3335
var_to_likelihoods = Dict(string(vn) =>for (vn, ℓ) in results)
3436
s, m = var_info[SampleFromPrior()]
35-
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"]
36-
@test logpdf(Normal(m, s), xs[2]) == var_to_likelihoods["xs[2]"]
37-
@test logpdf(Normal(m, s), xs[3]) == var_to_likelihoods["xs[3]"]
38-
@test logpdf(Normal(m, s), y) == var_to_likelihoods["y"]
37+
@test [logpdf(Normal(m, s), xs[1])] == var_to_likelihoods["xs[1]"]
38+
@test [logpdf(Normal(m, s), xs[2])] == var_to_likelihoods["xs[2]"]
39+
@test [logpdf(Normal(m, s), xs[3])] == var_to_likelihoods["xs[3]"]
40+
@test [logpdf(Normal(m, s), y)] == var_to_likelihoods["y"]
3941
end

test/turing/model.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
chain1 = sample(model1, MH(), 100)
3131
chain2 = sample(model2, MH(), 100)
3232

33-
res11 = generated_quantities(model1, chain1)
34-
res21 = generated_quantities(model2, chain1)
33+
res11 = generated_quantities(model1, MCMCChains.get_sections(chain1, :parameters))
34+
res21 = generated_quantities(model2, MCMCChains.get_sections(chain1, :parameters))
3535

36-
res12 = generated_quantities(model1, chain2)
37-
res22 = generated_quantities(model2, chain2)
36+
res12 = generated_quantities(model1, MCMCChains.get_sections(chain2, :parameters))
37+
res22 = generated_quantities(model2, MCMCChains.get_sections(chain2, :parameters))
3838

3939
# Check that the two different models produce the same values for
4040
# the same chains.
@@ -43,8 +43,8 @@
4343
# Ensure that they're not all the same (some can be, because rejected samples)
4444
@test any(res12[1:(end - 1)] .!= res12[2:end])
4545

46-
test_setval!(model1, chain1)
47-
test_setval!(model2, chain2)
46+
test_setval!(model1, MCMCChains.get_sections(chain1, :parameters))
47+
test_setval!(model2, MCMCChains.get_sections(chain2, :parameters))
4848

4949
# Next level
5050
@model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV}
@@ -79,11 +79,11 @@
7979
chain3 = sample(model3, MH(), 100)
8080
chain4 = sample(model4, MH(), 100)
8181

82-
res33 = generated_quantities(model3, chain3)
83-
res43 = generated_quantities(model4, chain3)
82+
res33 = generated_quantities(model3, MCMCChains.get_sections(chain3, :parameters))
83+
res43 = generated_quantities(model4, MCMCChains.get_sections(chain3, :parameters))
8484

85-
res34 = generated_quantities(model3, chain4)
86-
res44 = generated_quantities(model4, chain4)
85+
res34 = generated_quantities(model3, MCMCChains.get_sections(chain4, :parameters))
86+
res44 = generated_quantities(model4, MCMCChains.get_sections(chain4, :parameters))
8787

8888
# Check that the two different models produce the same values for
8989
# the same chains.

test/turing/prob_macro.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
model = demo(xval)
1313
varinfo = VarInfo(model)
14-
chain = sample(model, IS(), iters; save_state=true)
14+
chain = MCMCChains.get_sections(
15+
sample(model, IS(), iters; save_state=true), :parameters
16+
)
1517
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
1618
lps = logpdf.(Normal.(chain["m"], 1), xval)
1719
@test logprob"x = xval | chain = chain" == lps
@@ -40,7 +42,9 @@
4042

4143
model = demo(xval)
4244
varinfo = VarInfo(model)
43-
chain = sample(model, HMC(0.5, 1), iters; save_state=true)
45+
chain = MCMCChains.get_sections(
46+
sample(model, HMC(0.5, 1), iters; save_state=true), :parameters
47+
)
4448
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
4549

4650
names = namesingroup(chain, "m")
@@ -74,7 +78,10 @@
7478
group = rand(1:4, 100)
7579
n_groups = 4
7680

77-
chain1 = sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
81+
chain1 = MCMCChains.get_sections(
82+
sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true),
83+
:parameters,
84+
)
7885
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1"
7986

8087
@model function model2(y, group, n_groups)
@@ -85,7 +92,10 @@
8592
end
8693
end
8794

88-
chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
95+
chain2 = MCMCChains.get_sections(
96+
sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true),
97+
:parameters,
98+
)
8999
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
90100
end
91101
end

0 commit comments

Comments
 (0)