Skip to content

Commit 1bc2fbf

Browse files
penelopeysmmhauru
andauthored
Unify Turing Transitions, fix some tests (#2651)
* Unify `Transition` methods * Add tests * Add same test for SGLD/SGHMC * Refactor so that it's nice and organised * Fix failing test on 1.10 * just increase the atol * Make addlogprob test more robust * Remove stray `@show` Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]>
1 parent bb21e1e commit 1bc2fbf

File tree

8 files changed

+114
-82
lines changed

8 files changed

+114
-82
lines changed

src/mcmc/Inference.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,16 @@ end
124124
# Default Transition #
125125
######################
126126
getstats(::Any) = NamedTuple()
127+
getstats(nt::NamedTuple) = nt
127128

128-
# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
129-
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130-
abstract type AbstractTransition end
131-
132-
struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
129+
struct Transition{T,F<:AbstractFloat,N<:NamedTuple}
133130
θ::T
134131
logprior::F
135132
loglikelihood::F
136133
stat::N
137134

138135
"""
139-
Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true)
136+
Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true)
140137
141138
Construct a new `Turing.Inference.Transition` object using the outputs of a
142139
sampler step.
@@ -146,8 +143,10 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
146143
have junk contents. The role of this method is to re-evaluate `model` and
147144
thus set the accumulators to the correct values.
148145
149-
`sampler_transition` is the transition object returned by the sampler
150-
itself and is only used to extract statistics of interest.
146+
`stats` is any object on which `Turing.Inference.getstats` can be called to
147+
return a NamedTuple of statistics. This could be, for example, the transition
148+
returned by an (unwrapped) external sampler. Or alternatively, it could
149+
simply be a NamedTuple itself (for which `getstats` acts as the identity).
151150
152151
By default, the model is re-evaluated in order to obtain values of:
153152
- the values of the parameters as per user parameterisation (`vals_as_in_model`)
@@ -167,8 +166,11 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
167166
must be set up to track `x := y` statements.
168167
"""
169168
function Transition(
170-
model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true
169+
model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true
171170
)
171+
# Avoid mutating vi as it may be used later e.g. when constructing
172+
# sampler states.
173+
vi = deepcopy(vi)
172174
if reevaluate
173175
vi = DynamicPPL.setaccs!!(
174176
vi,
@@ -187,7 +189,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
187189
loglikelihood = DynamicPPL.getloglikelihood(vi)
188190

189191
# Get additional statistics
190-
stats = getstats(sampler_transition)
192+
stats = getstats(stats)
191193
return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}(
192194
vals_as_in_model, logprior, loglikelihood, stats
193195
)
@@ -196,17 +198,14 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
196198
function Transition(
197199
model::DynamicPPL.Model,
198200
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
199-
sampler_transition;
201+
stats;
200202
reevaluate=true,
201203
)
202204
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
203205
# much faster to convert it to a typed varinfo first, hence this method.
204206
# https://github.com/TuringLang/Turing.jl/issues/2604
205207
return Transition(
206-
model,
207-
DynamicPPL.typed_varinfo(untyped_vi),
208-
sampler_transition;
209-
reevaluate=reevaluate,
208+
model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate
210209
)
211210
end
212211
end
@@ -318,7 +317,7 @@ getlogevidence(transitions, sampler, state) = missing
318317
# Default MCMCChains.Chains constructor.
319318
# This is type piracy (at least for SampleFromPrior).
320319
function AbstractMCMC.bundle_samples(
321-
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
320+
ts::Vector{<:Union{Transition,AbstractVarInfo}},
322321
model::AbstractModel,
323322
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
324323
state,
@@ -381,7 +380,7 @@ end
381380

382381
# This is type piracy (for SampleFromPrior).
383382
function AbstractMCMC.bundle_samples(
384-
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
383+
ts::Vector{<:Union{Transition,AbstractVarInfo}},
385384
model::AbstractModel,
386385
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler},
387386
state,

src/mcmc/particle_mcmc.jl

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -135,23 +135,6 @@ function SMC(threshold::Real)
135135
return SMC(AdvancedPS.resample_systematic, threshold)
136136
end
137137

138-
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
139-
"The parameters for any given sample."
140-
θ::T
141-
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
142-
lp::F
143-
"The weight of the particle the sample was retrieved from."
144-
weight::F
145-
end
146-
147-
function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight)
148-
theta = getparams(model, vi)
149-
lp = DynamicPPL.getlogjoint_internal(vi)
150-
return SMCTransition(theta, lp, weight)
151-
end
152-
153-
getstats_with_lp(t::SMCTransition) = (lp=t.lp, weight=t.weight)
154-
155138
struct SMCState{P,F<:AbstractFloat}
156139
particles::P
157140
particleindex::Int
@@ -228,7 +211,8 @@ function DynamicPPL.initialstep(
228211
weight = AdvancedPS.getweight(particles, 1)
229212

230213
# Compute the first transition and the first state.
231-
transition = SMCTransition(model, particle.model.f.varinfo, weight)
214+
stats = (; weight=weight, logevidence=logevidence)
215+
transition = Transition(model, particle.model.f.varinfo, stats)
232216
state = SMCState(particles, 2, logevidence)
233217

234218
return transition, state
@@ -246,7 +230,8 @@ function AbstractMCMC.step(
246230
weight = AdvancedPS.getweight(particles, index)
247231

248232
# Compute the transition and the next state.
249-
transition = SMCTransition(model, particle.model.f.varinfo, weight)
233+
stats = (; weight=weight, logevidence=state.average_logevidence)
234+
transition = Transition(model, particle.model.f.varinfo, stats)
250235
nextstate = SMCState(state.particles, index + 1, state.average_logevidence)
251236

252237
return transition, nextstate
@@ -300,32 +285,28 @@ Equivalent to [`PG`](@ref).
300285
"""
301286
const CSMC = PG # type alias of PG as Conditional SMC
302287

303-
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
304-
"The parameters for any given sample."
305-
θ::T
306-
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
307-
lp::F
308-
"The log evidence of the sample."
309-
logevidence::F
310-
end
311-
312288
struct PGState
313289
vi::AbstractVarInfo
314290
rng::Random.AbstractRNG
315291
end
316292

317293
get_varinfo(state::PGState) = state.vi
318294

319-
function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence)
320-
theta = getparams(model, vi)
321-
lp = DynamicPPL.getlogjoint_internal(vi)
322-
return PGTransition(theta, lp, logevidence)
323-
end
324-
325-
getstats_with_lp(t::PGTransition) = (lp=t.lp, logevidence=t.logevidence)
326-
327-
function getlogevidence(samples, sampler::Sampler{<:PG}, state::PGState)
328-
return mean(x.logevidence for x in samples)
295+
function getlogevidence(
296+
transitions::AbstractVector{<:Turing.Inference.Transition},
297+
sampler::Sampler{<:PG},
298+
state::PGState,
299+
)
300+
logevidences = map(transitions) do t
301+
if haskey(t.stat, :logevidence)
302+
return t.stat.logevidence
303+
else
304+
# This should not really happen, but if it does we can handle it
305+
# gracefully
306+
return missing
307+
end
308+
end
309+
return mean(logevidences)
329310
end
330311

331312
function DynamicPPL.initialstep(
@@ -357,7 +338,7 @@ function DynamicPPL.initialstep(
357338

358339
# Compute the first transition.
359340
_vi = reference.model.f.varinfo
360-
transition = PGTransition(model, _vi, logevidence)
341+
transition = Transition(model, _vi, (; logevidence=logevidence))
361342

362343
return transition, PGState(_vi, reference.rng)
363344
end
@@ -397,7 +378,7 @@ function AbstractMCMC.step(
397378

398379
# Compute the transition.
399380
_vi = newreference.model.f.varinfo
400-
transition = PGTransition(model, _vi, logevidence)
381+
transition = Transition(model, _vi, (; logevidence=logevidence))
401382

402383
return transition, PGState(_vi, newreference.rng)
403384
end

src/mcmc/sghmc.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -184,23 +184,6 @@ function SGLD(;
184184
return SGLD(stepsize, adtype)
185185
end
186186

187-
struct SGLDTransition{T,F<:Real} <: AbstractTransition
188-
"The parameters for any given sample."
189-
θ::T
190-
"The joint log probability of the sample."
191-
lp::F
192-
"The stepsize that was used to obtain the sample."
193-
stepsize::F
194-
end
195-
196-
function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize)
197-
theta = getparams(model, vi)
198-
lp = DynamicPPL.getlogjoint_internal(vi)
199-
return SGLDTransition(theta, lp, stepsize)
200-
end
201-
202-
getstats_with_lp(t::SGLDTransition) = (lp=t.lp, SGLD_stepsize=t.stepsize)
203-
204187
struct SGLDState{L,V<:AbstractVarInfo}
205188
logdensity::L
206189
vi::V
@@ -220,13 +203,13 @@ function DynamicPPL.initialstep(
220203
end
221204

222205
# Create first sample and state.
223-
sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0)))
206+
transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0))))
224207
= DynamicPPL.LogDensityFunction(
225208
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
226209
)
227210
state = SGLDState(ℓ, vi, 1)
228211

229-
return sample, state
212+
return transition, state
230213
end
231214

232215
function AbstractMCMC.step(
@@ -245,8 +228,8 @@ function AbstractMCMC.step(
245228
vi = DynamicPPL.unflatten(vi, θ)
246229

247230
# Compute next sample and state.
248-
sample = SGLDTransition(model, vi, stepsize)
231+
transition = Transition(model, vi, (; SGLD_stepsize=stepsize))
249232
newstate = SGLDState(ℓ, vi, state.step + 1)
250233

251-
return sample, newstate
234+
return transition, newstate
252235
end

test/mcmc/gibbs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,8 @@ end
598598
means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0)
599599
stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0)
600600
for vn in keys(means)
601-
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1)
602-
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1)
601+
@test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.15)
602+
@test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.15)
603603
end
604604
end
605605

@@ -651,7 +651,7 @@ end
651651
chain = sample(
652652
StableRNG(468),
653653
model,
654-
Gibbs(:b => PG(10), :x => ESS()),
654+
Gibbs(:b => PG(20), :x => ESS()),
655655
2000;
656656
discard_initial=100,
657657
)

test/mcmc/particle_mcmc.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module ParticleMCMCTests
22

33
using ..Models: gdemo_default
4-
#using ..Models: MoGtest, MoGtest_default
4+
using ..SamplerTestUtils: test_chain_logp_metadata
55
using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial
66
using Distributions: Bernoulli, Beta, Gamma, Normal, sample
77
using Random: Random
8+
using StableRNGs: StableRNG
89
using Test: @test, @test_throws, @testset
910
using Turing
1011

@@ -49,6 +50,10 @@ using Turing
4950
@test_throws ErrorException sample(fail_smc(), SMC(), 100)
5051
end
5152

53+
@testset "chain log-density metadata" begin
54+
test_chain_logp_metadata(SMC())
55+
end
56+
5257
@testset "logevidence" begin
5358
Random.seed!(100)
5459

@@ -65,7 +70,10 @@ using Turing
6570
chains_smc = sample(test(), SMC(), 100)
6671

6772
@test all(isone, chains_smc[:x])
73+
# the chain itself has a logevidence field
6874
@test chains_smc.logevidence -2 * log(2)
75+
# but each transition also contains the logevidence
76+
@test chains_smc[:logevidence] fill(chains_smc.logevidence, 100)
6977
end
7078
end
7179

@@ -88,6 +96,10 @@ end
8896
@test s.resampler === resample_systematic
8997
end
9098

99+
@testset "chain log-density metadata" begin
100+
test_chain_logp_metadata(PG(10))
101+
end
102+
91103
@testset "logevidence" begin
92104
Random.seed!(100)
93105

@@ -105,6 +117,7 @@ end
105117

106118
@test all(isone, chains_pg[:x])
107119
@test chains_pg.logevidence -2 * log(2) atol = 0.01
120+
@test chains_pg[:logevidence] fill(chains_pg.logevidence, 100)
108121
end
109122

110123
# https://github.com/TuringLang/Turing.jl/issues/1598
@@ -114,6 +127,24 @@ end
114127
@test length(unique(c[:s])) == 1
115128
end
116129

130+
@testset "addlogprob leads to reweighting" begin
131+
# Make sure that PG takes @addlogprob! into account. It didn't use to:
132+
# https://github.com/TuringLang/Turing.jl/issues/1996
133+
@model function addlogprob_demo()
134+
x ~ Normal(0, 1)
135+
if x < 0
136+
@addlogprob! -10.0
137+
else
138+
# Need a balanced number of addlogprobs in all branches, or
139+
# else PG will error
140+
@addlogprob! 0.0
141+
end
142+
end
143+
c = sample(StableRNG(468), addlogprob_demo(), PG(10), 100)
144+
# Result should be biased towards x > 0.
145+
@test mean(c[:x]) > 0.7
146+
end
147+
117148
# https://github.com/TuringLang/Turing.jl/issues/2007
118149
@testset "keyword arguments not supported" begin
119150
@model kwarg_demo(; x=2) = return x

test/mcmc/sghmc.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module SGHMCTests
22

33
using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
5+
using ..SamplerTestUtils: test_chain_logp_metadata
56
using DynamicPPL.TestUtils.AD: run_ad
67
using DynamicPPL.TestUtils: DEMO_MODELS
78
using DynamicPPL: DynamicPPL
@@ -32,6 +33,10 @@ using Turing
3233
chain = sample(rng, gdemo_default, alg, 10_000)
3334
check_gdemo(chain; atol=0.1)
3435
end
36+
37+
@testset "chain log-density metadata" begin
38+
test_chain_logp_metadata(SGHMC(; learning_rate=0.02, momentum_decay=0.5))
39+
end
3540
end
3641

3742
@testset "Testing sgld.jl" begin
@@ -46,6 +51,7 @@ end
4651
sampler = DynamicPPL.Sampler(alg)
4752
@test sampler isa DynamicPPL.Sampler{<:SGLD}
4853
end
54+
4955
@testset "sgld inference" begin
5056
rng = StableRNG(1)
5157

@@ -59,6 +65,10 @@ end
5965
@test s_weighted 49 / 24 atol = 0.2
6066
@test m_weighted 7 / 6 atol = 0.2
6167
end
68+
69+
@testset "chain log-density metadata" begin
70+
test_chain_logp_metadata(SGLD(; stepsize=PolynomialStepsize(0.25)))
71+
end
6272
end
6373

6474
end

0 commit comments

Comments
 (0)