Skip to content

Unify Turing Transitions, fix some tests #2651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,16 @@ end
# Default Transition #
######################
getstats(::Any) = NamedTuple()
getstats(nt::NamedTuple) = nt

# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
abstract type AbstractTransition end

struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
struct Transition{T,F<:AbstractFloat,N<:NamedTuple}
θ::T
logprior::F
loglikelihood::F
stat::N

"""
Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true)
Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true)

Construct a new `Turing.Inference.Transition` object using the outputs of a
sampler step.
Expand All @@ -146,8 +143,10 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
have junk contents. The role of this method is to re-evaluate `model` and
thus set the accumulators to the correct values.

`sampler_transition` is the transition object returned by the sampler
itself and is only used to extract statistics of interest.
`stats` is any object on which `Turing.Inference.getstats` can be called to
return a NamedTuple of statistics. This could be, for example, the transition
returned by an (unwrapped) external sampler. Or alternatively, it could
simply be a NamedTuple itself (for which `getstats` acts as the identity).

By default, the model is re-evaluated in order to obtain values of:
- the values of the parameters as per user parameterisation (`vals_as_in_model`)
Expand All @@ -167,8 +166,11 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
must be set up to track `x := y` statements.
"""
function Transition(
model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true
model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true
)
# Avoid mutating vi as it may be used later e.g. when constructing
# sampler states.
vi = deepcopy(vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could be a significant time cost. In which case we could make sure we have a proper copy method for varinfos in DPPL. Would probably be good to have that anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does copy tend to be more performant than deepcopy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it can be. Depends on your data structures. deepcopy can be slow because it plays it safe with aliasing and such, whereas copy does whatever you make it do.

if reevaluate
vi = DynamicPPL.setaccs!!(
vi,
Expand All @@ -187,7 +189,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
loglikelihood = DynamicPPL.getloglikelihood(vi)

# Get additional statistics
stats = getstats(sampler_transition)
stats = getstats(stats)
return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}(
vals_as_in_model, logprior, loglikelihood, stats
)
Expand All @@ -196,17 +198,14 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
function Transition(
model::DynamicPPL.Model,
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
sampler_transition;
stats;
reevaluate=true,
)
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
# much faster to convert it to a typed varinfo first, hence this method.
# https://github.com/TuringLang/Turing.jl/issues/2604
return Transition(
model,
DynamicPPL.typed_varinfo(untyped_vi),
sampler_transition;
reevaluate=reevaluate,
model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate
)
end
end
Expand Down Expand Up @@ -318,7 +317,7 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
ts::Vector{<:Union{Transition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
state,
Expand Down Expand Up @@ -381,7 +380,7 @@ end

# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
ts::Vector{<:Union{Transition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
state,
Expand Down
61 changes: 21 additions & 40 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,6 @@ function SMC(threshold::Real)
return SMC(AdvancedPS.resample_systematic, threshold)
end

struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
lp::F
"The weight of the particle the sample was retrieved from."
weight::F
end

function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight)
theta = getparams(model, vi)
lp = DynamicPPL.getlogjoint_internal(vi)
return SMCTransition(theta, lp, weight)
end

getstats_with_lp(t::SMCTransition) = (lp=t.lp, weight=t.weight)

struct SMCState{P,F<:AbstractFloat}
particles::P
particleindex::Int
Expand Down Expand Up @@ -228,7 +211,8 @@ function DynamicPPL.initialstep(
weight = AdvancedPS.getweight(particles, 1)

# Compute the first transition and the first state.
transition = SMCTransition(model, particle.model.f.varinfo, weight)
stats = (; weight=weight, logevidence=logevidence)
transition = Transition(model, particle.model.f.varinfo, stats)
state = SMCState(particles, 2, logevidence)

return transition, state
Expand All @@ -246,7 +230,8 @@ function AbstractMCMC.step(
weight = AdvancedPS.getweight(particles, index)

# Compute the transition and the next state.
transition = SMCTransition(model, particle.model.f.varinfo, weight)
stats = (; weight=weight, logevidence=state.average_logevidence)
transition = Transition(model, particle.model.f.varinfo, stats)
nextstate = SMCState(state.particles, index + 1, state.average_logevidence)

return transition, nextstate
Expand Down Expand Up @@ -300,32 +285,28 @@ Equivalent to [`PG`](@ref).
"""
const CSMC = PG # type alias of PG as Conditional SMC

struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
lp::F
"The log evidence of the sample."
logevidence::F
end

struct PGState
vi::AbstractVarInfo
rng::Random.AbstractRNG
end

get_varinfo(state::PGState) = state.vi

function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence)
theta = getparams(model, vi)
lp = DynamicPPL.getlogjoint_internal(vi)
return PGTransition(theta, lp, logevidence)
end

getstats_with_lp(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence)

function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState)
return mean(x.logevidence for x in samples)
function getlogevidence(
transitions::AbstractVector{<:Turing.Inference.Transition},
sampler::Sampler{<:PG},
state::PGState,
)
logevidences = map(transitions) do t
if haskey(t.stat, :logevidence)
return t.stat.logevidence
else
# This should not really happen, but if it does we can handle it
# gracefully
return missing
end
end
return mean(logevidences)
end

function DynamicPPL.initialstep(
Expand Down Expand Up @@ -357,7 +338,7 @@ function DynamicPPL.initialstep(

# Compute the first transition.
_vi = reference.model.f.varinfo
transition = PGTransition(model, _vi, logevidence)
transition = Transition(model, _vi, (; logevidence=logevidence))

return transition, PGState(_vi, reference.rng)
end
Expand Down Expand Up @@ -397,7 +378,7 @@ function AbstractMCMC.step(

# Compute the transition.
_vi = newreference.model.f.varinfo
transition = PGTransition(model, _vi, logevidence)
transition = Transition(model, _vi, (; logevidence=logevidence))

return transition, PGState(_vi, newreference.rng)
end
Expand Down
25 changes: 4 additions & 21 deletions src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,23 +184,6 @@ function SGLD(;
return SGLD(stepsize, adtype)
end

struct SGLDTransition{T,F<:Real} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample."
lp::F
"The stepsize that was used to obtain the sample."
stepsize::F
end

function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize)
theta = getparams(model, vi)
lp = DynamicPPL.getlogjoint_internal(vi)
return SGLDTransition(theta, lp, stepsize)
end

getstats_with_lp(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize)

struct SGLDState{L,V<:AbstractVarInfo}
logdensity::L
vi::V
Expand All @@ -220,13 +203,13 @@ function DynamicPPL.initialstep(
end

# Create first sample and state.
sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0)))
transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0))))
ℓ = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
)
state = SGLDState(ℓ, vi, 1)

return sample, state
return transition, state
end

function AbstractMCMC.step(
Expand All @@ -245,8 +228,8 @@ function AbstractMCMC.step(
vi = DynamicPPL.unflatten(vi, θ)

# Compute next sample and state.
sample = SGLDTransition(model, vi, stepsize)
transition = Transition(model, vi, (; SGLD_stepsize=stepsize))
newstate = SGLDState(ℓ, vi, state.step + 1)

return sample, newstate
return transition, newstate
end
6 changes: 3 additions & 3 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ end
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0)
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0)
for vn in keys(means)
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1)
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1)
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.15)
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.15)
end
end

Expand Down Expand Up @@ -651,7 +651,7 @@ end
chain = sample(
StableRNG(468),
model,
Gibbs(:b => PG(10), :x => ESS()),
Gibbs(:b => PG(20), :x => ESS()),
Copy link
Member Author

@penelopeysm penelopeysm Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was failing on 1.10 due to numerical inaccuracy. Kind of unsure why it was only failing on 1.10 but not 1.11 given that we were using StableRNGs. My first guess would be the rng splitting in AdvancedPS.

I just bumped the atol up anyway because this test is so wonky (really we're mostly checking that it samples at all, since the results are incorrect depending on interpretation of model).

2000;
discard_initial=100,
)
Expand Down
33 changes: 32 additions & 1 deletion test/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module ParticleMCMCTests

using ..Models: gdemo_default
#using ..Models: MoGtest, MoGtest_default
using ..SamplerTestUtils: test_chain_logp_metadata
using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial
using Distributions: Bernoulli, Beta, Gamma, Normal, sample
using Random: Random
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
using Turing

Expand Down Expand Up @@ -49,6 +50,10 @@ using Turing
@test_throws ErrorException sample(fail_smc(), SMC(), 100)
end

@testset "chain log-density metadata" begin
test_chain_logp_metadata(SMC())
end

@testset "logevidence" begin
Random.seed!(100)

Expand All @@ -65,7 +70,10 @@ using Turing
chains_smc = sample(test(), SMC(), 100)

@test all(isone, chains_smc[:x])
# the chain itself has a logevidence field
@test chains_smc.logevidence ≈ -2 * log(2)
# but each transition also contains the logevidence
@test chains_smc[:logevidence] ≈ fill(chains_smc.logevidence, 100)
end
end

Expand All @@ -88,6 +96,10 @@ end
@test s.resampler === resample_systematic
end

@testset "chain log-density metadata" begin
test_chain_logp_metadata(PG(10))
end

@testset "logevidence" begin
Random.seed!(100)

Expand All @@ -105,6 +117,7 @@ end

@test all(isone, chains_pg[:x])
@test chains_pg.logevidence ≈ -2 * log(2) atol = 0.01
@test chains_pg[:logevidence] ≈ fill(chains_pg.logevidence, 100)
end

# https://github.com/TuringLang/Turing.jl/issues/1598
Expand All @@ -114,6 +127,24 @@ end
@test length(unique(c[:s])) == 1
end

@testset "addlogprob leads to reweighting" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was sort of tested in PG-within-Gibbs, but we didn't have a PG-only test

# Make sure that PG takes @addlogprob! into account. It didn't use to:
# https://github.com/TuringLang/Turing.jl/issues/1996
@model function addlogprob_demo()
x ~ Normal(0, 1)
if x < 0
@addlogprob! -2.0
else
# Need a balanced number of addlogprobs in all branches, or
# else PG will error
@addlogprob! 0.0
end
end
c = sample(addlogprob_demo(), PG(10), 100)
# Result should be biased towards x > 0.
@test mean(c[:x]) > 0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test mean(c[:x]) > 0.5
@test mean(c[:x]) > 0.7

Could this get a bit more clearance from 0.5, with, if necessary, the @addloprob! value being increased? Otherwise there's 50% chance this will pass just by luck.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, x is ~ Normal() so the chance of it being larger than 0.5 is probably lower than 50%. That does assume that the chain has fully mixed though, which might not be valid with 100 iterations, and there's no harm to making the addlogprob value larger, so I'll do that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I don't know why I didn't put in StableRNGs here. Will do so too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't thinking. I mixed up the cases of x ~ Normal() and x ~ Bernoulli().

end

# https://github.com/TuringLang/Turing.jl/issues/2007
@testset "keyword arguments not supported" begin
@model kwarg_demo(; x=2) = return x
Expand Down
10 changes: 10 additions & 0 deletions test/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module SGHMCTests

using ..Models: gdemo_default
using ..NumericalTests: check_gdemo
using ..SamplerTestUtils: test_chain_logp_metadata
using DynamicPPL.TestUtils.AD: run_ad
using DynamicPPL.TestUtils: DEMO_MODELS
using DynamicPPL: DynamicPPL
Expand Down Expand Up @@ -32,6 +33,10 @@ using Turing
chain = sample(rng, gdemo_default, alg, 10_000)
check_gdemo(chain; atol=0.1)
end

@testset "chain log-density metadata" begin
test_chain_logp_metadata(SGHMC(; learning_rate=0.02, momentum_decay=0.5))
end
end

@testset "Testing sgld.jl" begin
Expand All @@ -46,6 +51,7 @@ end
sampler = DynamicPPL.Sampler(alg)
@test sampler isa DynamicPPL.Sampler{<:SGLD}
end

@testset "sgld inference" begin
rng = StableRNG(1)

Expand All @@ -59,6 +65,10 @@ end
@test s_weighted ≈ 49 / 24 atol = 0.2
@test m_weighted ≈ 7 / 6 atol = 0.2
end

@testset "chain log-density metadata" begin
test_chain_logp_metadata(SGLD(; stepsize=PolynomialStepsize(0.25)))
end
end

end
Loading
Loading