@@ -11,6 +11,12 @@ Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for
11
11
resampling.
12
12
"""
13
13
function set_all_del! (vi:: AbstractVarInfo )
14
+ # TODO (penelopeysm): Instead of being a 'del' flag on the VarInfo, we
15
+ # could either:
16
+ # - keep a boolean 'resample' flag on the trace, or
17
+ # - modify the model context appropriately.
18
+ # However, this refactoring will have to wait until InitContext is
19
+ # merged into DPPL.
14
20
for vn in keys (vi)
15
21
DynamicPPL. set_flag! (vi, vn, " del" )
16
22
end
59
65
function AdvancedPS. advance! (
60
66
trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} , isref:: Bool = false
61
67
)
62
- # We want to increment num produce for the VarInfo stored in the trace. The trace is
63
- # mutable, so we create a new model with the incremented VarInfo and set it in the trace
64
- model = trace. model
65
- model = Accessors. @set model. f. varinfo = DynamicPPL. increment_num_produce!! (
66
- model. f. varinfo
67
- )
68
- trace. model = model
69
68
# Make sure we load/reset the rng in the new replaying mechanism
70
69
isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
71
70
score = consume (trace. model. ctask)
72
71
return score
73
72
end
74
73
75
74
function AdvancedPS. delete_retained! (trace:: TracedModel )
76
- # TODO (DPPL0.37/penelopeysm): Explain this a bit better.
77
- #
78
75
# This method is called if, during a CSMC update, we perform a resampling
79
76
# and choose the reference particle as the trajectory to carry on from.
80
77
# In such a case, we need to ensure that when we continue sampling (i.e.
81
78
# the next time we hit tilde_assume), we don't use the values in the
82
79
# reference particle but rather sample new values.
83
- # In this implementation, we indiscriminately set the 'del' flag for all
84
- # variables in the VarInfo. This is slightly overkill: it is not necessary
85
- # to set the 'del' flag for variables that were already sampled. However,
86
- # it allows us to avoid using DynamicPPL.set_retained_vns_del!.
80
+ #
81
+ # Here, we indiscriminately set the 'del' flag for all variables in the
82
+ # VarInfo. This is slightly overkill: it is not necessary to set the 'del'
83
+ # flag for variables that were already sampled. However, it allows us to
84
+ # avoid keeping track of which variables were sampled, which leads to many
85
+ # simplifications in the VarInfo data structure.
87
86
set_all_del! (trace. varinfo)
88
87
return trace
89
88
end
90
89
91
90
function AdvancedPS. reset_model (trace:: TracedModel )
92
- return Accessors . @set trace. varinfo = DynamicPPL . reset_num_produce!! (trace . varinfo)
91
+ return trace
93
92
end
94
93
95
94
function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
@@ -213,7 +212,6 @@ function DynamicPPL.initialstep(
213
212
)
214
213
# Reset the VarInfo.
215
214
vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
216
- vi = DynamicPPL. reset_num_produce!! (vi)
217
215
set_all_del! (vi)
218
216
vi = DynamicPPL. resetlogp!! (vi)
219
217
vi = DynamicPPL. empty!! (vi)
@@ -344,7 +342,6 @@ function DynamicPPL.initialstep(
344
342
)
345
343
vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
346
344
# Reset the VarInfo before new sweep
347
- vi = DynamicPPL. reset_num_produce!! (vi)
348
345
set_all_del! (vi)
349
346
vi = DynamicPPL. resetlogp!! (vi)
350
347
@@ -366,11 +363,6 @@ function DynamicPPL.initialstep(
366
363
367
364
# Compute the first transition.
368
365
_vi = reference. model. f. varinfo
369
- # Unset any 'del' flags before we actually construct the transition.
370
- # This is necessary because the model will be re-evaluated and we
371
- # want to make sure we do use the values in the reference particle
372
- # instead of resampling them.
373
- unset_all_del! (_vi)
374
366
transition = PGTransition (model, _vi, logevidence)
375
367
376
368
return transition, PGState (_vi, reference. rng)
@@ -382,10 +374,10 @@ function AbstractMCMC.step(
382
374
# Reset the VarInfo before new sweep.
383
375
vi = state. vi
384
376
vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
385
- vi = DynamicPPL. reset_num_produce!! (vi)
386
377
vi = DynamicPPL. resetlogp!! (vi)
387
378
388
379
# Create reference particle for which the samples will be retained.
380
+ unset_all_del! (vi)
389
381
reference = AdvancedPS. forkr (AdvancedPS. Trace (model, spl, vi, state. rng))
390
382
391
383
# For all other particles, do not retain the variables but resample them.
@@ -412,11 +404,6 @@ function AbstractMCMC.step(
412
404
413
405
# Compute the transition.
414
406
_vi = newreference. model. f. varinfo
415
- # Unset any 'del' flags before we actually construct the transition.
416
- # This is necessary because the model will be re-evaluated and we
417
- # want to make sure we do use the values in the reference particle
418
- # instead of resampling them.
419
- unset_all_del! (_vi)
420
407
transition = PGTransition (model, _vi, logevidence)
421
408
422
409
return transition, PGState (_vi, newreference. rng)
@@ -499,12 +486,11 @@ function DynamicPPL.assume(
499
486
vi = push!! (vi, vn, r, dist)
500
487
elseif DynamicPPL. is_flagged (vi, vn, " del" )
501
488
DynamicPPL. unset_flag! (vi, vn, " del" ) # Reference particle parent
502
- r = rand (trng, dist)
503
- vi[vn] = DynamicPPL. tovec (r)
504
489
# TODO (mhauru):
505
490
# The below is the only line that differs from assume called on SampleFromPrior.
506
- # Could we just call assume on SampleFromPrior and then `setorder!!` after that?
507
- vi = DynamicPPL. setorder!! (vi, vn, DynamicPPL. get_num_produce (vi))
491
+ # Could we just call assume on SampleFromPrior with a specific rng?
492
+ r = rand (trng, dist)
493
+ vi[vn] = DynamicPPL. tovec (r)
508
494
else
509
495
r = vi[vn]
510
496
end
@@ -546,8 +532,6 @@ function AdvancedPS.Trace(
546
532
rng:: AdvancedPS.TracedRNG ,
547
533
)
548
534
newvarinfo = deepcopy (varinfo)
549
- newvarinfo = DynamicPPL. reset_num_produce!! (newvarinfo)
550
-
551
535
tmodel = TracedModel (model, sampler, newvarinfo, rng)
552
536
newtrace = AdvancedPS. Trace (tmodel, rng)
553
537
return newtrace
0 commit comments