Skip to content

Commit 214597b

Browse files
committed
Use nice functions
1 parent 967ebfb commit 214597b

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,32 @@
44

55
### AdvancedPS models and interface
66

7+
"""
8+
set_all_del!(vi::AbstractVarInfo)
9+
10+
Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for
11+
resampling.
12+
"""
13+
function set_all_del!(vi::AbstractVarInfo)
14+
for vn in keys(vi)
15+
DynamicPPL.set_flag!(vi, vn, "del")
16+
end
17+
return nothing
18+
end
19+
20+
"""
21+
unset_all_del!(vi::AbstractVarInfo)
22+
23+
Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing
24+
them from being resampled.
25+
"""
26+
function unset_all_del!(vi::AbstractVarInfo)
27+
for vn in keys(vi)
28+
DynamicPPL.unset_flag!(vi, vn, "del")
29+
end
30+
return nothing
31+
end
32+
733
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
834
AdvancedPS.AbstractGenericModel
935
model::M
@@ -58,9 +84,7 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
5884
# variables in the VarInfo. This is slightly overkill: it is not necessary
5985
# to set the 'del' flag for variables that were already sampled. However,
6086
# it allows us to avoid using DynamicPPL.set_retained_vns_del!.
61-
for vn in keys(trace.varinfo)
62-
DynamicPPL.set_flag!(trace.varinfo, vn, "del")
63-
end
87+
set_all_del!(trace.varinfo)
6488
return trace
6589
end
6690

@@ -190,9 +214,7 @@ function DynamicPPL.initialstep(
190214
# Reset the VarInfo.
191215
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
192216
vi = DynamicPPL.reset_num_produce!!(vi)
193-
for vn in keys(vi)
194-
DynamicPPL.set_flag!(vi, vn, "del")
195-
end
217+
set_all_del!(vi)
196218
vi = DynamicPPL.resetlogp!!(vi)
197219
vi = DynamicPPL.empty!!(vi)
198220

@@ -323,9 +345,7 @@ function DynamicPPL.initialstep(
323345
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
324346
# Reset the VarInfo before new sweep
325347
vi = DynamicPPL.reset_num_produce!!(vi)
326-
for vn in keys(vi)
327-
DynamicPPL.set_flag!(vi, vn, "del")
328-
end
348+
set_all_del!(vi)
329349
vi = DynamicPPL.resetlogp!!(vi)
330350

331351
# Create a new set of particles
@@ -350,9 +370,7 @@ function DynamicPPL.initialstep(
350370
# This is necessary because the model will be re-evaluated and we
351371
# want to make sure we do use the values in the reference particle
352372
# instead of resampling them.
353-
for vn in keys(_vi)
354-
DynamicPPL.unset_flag!(_vi, vn, "del")
355-
end
373+
unset_all_del!(_vi)
356374
transition = PGTransition(model, _vi, logevidence)
357375

358376
return transition, PGState(_vi, reference.rng)
@@ -371,9 +389,7 @@ function AbstractMCMC.step(
371389
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
372390

373391
# For all other particles, do not retain the variables but resample them.
374-
for vn in keys(vi)
375-
DynamicPPL.set_flag!(vi, vn, "del")
376-
end
392+
set_all_del!(vi)
377393

378394
# Create a new set of particles.
379395
num_particles = spl.alg.nparticles
@@ -400,9 +416,7 @@ function AbstractMCMC.step(
400416
# This is necessary because the model will be re-evaluated and we
401417
# want to make sure we do use the values in the reference particle
402418
# instead of resampling them.
403-
for vn in keys(_vi)
404-
DynamicPPL.unset_flag!(_vi, vn, "del")
405-
end
419+
unset_all_del!(_vi)
406420
transition = PGTransition(model, _vi, logevidence)
407421

408422
return transition, PGState(_vi, newreference.rng)

0 commit comments

Comments
 (0)