Skip to content

Commit f372768

Browse files
committed
Enforce re-evaluation when constructing Transition
1 parent 57cb82e commit f372768

File tree

12 files changed

+151
-178
lines changed

12 files changed

+151
-178
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,8 @@ function DynamicPPL.initialstep(
7373
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
7474
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
7575

76-
# Update the variables.
77-
vi = DynamicPPL.unflatten(vi, Q.q)
78-
# TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this.
79-
vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq)
80-
vi = DynamicPPL.setlogprior!!(vi, 0.0)
81-
8276
# Create first sample and state.
83-
sample = Turing.Inference.Transition(model, vi)
77+
sample = Turing.Inference.Transition(model, vi, Q.q, nothing)
8478
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
8579

8680
return sample, state
@@ -99,12 +93,8 @@ function AbstractMCMC.step(
9993
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
10094
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
10195

102-
# Update the variables.
103-
vi = DynamicPPL.unflatten(vi, Q.q)
104-
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
105-
10696
# Create next sample and state.
107-
sample = Turing.Inference.Transition(model, vi)
97+
sample = Turing.Inference.Transition(model, vi, Q.q, nothing)
10898
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
10999

110100
return sample, newstate

src/mcmc/Inference.jl

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -124,85 +124,124 @@ end
124124
######################
125125
# Default Transition #
126126
######################
127-
# Default
128-
getstats(t) = nothing
127+
getstats(::Any) = NamedTuple()
129128

129+
# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
130+
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131
abstract type AbstractTransition end
131132

132-
struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition
133+
struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
133134
θ::T
134-
lp::F # TODO: merge `lp` with `stat`
135-
stat::S
136-
end
137-
138-
Transition(θ, lp) = Transition(θ, lp, nothing)
139-
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
140-
# TODO(DPPL0.37/penelopeysm): Fix this
141-
θ = getparams(model, vi)
142-
lp = getlogjoint_internal(vi)
143-
return Transition(θ, lp, getstats(t))
135+
logprior::F
136+
loglikelihood::F
137+
stat::N
138+
139+
"""
140+
Transition(model::Model, vi::AbstractVarInfo, params::AbstractVector, sampler_transition)
141+
142+
Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step.
143+
144+
Here, `vi` represents a VarInfo which in general may have junk contents (both
145+
parameters and accumulators). The role of this method is to re-evaluate `model` by inserting
146+
the new `params` (provided by the sampler) into the VarInfo `vi`.
147+
148+
`sampler_transition` is the transition object returned by the sampler itself and is only used
149+
to extract statistics of interest.
150+
151+
!!! warning "Parameters must match varinfo linking status"
152+
It is mandatory that the vector of parameters provided line up exactly with how the
153+
VarInfo `vi` is linked. Otherwise, this can silently produce incorrect results.
154+
"""
155+
function Transition(
156+
model::DynamicPPL.Model,
157+
vi::AbstractVarInfo,
158+
parameters::AbstractVector,
159+
sampler_transition,
160+
)
161+
# To be safe...
162+
vi = deepcopy(vi)
163+
# Set the parameters and re-evaluate with the appropriate accumulators
164+
vi = DynamicPPL.unflatten(vi, parameters)
165+
vi = DynamicPPL.setaccs!!(
166+
vi,
167+
(
168+
DynamicPPL.ValuesAsInModelAccumulator(true),
169+
DynamicPPL.LogPriorAccumulator(),
170+
DynamicPPL.LogLikelihoodAccumulator(),
171+
),
172+
)
173+
_, vi = DynamicPPL.evaluate!!(model, vi)
174+
175+
# Extract all the information we need
176+
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
177+
logprior = DynamicPPL.getlogprior(vi)
178+
loglikelihood = DynamicPPL.getloglikelihood(vi)
179+
180+
# Convert values to the format needed (with individual VarNames split up).
181+
# TODO(penelopeysm): This wouldn't be necessary if not for MCMCChains's poor
182+
# representation...
183+
iters = map(
184+
DynamicPPL.varname_and_value_leaves,
185+
keys(vals_as_in_model),
186+
values(vals_as_in_model),
187+
)
188+
values_split = mapreduce(collect, vcat, iters)
189+
190+
# Get additional statistics
191+
stats = getstats(sampler_transition)
192+
return new{typeof(values_split),typeof(logprior),typeof(stats)}(
193+
values_split, logprior, loglikelihood, stats
194+
)
195+
end
196+
function Transition(
197+
model::DynamicPPL.Model,
198+
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
199+
parameters::AbstractVector,
200+
sampler_transition,
201+
)
202+
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
203+
# much faster to convert it to a typed varinfo first, hence this method.
204+
# https://github.com/TuringLang/Turing.jl/issues/2604
205+
return Transition(
206+
model, DynamicPPL.typed_varinfo(untyped_vi), parameters, sampler_transition
207+
)
208+
end
144209
end
145210

146-
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147211
function metadata(t::Transition)
148-
stat = t.stat
149-
if stat === nothing
150-
return (lp=t.lp,)
151-
else
152-
return merge((lp=t.lp,), stat)
153-
end
212+
return merge(
213+
t.stat,
214+
(
215+
lp=t.logprior + t.loglikelihood,
216+
logprior=t.logprior,
217+
loglikelihood=t.loglikelihood,
218+
),
219+
)
220+
end
221+
function metadata(vi::AbstractVarInfo)
222+
return (
223+
lp=DynamicPPL.getlogjoint(vi),
224+
logprior=DynamicPPL.getlogp(vi),
225+
loglikelihood=DynamicPPL.getloglikelihood(vi),
226+
)
154227
end
155-
156-
# TODO(DPPL0.37/penelopeysm): Fix this
157-
DynamicPPL.getlogjoint(t::Transition) = t.lp
158-
159-
# Metadata of VarInfo object
160-
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161-
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
162228

163229
##########################
164230
# Chain making utilities #
165231
##########################
166232

167-
"""
168-
getparams(model, t)
169-
170-
Return a named tuple of parameters.
171-
"""
172-
getparams(model, t) = t.θ
173-
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
174-
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175-
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176-
# of the parameters change depending on the realizations. Hence we have to use
177-
# `values_as_in_model`, which re-runs the model and extracts the parameters
178-
# as they are seen in the model, i.e. in the constrained space. Moreover,
179-
# this means that the code below will work both of linked and invlinked `vi`.
180-
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181-
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182-
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
183-
184-
# Obtain an iterator over the flattened parameter names and values.
185-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
186-
187-
# Materialize the iterators and concatenate.
188-
return mapreduce(collect, vcat, iters)
233+
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
234+
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
235+
t = Transition(model, vi, vi[:], nothing)
236+
return getparams(model, t)
189237
end
190-
function getparams(
191-
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
192-
)
193-
# values_as_in_model is unconscionably slow for untyped VarInfo. It's
194-
# much faster to convert it to a typed varinfo before calling getparams.
195-
# https://github.com/TuringLang/Turing.jl/issues/2604
196-
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
197-
end
198-
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
199-
return float(Real)[]
200-
end
201-
202238
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
203239
names_set = OrderedSet{VarName}()
204240
# Extract the parameter names and values from each transition.
205241
dicts = map(ts) do t
242+
# TODO(penelopeysm): Get rid of AbstractVarInfo transitions. see
243+
# https://github.com/TuringLang/Turing.jl/issues/2631. That would
244+
# allow us to just use t.θ here.
206245
nms_and_vs = getparams(model, t)
207246
nms = map(first, nms_and_vs)
208247
vs = map(last, nms_and_vs)
@@ -221,7 +260,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
221260
end
222261

223262
function get_transition_extras(ts::AbstractVector{<:VarInfo})
224-
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
263+
valmat = reshape([DynamicPPL.getlogjoint(t) for t in ts], :, 1)
225264
return [:lp], valmat
226265
end
227266

src/mcmc/emcee.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,12 @@ function AbstractMCMC.step(
9292
)
9393

9494
# Compute the next states.
95-
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
95+
t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)
9696

9797
# Compute the next transition and state.
9898
transition = map(states) do _state
9999
vi = DynamicPPL.unflatten(vi, _state.params)
100-
t = Transition(getparams(model, vi), _state.lp)
101-
return t
100+
return Transition(model, vi, _state.params, t)
102101
end
103102
newstate = EmceeState(vi, states)
104103

src/mcmc/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DynamicPPL.initialstep(
3131
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3232
error("ESS only supports Gaussian prior distributions")
3333
end
34-
return Transition(model, vi), vi
34+
return Transition(model, vi, vi[:], nothing), vi
3535
end
3636

3737
function AbstractMCMC.step(
@@ -56,7 +56,7 @@ function AbstractMCMC.step(
5656
vi = DynamicPPL.unflatten(vi, sample)
5757
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

59-
return Transition(model, vi), vi
59+
return Transition(model, vi, vi[:], nothing), vi
6060
end
6161

6262
# Prior distribution of considered random variable

src/mcmc/external_sampler.jl

Lines changed: 10 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ There are a few more optional functions which you can implement to improve the i
2222
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
2323
2424
- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space.
25-
26-
- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method.
2725
"""
2826
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
2927
InferenceAlgorithm
@@ -85,26 +83,14 @@ function externalsampler(
8583
return ExternalSampler(sampler, adtype, Val(unconstrained))
8684
end
8785

88-
"""
89-
getlogp_external(external_transition, external_state)
90-
91-
Get the log probability density associated with the external sampler's
92-
transition and state. Returns `missing` by default; in this case, an extra
93-
model evaluation will be needed to calculate the correct log density.
94-
"""
95-
getlogp_external(::Any, ::Any) = missing
96-
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
97-
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density
98-
99-
struct TuringState{S,V1<:AbstractVarInfo,M,V}
86+
struct TuringState{S,M,V}
10087
state::S
101-
# Note that this varinfo has the correct parameters and logp obtained from
102-
# the state, whereas `ldf.varinfo` will in general have junk inside it.
103-
varinfo::V1
88+
# Note that in general the VarInfo inside this LogDensityFunction will have
89+
# junk parameters and logp. It only exists to provide structure
10490
ldf::DynamicPPL.LogDensityFunction{M,V}
10591
end
10692

107-
varinfo(state::TuringState) = state.varinfo
93+
varinfo(state::TuringState) = state.ldf.varinfo
10894
varinfo(state::AbstractVarInfo) = state
10995

11096
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
@@ -115,27 +101,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
115101

116102
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
117103

118-
function make_updated_varinfo(
119-
f::DynamicPPL.LogDensityFunction, external_transition, external_state
120-
)
121-
# Set the parameters.
122-
new_parameters = getparams(f.model, external_state)
123-
new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters)
124-
# Set (or recalculate, if needed) the log density.
125-
new_logp = getlogp_external(external_transition, external_state)
126-
return if ismissing(new_logp)
127-
last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context))
128-
else
129-
# TODO(DPPL0.37/penelopeysm) This is obviously wrong. Note that we
130-
# have the same problem here as in HMC in that the sampler doesn't
131-
# tell us about how logp is broken down into prior and likelihood.
132-
# We should probably just re-evaluate unconditionally. A bit
133-
# unfortunate.
134-
DynamicPPL.setlogprior!!(new_varinfo, 0.0)
135-
DynamicPPL.setloglikelihood!!(new_varinfo, new_logp)
136-
end
137-
end
138-
139104
# TODO: Do we also support `resume`, etc?
140105
function AbstractMCMC.step(
141106
rng::Random.AbstractRNG,
@@ -182,13 +147,10 @@ function AbstractMCMC.step(
182147
)
183148
end
184149

185-
# Get the parameters and log density, and set them in the varinfo.
186-
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
187-
188-
# Update the `state`
150+
new_parameters = getparams(f.model, state_inner)
189151
return (
190-
Transition(f.model, new_varinfo, transition_inner),
191-
TuringState(state_inner, new_varinfo, f),
152+
Transition(f.model, varinfo, new_parameters, transition_inner),
153+
TuringState(state_inner, f),
192154
)
193155
end
194156

@@ -207,12 +169,9 @@ function AbstractMCMC.step(
207169
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
208170
)
209171

210-
# Get the parameters and log density, and set them in the varinfo.
211-
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
212-
213-
# Update the `state`
172+
new_parameters = getparams(f.model, state_inner)
214173
return (
215-
Transition(f.model, new_varinfo, transition_inner),
216-
TuringState(state_inner, new_varinfo, f),
174+
Transition(f.model, varinfo, new_parameters, transition_inner),
175+
TuringState(state_inner, f),
217176
)
218177
end

src/mcmc/gibbs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function AbstractMCMC.step(
389389
initial_params=initial_params,
390390
kwargs...,
391391
)
392-
return Transition(model, vi), GibbsState(vi, states)
392+
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
393393
end
394394

395395
function AbstractMCMC.step_warmup(
@@ -414,7 +414,7 @@ function AbstractMCMC.step_warmup(
414414
initial_params=initial_params,
415415
kwargs...,
416416
)
417-
return Transition(model, vi), GibbsState(vi, states)
417+
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
418418
end
419419

420420
"""
@@ -502,7 +502,7 @@ function AbstractMCMC.step(
502502
vi, states = gibbs_step_recursive(
503503
rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs...
504504
)
505-
return Transition(model, vi), GibbsState(vi, states)
505+
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
506506
end
507507

508508
function AbstractMCMC.step_warmup(
@@ -522,7 +522,7 @@ function AbstractMCMC.step_warmup(
522522
vi, states = gibbs_step_recursive(
523523
rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs...
524524
)
525-
return Transition(model, vi), GibbsState(vi, states)
525+
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
526526
end
527527

528528
"""

0 commit comments

Comments
 (0)