Skip to content

Commit eb2732a

Browse files
committed
Add back (some kind of) dynamic model test
1 parent eaa1724 commit eb2732a

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

test/mcmc/gibbs.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,99 @@ end
565565
end
566566
end
567567

568+
@testset "PG with variable number of observations" begin
569+
# When sampling from a model with Particle Gibbs, it is mandatory for
570+
# the number of observations to be the same in all particles, since the
571+
# observations trigger particle resampling.
572+
#
573+
# Up until Turing v0.39, `x ~ dist` statements where `x` was the
574+
# responsibility of a different (non-PG) Gibbs subsampler used to not
575+
# count as an observation. Instead, the log-probability `logpdf(dist, x)`
576+
# would be manually added to the VarInfo's `logp` field and included in the
577+
# weighting for the _following_ observation.
578+
#
579+
# In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!!
580+
# which thus triggers resampling. Thus, for example, the following model
581+
# does not work any more:
582+
#
583+
# @model function f()
584+
# a ~ Poisson(2.0)
585+
# x = Vector{Float64}(undef, a)
586+
# for i in eachindex(x)
587+
# x[i] ~ Normal()
588+
# end
589+
# end
590+
# sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000)
591+
#
592+
# because the number of observations in each particle depends on the value
593+
# of `a`.
594+
#
595+
# This testset checks that ways of working around such a situation.
596+
597+
function test_dynamic_bernoulli(chain)
598+
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0)
599+
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0)
600+
for vn in keys(means)
601+
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1)
602+
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1)
603+
end
604+
end
605+
606+
@testset "Coalescing multiple observations into one" begin
607+
# Instead of observing x[1] and x[2] separately, we lump them into a
608+
# single distribution.
609+
@model function dynamic_bernoulli()
610+
b ~ Bernoulli()
611+
if b
612+
dists = [Normal(1.0)]
613+
else
614+
dists = [Normal(1.0), Normal(2.0)]
615+
end
616+
return x ~ product_distribution(dists)
617+
end
618+
model = dynamic_bernoulli()
619+
# This currently fails because if the global varinfo has `x` with length 2,
620+
# and the particle sampler has `b = true`, it attempts to calculate the
621+
# log-likelihood of a length-2 vector with respect to a length-1
622+
# distribution.
623+
@test_throws DimensionMismatch chain = sample(
624+
StableRNG(468),
625+
model,
626+
Gibbs(:b => PG(10), :x => ESS()),
627+
2000;
628+
discard_initial=100,
629+
)
630+
# test_dynamic_bernoulli(chain)
631+
end
632+
633+
@testset "Inserting @addlogprob!" begin
634+
# On top of observing x[i], we also add in extra 'observations'
635+
@model function dynamic_bernoulli_2()
636+
b ~ Bernoulli()
637+
x_length = b ? 1 : 2
638+
x = Vector{Float64}(undef, x_length)
639+
for i in 1:x_length
640+
x[i] ~ Normal(i, 1.0)
641+
end
642+
if length(x) == 1
643+
# This value is the expectation value of logpdf(Normal(), x) where x ~ Normal().
644+
# See discussion in
645+
# https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817
646+
@addlogprob!(-1.418849)
647+
end
648+
end
649+
model = dynamic_bernoulli_2()
650+
chain = sample(
651+
StableRNG(468),
652+
model,
653+
Gibbs(:b => PG(10), :x => ESS()),
654+
2000;
655+
discard_initial=100,
656+
)
657+
test_dynamic_bernoulli(chain)
658+
end
659+
end
660+
568661
@testset "Demo model" begin
569662
@testset verbose = true "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
570663
vns = DynamicPPL.TestUtils.varnames(model)

0 commit comments

Comments
 (0)