Skip to content

Commit bf18516

Browse files
committed
[no ci] initial updates for InitContext
1 parent 10f960e commit bf18516

File tree

11 files changed

+69
-75
lines changed

11 files changed

+69
-75
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# 0.41.0
22

3+
## DynamicPPL 0.38
4+
5+
Lorem ipsum dynamicppl sit amet
6+
7+
## Initial step in MCMC sampling
8+
39
HMC and NUTS samplers no longer take an extra single step before starting the chain.
410
This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided).
511

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Distributions = "0.25.77"
6464
DistributionsAD = "0.6"
6565
DocStringExtensions = "0.8, 0.9"
6666
DynamicHMC = "3.4"
67-
DynamicPPL = "0.37.2"
67+
DynamicPPL = "0.38"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
7070
Libtask = "0.9.3"
@@ -90,3 +90,6 @@ julia = "1.10.8"
9090
[extras]
9191
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
93+
94+
[sources]
95+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}

src/mcmc/abstractmcmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# TODO: Implement additional checks for certain samplers, e.g.
22
# HMC not supporting discrete parameters.
33
function _check_model(model::DynamicPPL.Model)
4-
# TODO(DPPL0.38/penelopeysm): use InitContext
5-
spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context))
6-
return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true)
4+
new_context = DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
5+
new_model = DynamicPPL.contextualize(model, new_context)
6+
return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true)
77
end
88
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
99
return _check_model(model)

src/mcmc/emcee.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,19 @@ function AbstractMCMC.step(
4646

4747
# Sample from the prior
4848
n = spl.alg.ensemble.n_walkers
49-
vis = [VarInfo(rng, model, SampleFromPrior()) for _ in 1:n]
49+
vis = [VarInfo(rng, model) for _ in 1:n]
5050

5151
# Update the parameters if provided.
5252
if initial_params !== nothing
53-
length(initial_params) == n ||
54-
throw(ArgumentError("initial parameters have to be specified for each walker"))
55-
vis = map(vis, initial_params) do vi, init
56-
# TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!!
57-
vi = DynamicPPL.initialize_parameters!!(vi, init, model)
58-
59-
# Update log joint probability.
60-
spl_model = DynamicPPL.contextualize(
61-
model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context)
62-
)
63-
last(DynamicPPL.evaluate!!(spl_model, vi))
53+
if !(
54+
initial_params isa AbstractVector{<:DynamicPPL.AbstractInitStrategy} &&
55+
length(initial_params) == n
56+
)
57+
err_msg = "initial_params for `Emcee` must be a vector of `DynamicPPL.AbstractInitStrategy`, with length equal to the number of walkers ($n)"
58+
throw(ArgumentError(err_msg))
59+
end
60+
vis = map(vis, initial_params) do vi, strategy
61+
DynamicPPL.init!!(rng, model, vi, strategy)
6462
end
6563
end
6664

src/mcmc/ess.jl

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,12 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
8282

8383
# Only define out-of-place sampling
8484
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
85-
varinfo = p.varinfo
86-
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
87-
# TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model,
88-
# p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason
89-
# why we had to use the 'del' flag before this was because
90-
# SampleFromPrior() wouldn't overwrite existing variables.
91-
# The main problem I'm rather unsure about is ESS-within-Gibbs. The
92-
# current implementation I think makes sure to only resample the variables
93-
# that 'belong' to the current ESS sampler. InitContext on the other hand
94-
# would resample all variables in the model (??) Need to think about this
95-
# carefully.
96-
vns = keys(varinfo)
97-
for vn in vns
98-
set_flag!(varinfo, vn, "del")
99-
end
100-
p.model(rng, varinfo)
101-
return varinfo[:]
85+
# TODO(penelopeysm/DPPL 0.38) The main problem I'm rather unsure about is
86+
# ESS-within-Gibbs. The current implementation I think makes sure to only resample the
87+
# variables that 'belong' to the current ESS sampler. InitContext on the other hand
88+
# would resample all variables in the model (??) Need to think about this carefully.
89+
_, vi = DynamicPPL.init!!(p.model, p.varinfo, DynamicPPL.InitFromPrior())
90+
return vi[:]
10291
end
10392

10493
# Mean of prior distribution

src/mcmc/gibbs.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ A context used in the implementation of the Turing.jl Gibbs sampler.
4747
There will be one `GibbsContext` for each iteration of a component sampler.
4848
4949
`target_varnames` is a a tuple of `VarName`s that the current component sampler
50-
is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume`
50+
is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume!!`
5151
calls to its child context. For other variables, their values will be fixed to
5252
the values they have in `global_varinfo`.
5353
@@ -140,7 +140,7 @@ function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName})
140140
end
141141

142142
# Tilde pipeline
143-
function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
143+
function DynamicPPL.tilde_assume!!(context::GibbsContext, right, vn, vi)
144144
child_context = DynamicPPL.childcontext(context)
145145

146146
# Note that `child_context` may contain `PrefixContext`s -- in which case
@@ -175,7 +175,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
175175

176176
return if is_target_varname(context, vn)
177177
# Fall back to the default behavior.
178-
DynamicPPL.tilde_assume(child_context, right, vn, vi)
178+
DynamicPPL.tilde_assume!!(child_context, right, vn, vi)
179179
elseif has_conditioned_gibbs(context, vn)
180180
# This branch means that a different sampler is supposed to handle this
181181
# variable. From the perspective of this sampler, this variable is
@@ -191,9 +191,10 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
191191
# presumably a new variable that should be sampled from its prior. We need to add
192192
# this new variable to the global `varinfo` of the context, but not to the local one
193193
# being used by the current sampler.
194-
value, new_global_vi = DynamicPPL.tilde_assume(
195-
child_context,
196-
DynamicPPL.SampleFromPrior(),
194+
value, new_global_vi = DynamicPPL.tilde_assume!!(
195+
# child_context might be a PrefixContext so we have to be careful to not
196+
# overwrite it.
197+
DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext()),
197198
right,
198199
vn,
199200
get_global_varinfo(context),
@@ -204,7 +205,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
204205
end
205206

206207
# As above but with an RNG.
207-
function DynamicPPL.tilde_assume(
208+
function DynamicPPL.tilde_assume!!(
208209
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi
209210
)
210211
# See comment in the above, rng-less version of this method for an explanation.
@@ -215,7 +216,7 @@ function DynamicPPL.tilde_assume(
215216
# This branch means that that `sampler` is supposed to handle
216217
# this variable. We can thus use its default behaviour, with
217218
# the 'local' sampler-specific VarInfo.
218-
DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi)
219+
DynamicPPL.tilde_assume!!(rng, child_context, sampler, right, vn, vi)
219220
elseif has_conditioned_gibbs(context, vn)
220221
# This branch means that a different sampler is supposed to handle this
221222
# variable. From the perspective of this sampler, this variable is
@@ -231,10 +232,10 @@ function DynamicPPL.tilde_assume(
231232
# presumably a new variable that should be sampled from its prior. We need to add
232233
# this new variable to the global `varinfo` of the context, but not to the local one
233234
# being used by the current sampler.
234-
value, new_global_vi = DynamicPPL.tilde_assume(
235-
rng,
236-
child_context,
237-
DynamicPPL.SampleFromPrior(),
235+
value, new_global_vi = DynamicPPL.tilde_assume!!(
236+
# child_context might be a PrefixContext so we have to be careful to not
237+
# overwrite it.
238+
DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext(rng)),
238239
right,
239240
vn,
240241
get_global_varinfo(context),

src/mcmc/mh.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,11 @@ function propose!!(
329329
prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false)
330330

331331
# Make a new transition.
332-
spl_model = DynamicPPL.contextualize(
333-
model, DynamicPPL.SamplingContext(rng, spl, model.context)
334-
)
332+
model = DynamicPPL.setleafcontext(model, MHContext(rng))
335333
densitymodel = AMH.DensityModel(
336334
Base.Fix1(
337335
LogDensityProblems.logdensity,
338-
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi),
336+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
339337
),
340338
)
341339
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
@@ -366,13 +364,11 @@ function propose!!(
366364
prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false)
367365

368366
# Make a new transition.
369-
spl_model = DynamicPPL.contextualize(
370-
model, DynamicPPL.SamplingContext(rng, spl, model.context)
371-
)
367+
model = DynamicPPL.setleafcontext(model, MHContext(rng))
372368
densitymodel = AMH.DensityModel(
373369
Base.Fix1(
374370
LogDensityProblems.logdensity,
375-
DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi),
371+
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
376372
),
377373
)
378374
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
@@ -410,13 +406,20 @@ function AbstractMCMC.step(
410406
return Transition(model, new_state.varinfo, nothing), new_state
411407
end
412408

413-
####
414-
#### Compiler interface, i.e. tilde operators.
415-
####
416-
function DynamicPPL.assume(
417-
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
409+
struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
410+
rng::R
411+
end
412+
DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf()
413+
414+
function DynamicPPL.tilde_assume!!(
415+
context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
418416
)
419-
# Just defer to `SampleFromPrior`.
420-
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
421-
return retval
417+
# Allow MH to sample new variables from the prior if it's not already present in the
418+
# VarInfo.
419+
dispatch_ctx = if haskey(vi, vn)
420+
DynamicPPL.DefaultContext()
421+
else
422+
DynamicPPL.InitContext(context.rng, DynamicPPL.InitFromPrior())
423+
end
424+
return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi)
422425
end

src/mcmc/particle_mcmc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ function set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
446446
return nothing
447447
end
448448

449+
# TODO(penelopeysm / DPPL 0.38): Figure this out
450+
struct ParticleMCMCContext <: DynamicPPL.AbstractContext end
451+
449452
function DynamicPPL.assume(
450453
rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo
451454
)

src/mcmc/prior.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ function AbstractMCMC.step(
1212
state=nothing;
1313
kwargs...,
1414
)
15-
# TODO(DPPL0.38/penelopeysm): replace with init!!
16-
sampling_model = DynamicPPL.contextualize(
17-
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
18-
)
19-
vi = VarInfo()
2015
vi = DynamicPPL.setaccs!!(
2116
vi,
2217
(
@@ -25,6 +20,6 @@ function AbstractMCMC.step(
2520
DynamicPPL.LogLikelihoodAccumulator(),
2621
),
2722
)
28-
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
23+
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior())
2924
return Transition(model, vi, nothing; reevaluate=false), nothing
3025
end

src/optimisation/Optimisation.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,10 +508,9 @@ function estimate_mode(
508508
kwargs...,
509509
)
510510
if check_model
511-
spl_model = DynamicPPL.contextualize(
512-
model, DynamicPPL.SamplingContext(model.context)
513-
)
514-
DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true)
511+
new_context = DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
512+
new_model = DynamicPPL.contextualize(model, new_context)
513+
DynamicPPL.check_model(new_model, DynamicPPL.VarInfo(); error_on_failure=true)
515514
end
516515

517516
constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons)

0 commit comments

Comments
 (0)