diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 07c8311b4..d6e9afcbb 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -124,19 +124,16 @@ end # Default Transition # ###################### getstats(::Any) = NamedTuple() +getstats(nt::NamedTuple) = nt -# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition, -# SMCTransition, and PGTransition to Turing.Inference.Transition instead. -abstract type AbstractTransition end - -struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition +struct Transition{T,F<:AbstractFloat,N<:NamedTuple} θ::T logprior::F loglikelihood::F stat::N """ - Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true) + Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true) Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step. @@ -146,8 +143,10 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition have junk contents. The role of this method is to re-evaluate `model` and thus set the accumulators to the correct values. - `sampler_transition` is the transition object returned by the sampler - itself and is only used to extract statistics of interest. + `stats` is any object on which `Turing.Inference.getstats` can be called to + return a NamedTuple of statistics. This could be, for example, the transition + returned by an (unwrapped) external sampler. Or alternatively, it could + simply be a NamedTuple itself (for which `getstats` acts as the identity). By default, the model is re-evaluated in order to obtain values of: - the values of the parameters as per user parameterisation (`vals_as_in_model`) @@ -167,8 +166,11 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition must be set up to track `x := y` statements. """ function Transition( - model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true + model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true ) + # Avoid mutating vi as it may be used later e.g. when constructing + # sampler states. + vi = deepcopy(vi) if reevaluate vi = DynamicPPL.setaccs!!( vi, @@ -187,7 +189,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition loglikelihood = DynamicPPL.getloglikelihood(vi) # Get additional statistics - stats = getstats(sampler_transition) + stats = getstats(stats) return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}( vals_as_in_model, logprior, loglikelihood, stats ) @@ -196,17 +198,14 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition function Transition( model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, - sampler_transition; + stats; reevaluate=true, ) # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's # much faster to convert it to a typed varinfo first, hence this method. # https://github.com/TuringLang/Turing.jl/issues/2604 return Transition( - model, - DynamicPPL.typed_varinfo(untyped_vi), - sampler_transition; - reevaluate=reevaluate, + model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate ) end end @@ -318,7 +317,7 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. # This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, + ts::Vector{<:Union{Transition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, @@ -381,7 +380,7 @@ end # This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, + ts::Vector{<:Union{Transition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, state, diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index b500d3a46..ab2add975 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -135,23 +135,6 @@ function SMC(threshold::Real) return SMC(AdvancedPS.resample_systematic, threshold) end -struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The weight of the particle the sample was retrieved from." - weight::F -end - -function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight) - theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint_internal(vi) - return SMCTransition(theta, lp, weight) -end - -getstats_with_lp(t::SMCTransition) = (lp=t.lp, weight=t.weight) - struct SMCState{P,F<:AbstractFloat} particles::P particleindex::Int @@ -228,7 +211,8 @@ function DynamicPPL.initialstep( weight = AdvancedPS.getweight(particles, 1) # Compute the first transition and the first state. - transition = SMCTransition(model, particle.model.f.varinfo, weight) + stats = (; weight=weight, logevidence=logevidence) + transition = Transition(model, particle.model.f.varinfo, stats) state = SMCState(particles, 2, logevidence) return transition, state @@ -246,7 +230,8 @@ function AbstractMCMC.step( weight = AdvancedPS.getweight(particles, index) # Compute the transition and the next state. - transition = SMCTransition(model, particle.model.f.varinfo, weight) + stats = (; weight=weight, logevidence=state.average_logevidence) + transition = Transition(model, particle.model.f.varinfo, stats) nextstate = SMCState(state.particles, index + 1, state.average_logevidence) return transition, nextstate @@ -300,15 +285,6 @@ Equivalent to [`PG`](@ref). """ const CSMC = PG # type alias of PG as Conditional SMC -struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The log evidence of the sample." - logevidence::F -end - struct PGState vi::AbstractVarInfo rng::Random.AbstractRNG @@ -316,16 +292,21 @@ end get_varinfo(state::PGState) = state.vi -function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence) - theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint_internal(vi) - return PGTransition(theta, lp, logevidence) -end - -getstats_with_lp(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence) - -function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState) - return mean(x.logevidence for x in samples) +function getlogevidence( + transitions::AbstractVector{<:Turing.Inference.Transition}, + sampler::Sampler{<:PG}, + state::PGState, +) + logevidences = map(transitions) do t + if haskey(t.stat, :logevidence) + return t.stat.logevidence + else + # This should not really happen, but if it does we can handle it + # gracefully + return missing + end + end + return mean(logevidences) end function DynamicPPL.initialstep( @@ -357,7 +338,7 @@ function DynamicPPL.initialstep( # Compute the first transition. _vi = reference.model.f.varinfo - transition = PGTransition(model, _vi, logevidence) + transition = Transition(model, _vi, (; logevidence=logevidence)) return transition, PGState(_vi, reference.rng) end @@ -397,7 +378,7 @@ function AbstractMCMC.step( # Compute the transition. _vi = newreference.model.f.varinfo - transition = PGTransition(model, _vi, logevidence) + transition = Transition(model, _vi, (; logevidence=logevidence)) return transition, PGState(_vi, newreference.rng) end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 5ca351643..34d7cf9d8 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -184,23 +184,6 @@ function SGLD(; return SGLD(stepsize, adtype) end -struct SGLDTransition{T,F<:Real} <: AbstractTransition - "The parameters for any given sample." - θ::T - "The joint log probability of the sample." - lp::F - "The stepsize that was used to obtain the sample." - stepsize::F -end - -function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize) - theta = getparams(model, vi) - lp = DynamicPPL.getlogjoint_internal(vi) - return SGLDTransition(theta, lp, stepsize) -end - -getstats_with_lp(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize) - struct SGLDState{L,V<:AbstractVarInfo} logdensity::L vi::V @@ -220,13 +203,13 @@ function DynamicPPL.initialstep( end # Create first sample and state. - sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0))) + transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0)))) ℓ = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype ) state = SGLDState(ℓ, vi, 1) - return sample, state + return transition, state end function AbstractMCMC.step( @@ -245,8 +228,8 @@ function AbstractMCMC.step( vi = DynamicPPL.unflatten(vi, θ) # Compute next sample and state. - sample = SGLDTransition(model, vi, stepsize) + transition = Transition(model, vi, (; SGLD_stepsize=stepsize)) newstate = SGLDState(ℓ, vi, state.step + 1) - return sample, newstate + return transition, newstate end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 0fd76be3a..dc8cd42d0 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -598,8 +598,8 @@ end means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) for vn in keys(means) - @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) - @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) + @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.15) + @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.15) end end @@ -651,7 +651,7 @@ end chain = sample( StableRNG(468), model, - Gibbs(:b => PG(10), :x => ESS()), + Gibbs(:b => PG(20), :x => ESS()), 2000; discard_initial=100, ) diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 7a2f5fe1c..ad7373b85 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -1,10 +1,11 @@ module ParticleMCMCTests using ..Models: gdemo_default -#using ..Models: MoGtest, MoGtest_default +using ..SamplerTestUtils: test_chain_logp_metadata using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial using Distributions: Bernoulli, Beta, Gamma, Normal, sample using Random: Random +using StableRNGs: StableRNG using Test: @test, @test_throws, @testset using Turing @@ -49,6 +50,10 @@ using Turing @test_throws ErrorException sample(fail_smc(), SMC(), 100) end + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SMC()) + end + @testset "logevidence" begin Random.seed!(100) @@ -65,7 +70,10 @@ using Turing chains_smc = sample(test(), SMC(), 100) @test all(isone, chains_smc[:x]) + # the chain itself has a logevidence field @test chains_smc.logevidence ≈ -2 * log(2) + # but each transition also contains the logevidence + @test chains_smc[:logevidence] ≈ fill(chains_smc.logevidence, 100) end end @@ -88,6 +96,10 @@ end @test s.resampler === resample_systematic end + @testset "chain log-density metadata" begin + test_chain_logp_metadata(PG(10)) + end + @testset "logevidence" begin Random.seed!(100) @@ -105,6 +117,7 @@ end @test all(isone, chains_pg[:x]) @test chains_pg.logevidence ≈ -2 * log(2) atol = 0.01 + @test chains_pg[:logevidence] ≈ fill(chains_pg.logevidence, 100) end # https://github.com/TuringLang/Turing.jl/issues/1598 @@ -114,6 +127,24 @@ end @test length(unique(c[:s])) == 1 end + @testset "addlogprob leads to reweighting" begin + # Make sure that PG takes @addlogprob! into account. It didn't use to: + # https://github.com/TuringLang/Turing.jl/issues/1996 + @model function addlogprob_demo() + x ~ Normal(0, 1) + if x < 0 + @addlogprob! -10.0 + else + # Need a balanced number of addlogprobs in all branches, or + # else PG will error + @addlogprob! 0.0 + end + end + c = sample(StableRNG(468), addlogprob_demo(), PG(10), 100) + # Result should be biased towards x > 0. + @test mean(c[:x]) > 0.7 + end + # https://github.com/TuringLang/Turing.jl/issues/2007 @testset "keyword arguments not supported" begin @model kwarg_demo(; x=2) = return x diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 1671362ed..ee943270c 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -2,6 +2,7 @@ module SGHMCTests using ..Models: gdemo_default using ..NumericalTests: check_gdemo +using ..SamplerTestUtils: test_chain_logp_metadata using DynamicPPL.TestUtils.AD: run_ad using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL: DynamicPPL @@ -32,6 +33,10 @@ using Turing chain = sample(rng, gdemo_default, alg, 10_000) check_gdemo(chain; atol=0.1) end + + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SGHMC(; learning_rate=0.02, momentum_decay=0.5)) + end end @testset "Testing sgld.jl" begin @@ -46,6 +51,7 @@ end sampler = DynamicPPL.Sampler(alg) @test sampler isa DynamicPPL.Sampler{<:SGLD} end + @testset "sgld inference" begin rng = StableRNG(1) @@ -59,6 +65,10 @@ end @test s_weighted ≈ 49 / 24 atol = 0.2 @test m_weighted ≈ 7 / 6 atol = 0.2 end + + @testset "chain log-density metadata" begin + test_chain_logp_metadata(SGLD(; stepsize=PolynomialStepsize(0.25))) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 9fec2f737..5fb6b2141 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ seed!(23) include("test_utils/models.jl") include("test_utils/numerical_tests.jl") +include("test_utils/sampler.jl") Turing.setprogress!(false) included_paths, excluded_paths = parse_args(ARGS) diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl new file mode 100644 index 000000000..32a3647f9 --- /dev/null +++ b/test/test_utils/sampler.jl @@ -0,0 +1,27 @@ +module SamplerTestUtils + +using Turing +using Test + +""" +Check that when sampling with `spl`, the resulting chain contains log-density +metadata that is correct. +""" +function test_chain_logp_metadata(spl) + @model function f() + # some prior term (but importantly, one that is constrained, i.e., can + # be linked with non-identity transform) + x ~ LogNormal() + # some likelihood term + return 1.0 ~ Normal(x) + end + chn = sample(f(), spl, 100) + # Check that the log-prior term is calculated in unlinked space. + @test chn[:logprior] ≈ logpdf.(LogNormal(), chn[:x]) + @test chn[:loglikelihood] ≈ logpdf.(Normal.(chn[:x]), 1.0) + # This should always be true, but it also indirectly checks that the + # log-joint is also calculated in unlinked space. + @test chn[:lp] ≈ chn[:logprior] + chn[:loglikelihood] +end + +end