Skip to content

Commit 7fb5fa0

Browse files
committed
Remove all uses of set_retained_vns_del!
1 parent c59ad26 commit 7fb5fa0

File tree

1 file changed

+34
-3
lines changed

1 file changed

+34
-3
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,20 @@ function AdvancedPS.advance!(
4747
end
4848

4949
function AdvancedPS.delete_retained!(trace::TracedModel)
50-
DynamicPPL.set_retained_vns_del!(trace.varinfo)
50+
# TODO(DPPL0.37/penelopeysm): Explain this a bit better.
51+
#
52+
# This method is called if, during a CSMC update, we perform a resampling
53+
# and choose the reference particle as the trajectory to carry on from.
54+
# In such a case, we need to ensure that when we continue sampling (i.e.
55+
# the next time we hit tilde_assume), we don't use the values in the
56+
# reference particle but rather sample new values.
57+
# In this implementation, we indiscriminately set the 'del' flag for all
58+
# variables in the VarInfo. This is slightly overkill: it is not necessary
59+
# to set the 'del' flag for variables that were already sampled. However,
60+
# it allows us to avoid using DynamicPPL.set_retained_vns_del!.
61+
for vn in keys(trace.varinfo)
62+
DynamicPPL.set_flag!(trace.varinfo, vn, "del")
63+
end
5164
return trace
5265
end
5366

@@ -177,7 +190,9 @@ function DynamicPPL.initialstep(
177190
# Reset the VarInfo.
178191
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
179192
vi = DynamicPPL.reset_num_produce!!(vi)
180-
DynamicPPL.set_retained_vns_del!(vi)
193+
for vn in keys(vi)
194+
DynamicPPL.set_flag!(vi, vn, "del")
195+
end
181196
vi = DynamicPPL.resetlogp!!(vi)
182197
vi = DynamicPPL.empty!!(vi)
183198

@@ -308,7 +323,9 @@ function DynamicPPL.initialstep(
308323
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
309324
# Reset the VarInfo before new sweep
310325
vi = DynamicPPL.reset_num_produce!!(vi)
311-
DynamicPPL.set_retained_vns_del!(vi)
326+
for vn in keys(vi)
327+
DynamicPPL.set_flag!(vi, vn, "del")
328+
end
312329
vi = DynamicPPL.resetlogp!!(vi)
313330

314331
# Create a new set of particles
@@ -329,6 +346,13 @@ function DynamicPPL.initialstep(
329346

330347
# Compute the first transition.
331348
_vi = reference.model.f.varinfo
349+
# Unset any 'del' flags before we actually construct the transition.
350+
# This is necessary because the model will be re-evaluated and we
351+
# want to make sure we do use the values in the reference particle
352+
# instead of resampling them.
353+
for vn in keys(_vi)
354+
DynamicPPL.unset_flag!(_vi, vn, "del")
355+
end
332356
transition = PGTransition(model, _vi, logevidence)
333357

334358
return transition, PGState(_vi, reference.rng)
@@ -372,6 +396,13 @@ function AbstractMCMC.step(
372396

373397
# Compute the transition.
374398
_vi = newreference.model.f.varinfo
399+
# Unset any 'del' flags before we actually construct the transition.
400+
# This is necessary because the model will be re-evaluated and we
401+
# want to make sure we do use the values in the reference particle
402+
# instead of resampling them.
403+
for vn in keys(_vi)
404+
DynamicPPL.unset_flag!(_vi, vn, "del")
405+
end
375406
transition = PGTransition(model, _vi, logevidence)
376407

377408
return transition, PGState(_vi, newreference.rng)

0 commit comments

Comments
 (0)