Skip to content

Commit 7124864

Browse files
penelopeysmmhauru
andauthored
"Fixes" for PG-in-Gibbs (#2629)
* WIP PMCMC work * Fixes to ProduceLogLikelihoodAccumulator * inline definition of `set_retained_vns_del!` * Fix ProduceLogLikelihoodAcc * Remove all uses of `set_retained_vns_del!` * Use nice functions * Remove PG tests with dynamic number of Gibbs-conditioned-observations * Fix essential/container tests * Update pMCMC implementation as per discussion * remove extra printing statements * revert unneeded changes * Add back (some kind of) dynamic model test * fix rebase * Add a todo comment for dynamic model tests --------- Co-authored-by: Markus Hauru <[email protected]>
1 parent c062867 commit 7124864

File tree

3 files changed

+144
-52
lines changed

3 files changed

+144
-52
lines changed

src/mcmc/particle_mcmc.jl

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,38 @@
44

55
### AdvancedPS models and interface
66

7+
"""
8+
set_all_del!(vi::AbstractVarInfo)
9+
10+
Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for
11+
resampling.
12+
"""
13+
function set_all_del!(vi::AbstractVarInfo)
14+
# TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we
15+
# could either:
16+
# - keep a boolean 'resample' flag on the trace, or
17+
# - modify the model context appropriately.
18+
# However, this refactoring will have to wait until InitContext is
19+
# merged into DPPL.
20+
for vn in keys(vi)
21+
DynamicPPL.set_flag!(vi, vn, "del")
22+
end
23+
return nothing
24+
end
25+
26+
"""
27+
unset_all_del!(vi::AbstractVarInfo)
28+
29+
Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing
30+
them from being resampled.
31+
"""
32+
function unset_all_del!(vi::AbstractVarInfo)
33+
for vn in keys(vi)
34+
DynamicPPL.unset_flag!(vi, vn, "del")
35+
end
36+
return nothing
37+
end
38+
739
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
840
AdvancedPS.AbstractGenericModel
941
model::M
@@ -33,26 +65,30 @@ end
3365
function AdvancedPS.advance!(
3466
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
3567
)
36-
# We want to increment num produce for the VarInfo stored in the trace. The trace is
37-
# mutable, so we create a new model with the incremented VarInfo and set it in the trace
38-
model = trace.model
39-
model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!(
40-
model.f.varinfo
41-
)
42-
trace.model = model
4368
# Make sure we load/reset the rng in the new replaying mechanism
4469
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
4570
score = consume(trace.model.ctask)
4671
return score
4772
end
4873

4974
function AdvancedPS.delete_retained!(trace::TracedModel)
50-
DynamicPPL.set_retained_vns_del!(trace.varinfo)
75+
# This method is called if, during a CSMC update, we perform a resampling
76+
# and choose the reference particle as the trajectory to carry on from.
77+
# In such a case, we need to ensure that when we continue sampling (i.e.
78+
# the next time we hit tilde_assume), we don't use the values in the
79+
# reference particle but rather sample new values.
80+
#
81+
# Here, we indiscriminately set the 'del' flag for all variables in the
82+
# VarInfo. This is slightly overkill: it is not necessary to set the 'del'
83+
# flag for variables that were already sampled. However, it allows us to
84+
# avoid keeping track of which variables were sampled, which leads to many
85+
# simplifications in the VarInfo data structure.
86+
set_all_del!(trace.varinfo)
5187
return trace
5288
end
5389

5490
function AdvancedPS.reset_model(trace::TracedModel)
55-
return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo)
91+
return trace
5692
end
5793

5894
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
@@ -176,8 +212,7 @@ function DynamicPPL.initialstep(
176212
)
177213
# Reset the VarInfo.
178214
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
179-
vi = DynamicPPL.reset_num_produce!!(vi)
180-
DynamicPPL.set_retained_vns_del!(vi)
215+
set_all_del!(vi)
181216
vi = DynamicPPL.resetlogp!!(vi)
182217
vi = DynamicPPL.empty!!(vi)
183218

@@ -307,8 +342,7 @@ function DynamicPPL.initialstep(
307342
)
308343
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
309344
# Reset the VarInfo before new sweep
310-
vi = DynamicPPL.reset_num_produce!!(vi)
311-
DynamicPPL.set_retained_vns_del!(vi)
345+
set_all_del!(vi)
312346
vi = DynamicPPL.resetlogp!!(vi)
313347

314348
# Create a new set of particles
@@ -339,14 +373,15 @@ function AbstractMCMC.step(
339373
)
340374
# Reset the VarInfo before new sweep.
341375
vi = state.vi
342-
vi = DynamicPPL.reset_num_produce!!(vi)
376+
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
343377
vi = DynamicPPL.resetlogp!!(vi)
344378

345379
# Create reference particle for which the samples will be retained.
380+
unset_all_del!(vi)
346381
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
347382

348383
# For all other particles, do not retain the variables but resample them.
349-
DynamicPPL.set_retained_vns_del!(vi)
384+
set_all_del!(vi)
350385

351386
# Create a new set of particles.
352387
num_particles = spl.alg.nparticles
@@ -451,12 +486,11 @@ function DynamicPPL.assume(
451486
vi = push!!(vi, vn, r, dist)
452487
elseif DynamicPPL.is_flagged(vi, vn, "del")
453488
DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent
454-
r = rand(trng, dist)
455-
vi[vn] = DynamicPPL.tovec(r)
456489
# TODO(mhauru):
457490
# The below is the only line that differs from assume called on SampleFromPrior.
458-
# Could we just call assume on SampleFromPrior and then `setorder!!` after that?
459-
vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi))
491+
# Could we just call assume on SampleFromPrior with a specific rng?
492+
r = rand(trng, dist)
493+
vi[vn] = DynamicPPL.tovec(r)
460494
else
461495
r = vi[vn]
462496
end
@@ -498,8 +532,6 @@ function AdvancedPS.Trace(
498532
rng::AdvancedPS.TracedRNG,
499533
)
500534
newvarinfo = deepcopy(varinfo)
501-
newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo)
502-
503535
tmodel = TracedModel(model, sampler, newvarinfo, rng)
504536
newtrace = AdvancedPS.Trace(tmodel, rng)
505537
return newtrace

test/essential/container.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using Turing
1919

2020
@testset "constructor" begin
2121
vi = DynamicPPL.VarInfo()
22+
vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator())
2223
sampler = Sampler(PG(10))
2324
model = test()
2425
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())
@@ -46,6 +47,7 @@ using Turing
4647
return a, b
4748
end
4849
vi = DynamicPPL.VarInfo()
50+
vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator())
4951
sampler = Sampler(PG(10))
5052
model = normal()
5153

test/mcmc/gibbs.jl

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ end
207207
val ~ Normal(s, 1)
208208
1.0 ~ Normal(s + m, 1)
209209

210-
n := m + 1
211-
xs = M(undef, n)
210+
n := m
211+
xs = M(undef, 5)
212212
for i in eachindex(xs)
213213
xs[i] ~ Beta(0.5, 0.5)
214214
end
@@ -565,40 +565,98 @@ end
565565
end
566566
end
567567

568-
# The below test used to sample incorrectly before
569-
# https://github.com/TuringLang/Turing.jl/pull/2328
570-
@testset "dynamic model with ESS" begin
571-
@model function dynamic_model_for_ess()
572-
b ~ Bernoulli()
573-
x_length = b ? 1 : 2
574-
x = Vector{Float64}(undef, x_length)
575-
for i in 1:x_length
576-
x[i] ~ Normal(i, 1.0)
568+
@testset "PG with variable number of observations" begin
569+
# When sampling from a model with Particle Gibbs, it is mandatory for
570+
# the number of observations to be the same in all particles, since the
571+
# observations trigger particle resampling.
572+
#
573+
# Up until Turing v0.39, `x ~ dist` statements where `x` was the
574+
# responsibility of a different (non-PG) Gibbs subsampler used to not
575+
# count as an observation. Instead, the log-probability `logpdf(dist, x)`
576+
# would be manually added to the VarInfo's `logp` field and included in the
577+
# weighting for the _following_ observation.
578+
#
579+
# In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!!
580+
# which thus triggers resampling. Thus, for example, the following model
581+
# does not work any more:
582+
#
583+
# @model function f()
584+
# a ~ Poisson(2.0)
585+
# x = Vector{Float64}(undef, a)
586+
# for i in eachindex(x)
587+
# x[i] ~ Normal()
588+
# end
589+
# end
590+
# sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000)
591+
#
592+
# because the number of observations in each particle depends on the value
593+
# of `a`.
594+
#
595+
# This testset checks that ways of working around such a situation.
596+
597+
function test_dynamic_bernoulli(chain)
598+
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0)
599+
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0)
600+
for vn in keys(means)
601+
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1)
602+
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1)
577603
end
578604
end
579605

580-
m = dynamic_model_for_ess()
581-
chain = sample(m, Gibbs(:b => PG(10), :x => ESS()), 2000; discard_initial=100)
582-
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0)
583-
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0)
584-
for vn in keys(means)
585-
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1)
586-
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1)
606+
# TODO(DPPL0.37/penelopeysm): decide what to do with these tests
607+
@testset "Coalescing multiple observations into one" begin
608+
# Instead of observing x[1] and x[2] separately, we lump them into a
609+
# single distribution.
610+
@model function dynamic_bernoulli()
611+
b ~ Bernoulli()
612+
if b
613+
dists = [Normal(1.0)]
614+
else
615+
dists = [Normal(1.0), Normal(2.0)]
616+
end
617+
return x ~ product_distribution(dists)
618+
end
619+
model = dynamic_bernoulli()
620+
# This currently fails because if the global varinfo has `x` with length 2,
621+
# and the particle sampler has `b = true`, it attempts to calculate the
622+
# log-likelihood of a length-2 vector with respect to a length-1
623+
# distribution.
624+
@test_throws DimensionMismatch chain = sample(
625+
StableRNG(468),
626+
model,
627+
Gibbs(:b => PG(10), :x => ESS()),
628+
2000;
629+
discard_initial=100,
630+
)
631+
# test_dynamic_bernoulli(chain)
587632
end
588-
end
589633

590-
@testset "dynamic model with dot tilde" begin
591-
@model function dynamic_model_with_dot_tilde(
592-
num_zs=10, (::Type{M})=Vector{Float64}
593-
) where {M}
594-
z = Vector{Int}(undef, num_zs)
595-
z .~ Poisson(1.0)
596-
num_ms = sum(z)
597-
m = M(undef, num_ms)
598-
return m .~ Normal(1.0, 1.0)
599-
end
600-
model = dynamic_model_with_dot_tilde()
601-
sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4)), 100)
634+
@testset "Inserting @addlogprob!" begin
635+
# On top of observing x[i], we also add in extra 'observations'
636+
@model function dynamic_bernoulli_2()
637+
b ~ Bernoulli()
638+
x_length = b ? 1 : 2
639+
x = Vector{Float64}(undef, x_length)
640+
for i in 1:x_length
641+
x[i] ~ Normal(i, 1.0)
642+
end
643+
if length(x) == 1
644+
# This value is the expectation value of logpdf(Normal(), x) where x ~ Normal().
645+
# See discussion in
646+
# https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817
647+
@addlogprob!(-1.418849)
648+
end
649+
end
650+
model = dynamic_bernoulli_2()
651+
chain = sample(
652+
StableRNG(468),
653+
model,
654+
Gibbs(:b => PG(10), :x => ESS()),
655+
2000;
656+
discard_initial=100,
657+
)
658+
test_dynamic_bernoulli(chain)
659+
end
602660
end
603661

604662
@testset "Demo model" begin

0 commit comments

Comments
 (0)