Skip to content

Commit e81585b

Browse files
committed
Update pMCMC implementation as per discussion
1 parent 39a7726 commit e81585b

File tree

1 file changed

+17
-33
lines changed

1 file changed

+17
-33
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for
1111
resampling.
1212
"""
1313
function set_all_del!(vi::AbstractVarInfo)
14+
# TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we
15+
# could either:
16+
# - keep a boolean 'resample' flag on the trace, or
17+
# - modify the model context appropriately.
18+
# However, this refactoring will have to wait until InitContext is
19+
# merged into DPPL.
1420
for vn in keys(vi)
1521
DynamicPPL.set_flag!(vi, vn, "del")
1622
end
@@ -59,37 +65,30 @@ end
5965
function AdvancedPS.advance!(
6066
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
6167
)
62-
# We want to increment num produce for the VarInfo stored in the trace. The trace is
63-
# mutable, so we create a new model with the incremented VarInfo and set it in the trace
64-
model = trace.model
65-
model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!(
66-
model.f.varinfo
67-
)
68-
trace.model = model
6968
# Make sure we load/reset the rng in the new replaying mechanism
7069
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
7170
score = consume(trace.model.ctask)
7271
return score
7372
end
7473

7574
function AdvancedPS.delete_retained!(trace::TracedModel)
76-
# TODO(DPPL0.37/penelopeysm): Explain this a bit better.
77-
#
7875
# This method is called if, during a CSMC update, we perform a resampling
7976
# and choose the reference particle as the trajectory to carry on from.
8077
# In such a case, we need to ensure that when we continue sampling (i.e.
8178
# the next time we hit tilde_assume), we don't use the values in the
8279
# reference particle but rather sample new values.
83-
# In this implementation, we indiscriminately set the 'del' flag for all
84-
# variables in the VarInfo. This is slightly overkill: it is not necessary
85-
# to set the 'del' flag for variables that were already sampled. However,
86-
# it allows us to avoid using DynamicPPL.set_retained_vns_del!.
80+
#
81+
# Here, we indiscriminately set the 'del' flag for all variables in the
82+
# VarInfo. This is slightly overkill: it is not necessary to set the 'del'
83+
# flag for variables that were already sampled. However, it allows us to
84+
# avoid keeping track of which variables were sampled, which leads to many
85+
# simplifications in the VarInfo data structure.
8786
set_all_del!(trace.varinfo)
8887
return trace
8988
end
9089

9190
function AdvancedPS.reset_model(trace::TracedModel)
92-
return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo)
91+
return trace
9392
end
9493

9594
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
@@ -213,7 +212,6 @@ function DynamicPPL.initialstep(
213212
)
214213
# Reset the VarInfo.
215214
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
216-
vi = DynamicPPL.reset_num_produce!!(vi)
217215
set_all_del!(vi)
218216
vi = DynamicPPL.resetlogp!!(vi)
219217
vi = DynamicPPL.empty!!(vi)
@@ -344,7 +342,6 @@ function DynamicPPL.initialstep(
344342
)
345343
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
346344
# Reset the VarInfo before new sweep
347-
vi = DynamicPPL.reset_num_produce!!(vi)
348345
set_all_del!(vi)
349346
vi = DynamicPPL.resetlogp!!(vi)
350347

@@ -366,11 +363,6 @@ function DynamicPPL.initialstep(
366363

367364
# Compute the first transition.
368365
_vi = reference.model.f.varinfo
369-
# Unset any 'del' flags before we actually construct the transition.
370-
# This is necessary because the model will be re-evaluated and we
371-
# want to make sure we do use the values in the reference particle
372-
# instead of resampling them.
373-
unset_all_del!(_vi)
374366
transition = PGTransition(model, _vi, logevidence)
375367

376368
return transition, PGState(_vi, reference.rng)
@@ -382,10 +374,10 @@ function AbstractMCMC.step(
382374
# Reset the VarInfo before new sweep.
383375
vi = state.vi
384376
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
385-
vi = DynamicPPL.reset_num_produce!!(vi)
386377
vi = DynamicPPL.resetlogp!!(vi)
387378

388379
# Create reference particle for which the samples will be retained.
380+
unset_all_del!(vi)
389381
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
390382

391383
# For all other particles, do not retain the variables but resample them.
@@ -412,11 +404,6 @@ function AbstractMCMC.step(
412404

413405
# Compute the transition.
414406
_vi = newreference.model.f.varinfo
415-
# Unset any 'del' flags before we actually construct the transition.
416-
# This is necessary because the model will be re-evaluated and we
417-
# want to make sure we do use the values in the reference particle
418-
# instead of resampling them.
419-
unset_all_del!(_vi)
420407
transition = PGTransition(model, _vi, logevidence)
421408

422409
return transition, PGState(_vi, newreference.rng)
@@ -499,12 +486,11 @@ function DynamicPPL.assume(
499486
vi = push!!(vi, vn, r, dist)
500487
elseif DynamicPPL.is_flagged(vi, vn, "del")
501488
DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent
502-
r = rand(trng, dist)
503-
vi[vn] = DynamicPPL.tovec(r)
504489
# TODO(mhauru):
505490
# The below is the only line that differs from assume called on SampleFromPrior.
506-
# Could we just call assume on SampleFromPrior and then `setorder!!` after that?
507-
vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi))
491+
# Could we just call assume on SampleFromPrior with a specific rng?
492+
r = rand(trng, dist)
493+
vi[vn] = DynamicPPL.tovec(r)
508494
else
509495
r = vi[vn]
510496
end
@@ -546,8 +532,6 @@ function AdvancedPS.Trace(
546532
rng::AdvancedPS.TracedRNG,
547533
)
548534
newvarinfo = deepcopy(varinfo)
549-
newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo)
550-
551535
tmodel = TracedModel(model, sampler, newvarinfo, rng)
552536
newtrace = AdvancedPS.Trace(tmodel, rng)
553537
return newtrace

0 commit comments

Comments
 (0)