|
4 | 4 |
|
5 | 5 | ### AdvancedPS models and interface
|
6 | 6 |
|
| 7 | +""" |
| 8 | + set_all_del!(vi::AbstractVarInfo) |
| 9 | +
|
| 10 | +Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for |
| 11 | +resampling. |
| 12 | +""" |
| 13 | +function set_all_del!(vi::AbstractVarInfo) |
| 14 | + # TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we |
| 15 | + # could either: |
| 16 | + # - keep a boolean 'resample' flag on the trace, or |
| 17 | + # - modify the model context appropriately. |
| 18 | + # However, this refactoring will have to wait until InitContext is |
| 19 | + # merged into DPPL. |
| 20 | + for vn in keys(vi) |
| 21 | + DynamicPPL.set_flag!(vi, vn, "del") |
| 22 | + end |
| 23 | + return nothing |
| 24 | +end |
| 25 | + |
| 26 | +""" |
| 27 | + unset_all_del!(vi::AbstractVarInfo) |
| 28 | +
|
| 29 | +Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing |
| 30 | +them from being resampled. |
| 31 | +""" |
| 32 | +function unset_all_del!(vi::AbstractVarInfo) |
| 33 | + for vn in keys(vi) |
| 34 | + DynamicPPL.unset_flag!(vi, vn, "del") |
| 35 | + end |
| 36 | + return nothing |
| 37 | +end |
| 38 | + |
7 | 39 | struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
|
8 | 40 | AdvancedPS.AbstractGenericModel
|
9 | 41 | model::M
|
|
33 | 65 | function AdvancedPS.advance!(
|
34 | 66 | trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
|
35 | 67 | )
|
36 |
| - # We want to increment num produce for the VarInfo stored in the trace. The trace is |
37 |
| - # mutable, so we create a new model with the incremented VarInfo and set it in the trace |
38 |
| - model = trace.model |
39 |
| - model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!( |
40 |
| - model.f.varinfo |
41 |
| - ) |
42 |
| - trace.model = model |
43 | 68 | # Make sure we load/reset the rng in the new replaying mechanism
|
44 | 69 | isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
|
45 | 70 | score = consume(trace.model.ctask)
|
46 | 71 | return score
|
47 | 72 | end
|
48 | 73 |
|
49 | 74 | function AdvancedPS.delete_retained!(trace::TracedModel)
|
50 |
| - DynamicPPL.set_retained_vns_del!(trace.varinfo) |
| 75 | + # This method is called if, during a CSMC update, we perform a resampling |
| 76 | + # and choose the reference particle as the trajectory to carry on from. |
| 77 | + # In such a case, we need to ensure that when we continue sampling (i.e. |
| 78 | + # the next time we hit tilde_assume), we don't use the values in the |
| 79 | + # reference particle but rather sample new values. |
| 80 | + # |
| 81 | + # Here, we indiscriminately set the 'del' flag for all variables in the |
| 82 | + # VarInfo. This is slightly overkill: it is not necessary to set the 'del' |
| 83 | + # flag for variables that were already sampled. However, it allows us to |
| 84 | + # avoid keeping track of which variables were sampled, which leads to many |
| 85 | + # simplifications in the VarInfo data structure. |
| 86 | + set_all_del!(trace.varinfo) |
51 | 87 | return trace
|
52 | 88 | end
|
53 | 89 |
|
54 | 90 | function AdvancedPS.reset_model(trace::TracedModel)
|
55 |
| - return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) |
| 91 | + return trace |
56 | 92 | end
|
57 | 93 |
|
58 | 94 | function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
|
@@ -176,8 +212,7 @@ function DynamicPPL.initialstep(
|
176 | 212 | )
|
177 | 213 | # Reset the VarInfo.
|
178 | 214 | vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
|
179 |
| - vi = DynamicPPL.reset_num_produce!!(vi) |
180 |
| - DynamicPPL.set_retained_vns_del!(vi) |
| 215 | + set_all_del!(vi) |
181 | 216 | vi = DynamicPPL.resetlogp!!(vi)
|
182 | 217 | vi = DynamicPPL.empty!!(vi)
|
183 | 218 |
|
@@ -307,8 +342,7 @@ function DynamicPPL.initialstep(
|
307 | 342 | )
|
308 | 343 | vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
|
309 | 344 | # Reset the VarInfo before new sweep
|
310 |
| - vi = DynamicPPL.reset_num_produce!!(vi) |
311 |
| - DynamicPPL.set_retained_vns_del!(vi) |
| 345 | + set_all_del!(vi) |
312 | 346 | vi = DynamicPPL.resetlogp!!(vi)
|
313 | 347 |
|
314 | 348 | # Create a new set of particles
|
@@ -339,14 +373,15 @@ function AbstractMCMC.step(
|
339 | 373 | )
|
340 | 374 | # Reset the VarInfo before new sweep.
|
341 | 375 | vi = state.vi
|
342 |
| - vi = DynamicPPL.reset_num_produce!!(vi) |
| 376 | + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) |
343 | 377 | vi = DynamicPPL.resetlogp!!(vi)
|
344 | 378 |
|
345 | 379 | # Create reference particle for which the samples will be retained.
|
| 380 | + unset_all_del!(vi) |
346 | 381 | reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
|
347 | 382 |
|
348 | 383 | # For all other particles, do not retain the variables but resample them.
|
349 |
| - DynamicPPL.set_retained_vns_del!(vi) |
| 384 | + set_all_del!(vi) |
350 | 385 |
|
351 | 386 | # Create a new set of particles.
|
352 | 387 | num_particles = spl.alg.nparticles
|
@@ -451,12 +486,11 @@ function DynamicPPL.assume(
|
451 | 486 | vi = push!!(vi, vn, r, dist)
|
452 | 487 | elseif DynamicPPL.is_flagged(vi, vn, "del")
|
453 | 488 | DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent
|
454 |
| - r = rand(trng, dist) |
455 |
| - vi[vn] = DynamicPPL.tovec(r) |
456 | 489 | # TODO(mhauru):
|
457 | 490 | # The below is the only line that differs from assume called on SampleFromPrior.
|
458 |
| - # Could we just call assume on SampleFromPrior and then `setorder!!` after that? |
459 |
| - vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) |
| 491 | + # Could we just call assume on SampleFromPrior with a specific rng? |
| 492 | + r = rand(trng, dist) |
| 493 | + vi[vn] = DynamicPPL.tovec(r) |
460 | 494 | else
|
461 | 495 | r = vi[vn]
|
462 | 496 | end
|
@@ -498,8 +532,6 @@ function AdvancedPS.Trace(
|
498 | 532 | rng::AdvancedPS.TracedRNG,
|
499 | 533 | )
|
500 | 534 | newvarinfo = deepcopy(varinfo)
|
501 |
| - newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo) |
502 |
| - |
503 | 535 | tmodel = TracedModel(model, sampler, newvarinfo, rng)
|
504 | 536 | newtrace = AdvancedPS.Trace(tmodel, rng)
|
505 | 537 | return newtrace
|
|
0 commit comments