Skip to content

Commit 634135f

Browse files
committed
[no ci] initial updates for InitContext
1 parent ff8d01e commit 634135f

File tree

10 files changed

+59
-127
lines changed

10 files changed

+59
-127
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Distributions = "0.25.77"
6464
DistributionsAD = "0.6"
6565
DocStringExtensions = "0.8, 0.9"
6666
DynamicHMC = "3.4"
67-
DynamicPPL = "0.37.2"
67+
DynamicPPL = "0.38"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
7070
Libtask = "0.9.3"
@@ -90,3 +90,6 @@ julia = "1.10.8"
9090
[extras]
9191
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
93+
94+
[sources]
95+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}

src/mcmc/abstractmcmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# TODO: Implement additional checks for certain samplers, e.g.
22
# HMC not supporting discrete parameters.
33
function _check_model(model::DynamicPPL.Model)
4-
# TODO(DPPL0.38/penelopeysm): use InitContext
5-
spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context))
6-
return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true)
4+
new_context = DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
5+
new_model = DynamicPPL.contextualize(model, new_context)
6+
return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true)
77
end
88
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
99
return _check_model(model)

src/mcmc/emcee.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,16 @@ function AbstractMCMC.step(
4646

4747
# Sample from the prior
4848
n = spl.alg.ensemble.n_walkers
49-
vis = [VarInfo(rng, model, SampleFromPrior()) for _ in 1:n]
49+
vis = [VarInfo(rng, model) for _ in 1:n]
5050

5151
# Update the parameters if provided.
5252
if initial_params !== nothing
53-
length(initial_params) == n ||
54-
throw(ArgumentError("initial parameters have to be specified for each walker"))
55-
vis = map(vis, initial_params) do vi, init
56-
# TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!!
57-
vi = DynamicPPL.initialize_parameters!!(vi, init, model)
58-
59-
# Update log joint probability.
60-
spl_model = DynamicPPL.contextualize(
61-
model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context)
62-
)
63-
last(DynamicPPL.evaluate!!(spl_model, vi))
53+
if !(initial_params isa AbstractVector && length(initial_params) == n)
54+
err_msg = "initial_params for `Emcee` must be a vector of initialisation strategies, with length equal to the number of walkers ($n)"
55+
throw(ArgumentError(err_msg))
56+
end
57+
vis = map(vis, initial_params) do vi, strategy
58+
DynamicPPL.init!!(rng, model, vi, strategy)
6459
end
6560
end
6661

src/mcmc/ess.jl

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,12 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
8282

8383
# Only define out-of-place sampling
8484
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
85-
varinfo = p.varinfo
86-
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
87-
# TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model,
88-
# p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason
89-
# why we had to use the 'del' flag before this was because
90-
# SampleFromPrior() wouldn't overwrite existing variables.
91-
# The main problem I'm rather unsure about is ESS-within-Gibbs. The
92-
# current implementation I think makes sure to only resample the variables
93-
# that 'belong' to the current ESS sampler. InitContext on the other hand
94-
# would resample all variables in the model (??) Need to think about this
95-
# carefully.
96-
vns = keys(varinfo)
97-
for vn in vns
98-
set_flag!(varinfo, vn, "del")
99-
end
100-
p.model(rng, varinfo)
101-
return varinfo[:]
85+
# TODO(penelopeysm/DPPL 0.38) The main problem I'm rather unsure about is
86+
# ESS-within-Gibbs. The current implementation I think makes sure to only resample the
87+
# variables that 'belong' to the current ESS sampler. InitContext on the other hand
88+
# would resample all variables in the model (??) Need to think about this carefully.
89+
_, vi = DynamicPPL.init!!(p.model, p.varinfo, DynamicPPL.InitFromPrior())
90+
return vi[:]
10291
end
10392

10493
# Mean of prior distribution

src/mcmc/gibbs.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ A context used in the implementation of the Turing.jl Gibbs sampler.
4747
There will be one `GibbsContext` for each iteration of a component sampler.
4848
4949
`target_varnames` is a a tuple of `VarName`s that the current component sampler
50-
is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume`
50+
is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume!!`
5151
calls to its child context. For other variables, their values will be fixed to
5252
the values they have in `global_varinfo`.
5353
@@ -140,7 +140,7 @@ function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName})
140140
end
141141

142142
# Tilde pipeline
143-
function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
143+
function DynamicPPL.tilde_assume!!(context::GibbsContext, right, vn, vi)
144144
child_context = DynamicPPL.childcontext(context)
145145

146146
# Note that `child_context` may contain `PrefixContext`s -- in which case
@@ -175,7 +175,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
175175

176176
return if is_target_varname(context, vn)
177177
# Fall back to the default behavior.
178-
DynamicPPL.tilde_assume(child_context, right, vn, vi)
178+
DynamicPPL.tilde_assume!!(child_context, right, vn, vi)
179179
elseif has_conditioned_gibbs(context, vn)
180180
# This branch means that a different sampler is supposed to handle this
181181
# variable. From the perspective of this sampler, this variable is
@@ -191,9 +191,10 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
191191
# presumably a new variable that should be sampled from its prior. We need to add
192192
# this new variable to the global `varinfo` of the context, but not to the local one
193193
# being used by the current sampler.
194-
value, new_global_vi = DynamicPPL.tilde_assume(
195-
child_context,
196-
DynamicPPL.SampleFromPrior(),
194+
value, new_global_vi = DynamicPPL.tilde_assume!!(
195+
# child_context might be a PrefixContext so we have to be careful to not
196+
# overwrite it.
197+
DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext()),
197198
right,
198199
vn,
199200
get_global_varinfo(context),
@@ -204,7 +205,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
204205
end
205206

206207
# As above but with an RNG.
207-
function DynamicPPL.tilde_assume(
208+
function DynamicPPL.tilde_assume!!(
208209
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi
209210
)
210211
# See comment in the above, rng-less version of this method for an explanation.
@@ -215,7 +216,7 @@ function DynamicPPL.tilde_assume(
215216
# This branch means that that `sampler` is supposed to handle
216217
# this variable. We can thus use its default behaviour, with
217218
# the 'local' sampler-specific VarInfo.
218-
DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi)
219+
DynamicPPL.tilde_assume!!(rng, child_context, sampler, right, vn, vi)
219220
elseif has_conditioned_gibbs(context, vn)
220221
# This branch means that a different sampler is supposed to handle this
221222
# variable. From the perspective of this sampler, this variable is
@@ -231,10 +232,10 @@ function DynamicPPL.tilde_assume(
231232
# presumably a new variable that should be sampled from its prior. We need to add
232233
# this new variable to the global `varinfo` of the context, but not to the local one
233234
# being used by the current sampler.
234-
value, new_global_vi = DynamicPPL.tilde_assume(
235-
rng,
236-
child_context,
237-
DynamicPPL.SampleFromPrior(),
235+
value, new_global_vi = DynamicPPL.tilde_assume!!(
236+
# child_context might be a PrefixContext so we have to be careful to not
237+
# overwrite it.
238+
DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext(rng)),
238239
right,
239240
vn,
240241
get_global_varinfo(context),

src/mcmc/mh.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,11 @@ function propose!!(
329329
prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false)
330330

331331
# Make a new transition.
332-
spl_model = DynamicPPL.contextualize(
333-
model, DynamicPPL.SamplingContext(rng, spl, model.context)
334-
)
332+
model = DynamicPPL.setleafcontext(model, MHContext(rng))
335333
densitymodel = AMH.DensityModel(
336334
Base.Fix1(
337335
LogDensityProblems.logdensity,
338-
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi),
336+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
339337
),
340338
)
341339
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
@@ -366,13 +364,11 @@ function propose!!(
366364
prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false)
367365

368366
# Make a new transition.
369-
spl_model = DynamicPPL.contextualize(
370-
model, DynamicPPL.SamplingContext(rng, spl, model.context)
371-
)
367+
model = DynamicPPL.setleafcontext(model, MHContext(rng))
372368
densitymodel = AMH.DensityModel(
373369
Base.Fix1(
374370
LogDensityProblems.logdensity,
375-
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi),
371+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
376372
),
377373
)
378374
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
@@ -410,13 +406,20 @@ function AbstractMCMC.step(
410406
return Transition(model, new_state.varinfo, nothing), new_state
411407
end
412408

413-
####
414-
#### Compiler interface, i.e. tilde operators.
415-
####
416-
function DynamicPPL.assume(
417-
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
409+
struct MHContext{R} <: DynamicPPL.AbstractContext
410+
rng::R
411+
end
412+
DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf()
413+
414+
function DynamicPPL.tilde_assume!!(
415+
context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
418416
)
419-
# Just defer to `SampleFromPrior`.
420-
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
421-
return retval
417+
# Allow MH to sample new variables from the prior if it's not already present in the
418+
# VarInfo.
419+
dispatch_ctx = if haskey(vi, vn)
420+
DynamicPPL.DefaultContext()
421+
else
422+
DynamicPPL.InitContext(context.rng, DynamicPPL.InitFromPrior())
423+
end
424+
return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi)
422425
end

src/mcmc/particle_mcmc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ function set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
446446
return nothing
447447
end
448448

449+
# TODO(penelopeysm / DPPL 0.38): Figure this out
450+
struct ParticleMCMCContext <: DynamicPPL.AbstractContext end
451+
449452
function DynamicPPL.assume(
450453
rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo
451454
)

src/mcmc/prior.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ function AbstractMCMC.step(
1212
state=nothing;
1313
kwargs...,
1414
)
15-
# TODO(DPPL0.38/penelopeysm): replace with init!!
16-
sampling_model = DynamicPPL.contextualize(
17-
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
18-
)
19-
vi = VarInfo()
2015
vi = DynamicPPL.setaccs!!(
2116
vi,
2217
(
@@ -25,6 +20,6 @@ function AbstractMCMC.step(
2520
DynamicPPL.LogLikelihoodAccumulator(),
2621
),
2722
)
28-
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
23+
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior())
2924
return Transition(model, vi, nothing; reevaluate=false), nothing
3025
end

src/optimisation/Optimisation.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,9 @@ function estimate_mode(
507507
kwargs...,
508508
)
509509
if check_model
510-
spl_model = DynamicPPL.contextualize(
511-
model, DynamicPPL.SamplingContext(model.context)
512-
)
513-
DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true)
510+
new_context = DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
511+
new_model = DynamicPPL.contextualize(model, new_context)
512+
DynamicPPL.check_model(new_model, DynamicPPL.VarInfo(); error_on_failure=true)
514513
end
515514

516515
constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons)

test/mcmc/external_sampler.jl

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -136,40 +136,6 @@ function initialize_mh_rw(model)
136136
return AdvancedMH.RWMH(MvNormal(Zeros(d), 0.1 * I))
137137
end
138138

139-
# TODO: Should this go somewhere else?
140-
# Convert a model into a `Distribution` to allow usage as a proposal in AdvancedMH.jl.
141-
struct ModelDistribution{M<:DynamicPPL.Model,V<:DynamicPPL.VarInfo} <:
142-
ContinuousMultivariateDistribution
143-
model::M
144-
varinfo::V
145-
end
146-
function ModelDistribution(model::DynamicPPL.Model)
147-
return ModelDistribution(model, DynamicPPL.VarInfo(model))
148-
end
149-
150-
Base.length(d::ModelDistribution) = length(d.varinfo[:])
151-
function Distributions._logpdf(d::ModelDistribution, x::AbstractVector)
152-
return logprior(d.model, DynamicPPL.unflatten(d.varinfo, x))
153-
end
154-
function Distributions._rand!(
155-
rng::Random.AbstractRNG, d::ModelDistribution, x::AbstractVector{<:Real}
156-
)
157-
model = d.model
158-
varinfo = deepcopy(d.varinfo)
159-
for vn in keys(varinfo)
160-
DynamicPPL.set_flag!(varinfo, vn, "del")
161-
end
162-
DynamicPPL.evaluate!!(model, varinfo, DynamicPPL.SamplingContext(rng))
163-
x .= varinfo[:]
164-
return x
165-
end
166-
167-
function initialize_mh_with_prior_proposal(model)
168-
return AdvancedMH.MetropolisHastings(
169-
AdvancedMH.StaticProposal(ModelDistribution(model))
170-
)
171-
end
172-
173139
function test_initial_params(
174140
model, sampler, initial_params=DynamicPPL.VarInfo(model)[:]; kwargs...
175141
)
@@ -268,28 +234,6 @@ end
268234
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
269235
end
270236
end
271-
272-
# NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls
273-
# it with `NamedTuple` instead of `AbstractVector`.
274-
# @testset "MH with prior proposal" begin
275-
# @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
276-
# sampler = initialize_mh_with_prior_proposal(model);
277-
# sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; unconstrained=false))
278-
# @testset "initial_params" begin
279-
# test_initial_params(model, sampler_ext)
280-
# end
281-
# @testset "inference" begin
282-
# DynamicPPL.TestUtils.test_sampler(
283-
# [model],
284-
# sampler_ext,
285-
# 10_000;
286-
# discard_initial=1_000,
287-
# rtol=0.2,
288-
# sampler_name="AdvancedMH"
289-
# )
290-
# end
291-
# end
292-
# end
293237
end
294238
end
295239

0 commit comments

Comments
 (0)