|
565 | 565 | end
|
566 | 566 | end
|
567 | 567 |
|
| 568 | + @testset "PG with variable number of observations" begin |
| 569 | + # When sampling from a model with Particle Gibbs, it is mandatory for |
| 570 | + # the number of observations to be the same in all particles, since the |
| 571 | + # observations trigger particle resampling. |
| 572 | + # |
| 573 | + # Up until Turing v0.39, `x ~ dist` statements where `x` was the |
| 574 | + # responsibility of a different (non-PG) Gibbs subsampler used to not |
| 575 | + # count as an observation. Instead, the log-probability `logpdf(dist, x)` |
| 576 | + # would be manually added to the VarInfo's `logp` field and included in the |
| 577 | + # weighting for the _following_ observation. |
| 578 | + # |
| 579 | + # In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!! |
| 580 | + # which thus triggers resampling. Thus, for example, the following model |
| 581 | + # does not work any more: |
| 582 | + # |
| 583 | + # @model function f() |
| 584 | + # a ~ Poisson(2.0) |
| 585 | + # x = Vector{Float64}(undef, a) |
| 586 | + # for i in eachindex(x) |
| 587 | + # x[i] ~ Normal() |
| 588 | + # end |
| 589 | + # end |
| 590 | + # sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) |
| 591 | + # |
| 592 | + # because the number of observations in each particle depends on the value |
| 593 | + # of `a`. |
| 594 | + # |
| 595 | + # This testset checks that ways of working around such a situation. |
| 596 | + |
| 597 | + function test_dynamic_bernoulli(chain) |
| 598 | + means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) |
| 599 | + stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) |
| 600 | + for vn in keys(means) |
| 601 | + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) |
| 602 | + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) |
| 603 | + end |
| 604 | + end |
| 605 | + |
| 606 | + @testset "Coalescing multiple observations into one" begin |
| 607 | + # Instead of observing x[1] and x[2] separately, we lump them into a |
| 608 | + # single distribution. |
| 609 | + @model function dynamic_bernoulli() |
| 610 | + b ~ Bernoulli() |
| 611 | + if b |
| 612 | + dists = [Normal(1.0)] |
| 613 | + else |
| 614 | + dists = [Normal(1.0), Normal(2.0)] |
| 615 | + end |
| 616 | + return x ~ product_distribution(dists) |
| 617 | + end |
| 618 | + model = dynamic_bernoulli() |
| 619 | + # This currently fails because if the global varinfo has `x` with length 2, |
| 620 | + # and the particle sampler has `b = true`, it attempts to calculate the |
| 621 | + # log-likelihood of a length-2 vector with respect to a length-1 |
| 622 | + # distribution. |
| 623 | + @test_throws DimensionMismatch chain = sample( |
| 624 | + StableRNG(468), |
| 625 | + model, |
| 626 | + Gibbs(:b => PG(10), :x => ESS()), |
| 627 | + 2000; |
| 628 | + discard_initial=100, |
| 629 | + ) |
| 630 | + # test_dynamic_bernoulli(chain) |
| 631 | + end |
| 632 | + |
| 633 | + @testset "Inserting @addlogprob!" begin |
| 634 | + # On top of observing x[i], we also add in extra 'observations' |
| 635 | + @model function dynamic_bernoulli_2() |
| 636 | + b ~ Bernoulli() |
| 637 | + x_length = b ? 1 : 2 |
| 638 | + x = Vector{Float64}(undef, x_length) |
| 639 | + for i in 1:x_length |
| 640 | + x[i] ~ Normal(i, 1.0) |
| 641 | + end |
| 642 | + if length(x) == 1 |
| 643 | + # This value is the expectation value of logpdf(Normal(), x) where x ~ Normal(). |
| 644 | + # See discussion in |
| 645 | + # https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817 |
| 646 | + @addlogprob!(-1.418849) |
| 647 | + end |
| 648 | + end |
| 649 | + model = dynamic_bernoulli_2() |
| 650 | + chain = sample( |
| 651 | + StableRNG(468), |
| 652 | + model, |
| 653 | + Gibbs(:b => PG(10), :x => ESS()), |
| 654 | + 2000; |
| 655 | + discard_initial=100, |
| 656 | + ) |
| 657 | + test_dynamic_bernoulli(chain) |
| 658 | + end |
| 659 | + end |
| 660 | + |
568 | 661 | @testset "Demo model" begin
|
569 | 662 | @testset verbose = true "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
|
570 | 663 | vns = DynamicPPL.TestUtils.varnames(model)
|
|
0 commit comments