Skip to content

Commit a11b75e

Browse files
committed
Fix some tests
1 parent 853f83b commit a11b75e

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ function AbstractMCMC.step(
5959
)
6060
vi = VarInfo()
6161
strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit()
62-
DynamicPPL.init!!(rng, model, vi, strategy)
63-
return vi, nothing
62+
_, new_vi = DynamicPPL.init!!(rng, model, vi, strategy)
63+
return new_vi, nothing
6464
end
6565

6666
"""

test/model.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
495495

496496
# Construct a chain with 'sampled values' of β
497497
ground_truth_β = 2
498-
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
498+
β_chain = MCMCChains.Chains(
499+
rand(Normal(ground_truth_β, 0.002), 1000),
500+
[];
501+
info=(; varname_to_symbol=Dict(@varname(β) => )),
502+
)
499503

500504
# Generate predictions from that chain
501505
xs_test = [10 + 0.1, 10 + 2 * 0.1]
@@ -541,7 +545,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
541545
@testset "prediction from multiple chains" begin
542546
# Normal linreg model
543547
multiple_β_chain = MCMCChains.Chains(
544-
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
548+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2),
549+
[];
550+
info=(; varname_to_symbol=Dict(@varname(β) => )),
545551
)
546552
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
547553
@test size(multiple_β_chain, 3) == size(predictions, 3)

test/sampler.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
@test length(chains) == N
2626

2727
# `m` is Gaussian, i.e. no transformation is used, so it
28-
# should have a mean equal to its prior, i.e. 2.
29-
@test mean(vi[@varname(m)] for vi in chains) 2 atol = 0.1
28+
# will be drawn from U[-2, 2] and its mean should be 0.
29+
@test mean(vi[@varname(m)] for vi in chains) 0.0 atol = 0.1
3030

3131
# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
3232
@test mean(vi[@varname(s)] for vi in chains) 1.8 atol = 0.1
@@ -81,10 +81,8 @@
8181
model = coinflip()
8282
sampler = Sampler(alg)
8383
lptrue = logpdf(Binomial(25, 0.2), 10)
84-
let inits = (; p=0.2)
85-
chain = sample(
86-
model, sampler, 1; initial_params=ParamsInit(inits), progress=false
87-
)
84+
let inits = ParamsInit((; p=0.2))
85+
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
8886
@test chain[1].metadata.p.vals == [0.2]
8987
@test getlogjoint(chain[1]) == lptrue
9088

@@ -111,10 +109,8 @@
111109
end
112110
model = twovars()
113111
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
114-
for inits in ([4, -1], (; s=4, m=-1))
115-
chain = sample(
116-
model, sampler, 1; initial_params=ParamsInit(inits), progress=false
117-
)
112+
let inits = ParamsInit((; s=4, m=-1))
113+
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
118114
@test chain[1].metadata.s.vals == [4]
119115
@test chain[1].metadata.m.vals == [-1]
120116
@test getlogjoint(chain[1]) == lptrue
@@ -126,7 +122,7 @@
126122
MCMCThreads(),
127123
1,
128124
10;
129-
initial_params=fill(ParamsInit(inits), 10),
125+
initial_params=fill(inits, 10),
130126
progress=false,
131127
)
132128
for c in chains
@@ -137,10 +133,8 @@
137133
end
138134

139135
# set only m = -1
140-
for inits in ((; s=missing, m=-1), (; m=-1))
141-
chain = sample(
142-
model, sampler, 1; initial_params=ParamsInit(inits), progress=false
143-
)
136+
for inits in (ParamsInit((; s=missing, m=-1)), ParamsInit((; m=-1)))
137+
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
144138
@test !ismissing(chain[1].metadata.s.vals[1])
145139
@test chain[1].metadata.m.vals == [-1]
146140

0 commit comments

Comments
 (0)