@@ -136,6 +136,37 @@ function initialize_mh_rw(model)
136
136
return AdvancedMH. RWMH (MvNormal (Zeros (d), 0.1 * I))
137
137
end
138
138
139
+ # TODO : Should this go somewhere else?
140
+ # Convert a model into a `Distribution` to allow usage as a proposal in AdvancedMH.jl.
141
+ struct ModelDistribution{M<: DynamicPPL.Model ,V<: DynamicPPL.VarInfo } < :
142
+ ContinuousMultivariateDistribution
143
+ model:: M
144
+ varinfo:: V
145
+ end
146
+ function ModelDistribution (model:: DynamicPPL.Model )
147
+ return ModelDistribution (model, DynamicPPL. VarInfo (model))
148
+ end
149
+
150
+ Base. length (d:: ModelDistribution ) = length (d. varinfo[:])
151
+ function Distributions. _logpdf (d:: ModelDistribution , x:: AbstractVector )
152
+ return logprior (d. model, DynamicPPL. unflatten (d. varinfo, x))
153
+ end
154
+ function Distributions. _rand! (
155
+ rng:: Random.AbstractRNG , d:: ModelDistribution , x:: AbstractVector{<:Real}
156
+ )
157
+ model = d. model
158
+ varinfo = deepcopy (d. varinfo)
159
+ _, varinfo = DynamicPPL. init!! (rng, model, varinfo, DynamicPPL. InitFromPrior ())
160
+ x .= varinfo[:]
161
+ return x
162
+ end
163
+
164
+ function initialize_mh_with_prior_proposal (model)
165
+ return AdvancedMH. MetropolisHastings (
166
+ AdvancedMH. StaticProposal (ModelDistribution (model))
167
+ )
168
+ end
169
+
139
170
function test_initial_params (
140
171
model, sampler, initial_params= DynamicPPL. VarInfo (model)[:]; kwargs...
141
172
)
234
265
@test isapprox (logpdf .(Normal (), chn[:x ]), chn[:lp ])
235
266
end
236
267
end
268
+
269
+ # NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls
270
+ # it with `NamedTuple` instead of `AbstractVector`.
271
+ # @testset "MH with prior proposal" begin
272
+ # @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
273
+ # sampler = initialize_mh_with_prior_proposal(model);
274
+ # sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; unconstrained=false))
275
+ # @testset "initial_params" begin
276
+ # test_initial_params(model, sampler_ext)
277
+ # end
278
+ # @testset "inference" begin
279
+ # DynamicPPL.TestUtils.test_sampler(
280
+ # [model],
281
+ # sampler_ext,
282
+ # 10_000;
283
+ # discard_initial=1_000,
284
+ # rtol=0.2,
285
+ # sampler_name="AdvancedMH"
286
+ # )
287
+ # end
288
+ # end
289
+ # end
237
290
end
238
291
end
239
292
0 commit comments