Skip to content

Commit 3afd807

Browse files
committed
Fix externalsampler test correctly
1 parent c315993 commit 3afd807

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

src/mcmc/external_sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function AbstractMCMC.step(
124124
sampler = alg.sampler
125125

126126
# Initialise varinfo with initial params and link the varinfo if needed.
127-
varinfo = DynamicPPL.VarInfo(model)
127+
varinfo = DynamicPPL.VarInfo(rng, model)
128128
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params)
129129

130130
if requires_unconstrained_space(alg)

test/mcmc/external_sampler.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,24 @@ function initialize_mh_with_prior_proposal(model)
172172
)
173173
end
174174

175-
function test_initial_params(model, sampler, initial_params=InitFromPrior(); kwargs...)
175+
function test_initial_params(model, sampler; kwargs...)
176+
# Generate some parameters.
177+
dict = DynamicPPL.values_as(VarInfo(model), Dict)
178+
init_strategy = DynamicPPL.InitFromParams(dict)
179+
176180
# Execute the transition with two different RNGs and check that the resulting
177-
# parameter values are the same.
181+
# parameter values are the same. This ensures that the `initial_params` are
182+
# respected (i.e., regardless of the RNG, the first step should always return
183+
# the same parameters).
178184
rng1 = Random.MersenneTwister(42)
179185
rng2 = Random.MersenneTwister(43)
180186

181-
transition1, _ = AbstractMCMC.step(rng1, model, sampler; initial_params, kwargs...)
182-
transition2, _ = AbstractMCMC.step(rng2, model, sampler; initial_params, kwargs...)
187+
transition1, _ = AbstractMCMC.step(
188+
rng1, model, sampler; initial_params=init_strategy, kwargs...
189+
)
190+
transition2, _ = AbstractMCMC.step(
191+
rng2, model, sampler; initial_params=init_strategy, kwargs...
192+
)
183193
vn_to_val1 = DynamicPPL.OrderedDict(transition1.θ)
184194
vn_to_val2 = DynamicPPL.OrderedDict(transition2.θ)
185195
for vn in union(keys(vn_to_val1), keys(vn_to_val2))

0 commit comments

Comments
 (0)