@@ -172,14 +172,24 @@ function initialize_mh_with_prior_proposal(model)
172
172
)
173
173
end
174
174
175
- function test_initial_params (model, sampler, initial_params= InitFromPrior (); kwargs... )
175
+ function test_initial_params (model, sampler; kwargs... )
176
+ # Generate some parameters.
177
+ dict = DynamicPPL. values_as (VarInfo (model), Dict)
178
+ init_strategy = DynamicPPL. InitFromParams (dict)
179
+
176
180
# Execute the transition with two different RNGs and check that the resulting
177
- # parameter values are the same.
181
+ # parameter values are the same. This ensures that the `initial_params` are
182
+ # respected (i.e., regardless of the RNG, the first step should always return
183
+ # the same parameters).
178
184
rng1 = Random. MersenneTwister (42 )
179
185
rng2 = Random. MersenneTwister (43 )
180
186
181
- transition1, _ = AbstractMCMC. step (rng1, model, sampler; initial_params, kwargs... )
182
- transition2, _ = AbstractMCMC. step (rng2, model, sampler; initial_params, kwargs... )
187
+ transition1, _ = AbstractMCMC. step (
188
+ rng1, model, sampler; initial_params= init_strategy, kwargs...
189
+ )
190
+ transition2, _ = AbstractMCMC. step (
191
+ rng2, model, sampler; initial_params= init_strategy, kwargs...
192
+ )
183
193
vn_to_val1 = DynamicPPL. OrderedDict (transition1. θ)
184
194
vn_to_val2 = DynamicPPL. OrderedDict (transition2. θ)
185
195
for vn in union (keys (vn_to_val1), keys (vn_to_val2))
0 commit comments