Skip to content

Commit ed43a02

Browse files
committed
[no ci] Fix unused external sampler code
1 parent d810993 commit ed43a02

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

test/mcmc/external_sampler.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,37 @@ function initialize_mh_rw(model)
136136
return AdvancedMH.RWMH(MvNormal(Zeros(d), 0.1 * I))
137137
end
138138

139+
# TODO: Should this go somewhere else?
140+
# Convert a model into a `Distribution` to allow usage as a proposal in AdvancedMH.jl.
141+
struct ModelDistribution{M<:DynamicPPL.Model,V<:DynamicPPL.VarInfo} <:
142+
ContinuousMultivariateDistribution
143+
model::M
144+
varinfo::V
145+
end
146+
function ModelDistribution(model::DynamicPPL.Model)
147+
return ModelDistribution(model, DynamicPPL.VarInfo(model))
148+
end
149+
150+
Base.length(d::ModelDistribution) = length(d.varinfo[:])
151+
function Distributions._logpdf(d::ModelDistribution, x::AbstractVector)
152+
return logprior(d.model, DynamicPPL.unflatten(d.varinfo, x))
153+
end
154+
function Distributions._rand!(
155+
rng::Random.AbstractRNG, d::ModelDistribution, x::AbstractVector{<:Real}
156+
)
157+
model = d.model
158+
varinfo = deepcopy(d.varinfo)
159+
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, DynamicPPL.InitFromPrior())
160+
x .= varinfo[:]
161+
return x
162+
end
163+
164+
function initialize_mh_with_prior_proposal(model)
165+
return AdvancedMH.MetropolisHastings(
166+
AdvancedMH.StaticProposal(ModelDistribution(model))
167+
)
168+
end
169+
139170
function test_initial_params(
140171
model, sampler, initial_params=DynamicPPL.VarInfo(model)[:]; kwargs...
141172
)
@@ -234,6 +265,28 @@ end
234265
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
235266
end
236267
end
268+
269+
# NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls
270+
# it with `NamedTuple` instead of `AbstractVector`.
271+
# @testset "MH with prior proposal" begin
272+
# @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
273+
# sampler = initialize_mh_with_prior_proposal(model);
274+
# sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; unconstrained=false))
275+
# @testset "initial_params" begin
276+
# test_initial_params(model, sampler_ext)
277+
# end
278+
# @testset "inference" begin
279+
# DynamicPPL.TestUtils.test_sampler(
280+
# [model],
281+
# sampler_ext,
282+
# 10_000;
283+
# discard_initial=1_000,
284+
# rtol=0.2,
285+
# sampler_name="AdvancedMH"
286+
# )
287+
# end
288+
# end
289+
# end
237290
end
238291
end
239292

0 commit comments

Comments
 (0)