Skip to content

Commit d92fd56

Browse files
authored
Do not re-evaluate model for Prior (#2644)
* Allow Prior to skip model re-evaluation * remove unneeded `default_chain_type` method * add a test * add a likelihood term too * why not test correctness while we're at it
1 parent b41a4b1 commit d92fd56

File tree

3 files changed

+74
-17
lines changed

3 files changed

+74
-17
lines changed

src/mcmc/Inference.jl

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
136136
stat::N
137137

138138
"""
139-
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
139+
Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true)
140140
141141
Construct a new `Turing.Inference.Transition` object using the outputs of a
142142
sampler step.
@@ -148,17 +148,38 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
148148
149149
`sampler_transition` is the transition object returned by the sampler
150150
itself and is only used to extract statistics of interest.
151+
152+
By default, the model is re-evaluated in order to obtain values of:
153+
- the values of the parameters as per user parameterisation (`vals_as_in_model`)
154+
- the various components of the log joint probability (`logprior`, `loglikelihood`)
155+
that are guaranteed to be correct.
156+
157+
If you **know** for a fact that the VarInfo `vi` already contains this information,
158+
then you can set `reevaluate=false` to skip the re-evaluation step.
159+
160+
!!! warning
161+
Note that in general this is unsafe and may lead to wrong results.
162+
163+
If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
164+
the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
165+
and `LogLikelihoodAccumulator` set up with the correct values. Note that the
166+
`ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
167+
must be set up to track `x := y` statements.
151168
"""
152-
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
153-
vi = DynamicPPL.setaccs!!(
154-
vi,
155-
(
156-
DynamicPPL.ValuesAsInModelAccumulator(true),
157-
DynamicPPL.LogPriorAccumulator(),
158-
DynamicPPL.LogLikelihoodAccumulator(),
159-
),
160-
)
161-
_, vi = DynamicPPL.evaluate!!(model, vi)
169+
function Transition(
170+
model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true
171+
)
172+
if reevaluate
173+
vi = DynamicPPL.setaccs!!(
174+
vi,
175+
(
176+
DynamicPPL.ValuesAsInModelAccumulator(true),
177+
DynamicPPL.LogPriorAccumulator(),
178+
DynamicPPL.LogLikelihoodAccumulator(),
179+
),
180+
)
181+
_, vi = DynamicPPL.evaluate!!(model, vi)
182+
end
162183

163184
# Extract all the information we need
164185
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
@@ -175,12 +196,18 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
175196
function Transition(
176197
model::DynamicPPL.Model,
177198
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
178-
sampler_transition,
199+
sampler_transition;
200+
reevaluate=true,
179201
)
180202
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181203
# much faster to convert it to a typed varinfo first, hence this method.
182204
# https://github.com/TuringLang/Turing.jl/issues/2604
183-
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
205+
return Transition(
206+
model,
207+
DynamicPPL.typed_varinfo(untyped_vi),
208+
sampler_transition;
209+
reevaluate=reevaluate,
210+
)
184211
end
185212
end
186213

src/mcmc/prior.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ function AbstractMCMC.step(
1616
sampling_model = DynamicPPL.contextualize(
1717
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
1818
)
19-
_, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo())
20-
return Transition(model, vi, nothing), nothing
19+
vi = VarInfo()
20+
vi = DynamicPPL.setaccs!!(
21+
vi,
22+
(
23+
DynamicPPL.ValuesAsInModelAccumulator(true),
24+
DynamicPPL.LogPriorAccumulator(),
25+
DynamicPPL.LogLikelihoodAccumulator(),
26+
),
27+
)
28+
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
29+
return Transition(model, vi, nothing; reevaluate=false), nothing
2130
end
22-
23-
DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains

test/mcmc/Inference.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,29 @@ using Turing
142142
@test mean(x[:s][1] for x in chains) 3 atol = 0.11
143143
@test mean(x[:m][1] for x in chains) 0 atol = 0.1
144144
end
145+
146+
@testset "accumulators are set correctly" begin
147+
# Prior() uses `reevaluate=false` when constructing a
148+
# `Turing.Inference.Transition`, so we had better make sure that it
149+
# does capture colon-eq statements, as we can't rely on the default
150+
# `Transition` constructor to do this for us.
151+
@model function coloneq()
152+
x ~ Normal()
153+
10.0 ~ Normal(x)
154+
z := 1.0
155+
return nothing
156+
end
157+
chain = sample(coloneq(), Prior(), N)
158+
@test chain isa MCMCChains.Chains
159+
@test all(x -> x == 1.0, chain[:z])
160+
# And for the same reason we should also make sure that the logp
161+
# components are correctly calculated.
162+
@test isapprox(chain[:logprior], logpdf.(Normal(), chain[:x]))
163+
@test isapprox(chain[:loglikelihood], logpdf.(Normal.(chain[:x]), 10.0))
164+
@test isapprox(chain[:lp], chain[:logprior] .+ chain[:loglikelihood])
165+
# And that the outcome is not influenced by the likelihood
166+
@test mean(chain, :x) 0.0 atol = 0.1
167+
end
145168
end
146169

147170
@testset "chain ordering" begin

0 commit comments

Comments
 (0)