Skip to content

Commit 607b0e7

Browse files
committed
Enforce re-evaluation when constructing Transition
1 parent 9428096 commit 607b0e7

File tree

16 files changed

+194
-224
lines changed

16 files changed

+194
-224
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,9 @@ function DynamicPPL.initialstep(
7373
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
7474
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
7575

76-
# Update the variables.
77-
vi = DynamicPPL.unflatten(vi, Q.q)
78-
# TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this.
79-
vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq)
80-
vi = DynamicPPL.setlogprior!!(vi, 0.0)
81-
8276
# Create first sample and state.
83-
sample = Turing.Inference.Transition(model, vi)
77+
vi = DynamicPPL.unflatten(vi, Q.q)
78+
sample = Turing.Inference.Transition(model, vi, nothing)
8479
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
8580

8681
return sample, state
@@ -99,12 +94,9 @@ function AbstractMCMC.step(
9994
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
10095
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
10196

102-
# Update the variables.
103-
vi = DynamicPPL.unflatten(vi, Q.q)
104-
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
105-
10697
# Create next sample and state.
107-
sample = Turing.Inference.Transition(model, vi)
98+
vi = DynamicPPL.unflatten(vi, Q.q)
99+
sample = Turing.Inference.Transition(model, vi, nothing)
108100
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
109101

110102
return sample, newstate

src/mcmc/Inference.jl

Lines changed: 77 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -124,75 +124,94 @@ end
124124
######################
125125
# Default Transition #
126126
######################
127-
# Default
128-
getstats(t) = nothing
127+
getstats(::Any) = NamedTuple()
129128

129+
# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
130+
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131
abstract type AbstractTransition end
131132

132-
struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition
133+
struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
133134
θ::T
134-
lp::F # TODO: merge `lp` with `stat`
135-
stat::S
136-
end
135+
logprior::F
136+
loglikelihood::F
137+
stat::N
138+
139+
"""
140+
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141+
142+
Construct a new `Turing.Inference.Transition` object using the outputs of a
143+
sampler step.
144+
145+
Here, `vi` represents a VarInfo _for which the appropriate parameters have
146+
already been set_. However, the accumulators (e.g. logp) may in general
147+
have junk contents. The role of this method is to re-evaluate `model` and
148+
thus set the accumulators to the correct values.
149+
150+
`sampler_transition` is the transition object returned by the sampler
151+
itself and is only used to extract statistics of interest.
152+
"""
153+
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
154+
vi = DynamicPPL.setaccs!!(
155+
vi,
156+
(
157+
DynamicPPL.ValuesAsInModelAccumulator(true),
158+
DynamicPPL.LogPriorAccumulator(),
159+
DynamicPPL.LogLikelihoodAccumulator(),
160+
),
161+
)
162+
_, vi = DynamicPPL.evaluate!!(model, vi)
163+
164+
# Extract all the information we need
165+
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
166+
logprior = DynamicPPL.getlogprior(vi)
167+
loglikelihood = DynamicPPL.getloglikelihood(vi)
168+
169+
# Get additional statistics
170+
stats = getstats(sampler_transition)
171+
return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}(
172+
vals_as_in_model, logprior, loglikelihood, stats
173+
)
174+
end
137175

138-
Transition(θ, lp) = Transition(θ, lp, nothing)
139-
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
140-
# TODO(DPPL0.37/penelopeysm): Fix this
141-
θ = getparams(model, vi)
142-
lp = getlogjoint_internal(vi)
143-
return Transition(θ, lp, getstats(t))
176+
function Transition(
177+
model::DynamicPPL.Model,
178+
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
179+
sampler_transition,
180+
)
181+
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
182+
# much faster to convert it to a typed varinfo first, hence this method.
183+
# https://github.com/TuringLang/Turing.jl/issues/2604
184+
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
185+
end
144186
end
145187

146-
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147188
function metadata(t::Transition)
148-
stat = t.stat
149-
if stat === nothing
150-
return (lp=t.lp,)
151-
else
152-
return merge((lp=t.lp,), stat)
153-
end
189+
return merge(
190+
t.stat,
191+
(
192+
lp=t.logprior + t.loglikelihood,
193+
logprior=t.logprior,
194+
loglikelihood=t.loglikelihood,
195+
),
196+
)
197+
end
198+
function metadata(vi::AbstractVarInfo)
199+
return (
200+
lp=DynamicPPL.getlogjoint(vi),
201+
logprior=DynamicPPL.getlogp(vi),
202+
loglikelihood=DynamicPPL.getloglikelihood(vi),
203+
)
154204
end
155-
156-
# TODO(DPPL0.37/penelopeysm): Fix this
157-
DynamicPPL.getlogjoint(t::Transition) = t.lp
158-
159-
# Metadata of VarInfo object
160-
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161-
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
162205

163206
##########################
164207
# Chain making utilities #
165208
##########################
166209

167-
"""
168-
getparams(model, t)
169-
170-
Return a named tuple of parameters.
171-
"""
172-
getparams(model, t) = t.θ
173-
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
174-
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175-
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176-
# of the parameters change depending on the realizations. Hence we have to use
177-
# `values_as_in_model`, which re-runs the model and extracts the parameters
178-
# as they are seen in the model, i.e. in the constrained space. Moreover,
179-
# this means that the code below will work both of linked and invlinked `vi`.
180-
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181-
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182-
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
210+
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
211+
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
212+
t = Transition(model, vi, nothing)
213+
return getparams(model, t)
183214
end
184-
function getparams(
185-
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
186-
)
187-
# values_as_in_model is unconscionably slow for untyped VarInfo. It's
188-
# much faster to convert it to a typed varinfo before calling getparams.
189-
# https://github.com/TuringLang/Turing.jl/issues/2604
190-
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
191-
end
192-
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
193-
return Dict{VarName,Any}()
194-
end
195-
196215
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
197216
names_set = OrderedSet{VarName}()
198217
# Extract the parameter names and values from each transition.
@@ -208,7 +227,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
208227
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
209228
mapreduce(collect, vcat, iters)
210229
end
211-
212230
nms = map(first, nms_and_vs)
213231
vs = map(last, nms_and_vs)
214232
for nm in nms
@@ -224,7 +242,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
224242
end
225243

226244
function get_transition_extras(ts::AbstractVector{<:VarInfo})
227-
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
245+
valmat = reshape([DynamicPPL.getlogjoint(t) for t in ts], :, 1)
228246
return [:lp], valmat
229247
end
230248

@@ -466,16 +484,17 @@ function transitions_from_chain(
466484
chain::MCMCChains.Chains;
467485
sampler=DynamicPPL.SampleFromPrior(),
468486
)
469-
vi = Turing.VarInfo(model)
487+
vi = VarInfo(model)
470488

471489
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
472490
transitions = map(iters) do (sample_idx, chain_idx)
473491
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
492+
# TODO(DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
474493
DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx)
475494
model(rng, vi, sampler)
476495

477496
# Convert `VarInfo` into `NamedTuple` and save.
478-
Transition(model, vi)
497+
Transition(model, vi, nothing)
479498
end
480499

481500
return transitions

src/mcmc/emcee.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function AbstractMCMC.step(
6565
end
6666

6767
# Compute initial transition and states.
68-
transition = map(Base.Fix1(Transition, model), vis)
68+
transition = [Transition(model, vi, nothing) for vi in vis]
6969

7070
# TODO: Make compatible with immutable `AbstractVarInfo`.
7171
state = EmceeState(
@@ -92,13 +92,12 @@ function AbstractMCMC.step(
9292
)
9393

9494
# Compute the next states.
95-
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
95+
t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)
9696

9797
# Compute the next transition and state.
9898
transition = map(states) do _state
9999
vi = DynamicPPL.unflatten(vi, _state.params)
100-
t = Transition(getparams(model, vi), _state.lp)
101-
return t
100+
return Transition(model, vi, t)
102101
end
103102
newstate = EmceeState(vi, states)
104103

src/mcmc/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DynamicPPL.initialstep(
3131
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3232
error("ESS only supports Gaussian prior distributions")
3333
end
34-
return Transition(model, vi), vi
34+
return Transition(model, vi, nothing), vi
3535
end
3636

3737
function AbstractMCMC.step(
@@ -56,7 +56,7 @@ function AbstractMCMC.step(
5656
vi = DynamicPPL.unflatten(vi, sample)
5757
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

59-
return Transition(model, vi), vi
59+
return Transition(model, vi, nothing), vi
6060
end
6161

6262
# Prior distribution of considered random variable

src/mcmc/external_sampler.jl

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ There are a few more optional functions which you can implement to improve the i
2222
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
2323
2424
- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space.
25-
26-
- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method.
2725
"""
2826
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
2927
InferenceAlgorithm
@@ -85,27 +83,21 @@ function externalsampler(
8583
return ExternalSampler(sampler, adtype, Val(unconstrained))
8684
end
8785

88-
"""
89-
getlogp_external(external_transition, external_state)
90-
91-
Get the log probability density associated with the external sampler's
92-
transition and state. Returns `missing` by default; in this case, an extra
93-
model evaluation will be needed to calculate the correct log density.
94-
"""
95-
getlogp_external(::Any, ::Any) = missing
96-
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
97-
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density
98-
99-
struct TuringState{S,V1<:AbstractVarInfo,M,V}
86+
# TODO(penelopeysm): Can't we clean this up somehow?
87+
struct TuringState{S,V1,M,V}
10088
state::S
101-
# Note that this varinfo has the correct parameters and logp obtained from
102-
# the state, whereas `ldf.varinfo` will in general have junk inside it.
89+
# Note that this varinfo must have the correct parameters set; but logp
90+
# does not matter as it will be re-evaluated
10391
varinfo::V1
92+
# Note that in general the VarInfo inside this LogDensityFunction will have
93+
# junk parameters and logp. It only exists to provide structure
10494
ldf::DynamicPPL.LogDensityFunction{M,V}
10595
end
10696

107-
varinfo(state::TuringState) = state.varinfo
108-
varinfo(state::AbstractVarInfo) = state
97+
# get_varinfo should return something from which the correct parameters can be
98+
# obtained, hence we use state.varinfo rather than state.ldf.varinfo
99+
get_varinfo(state::TuringState) = state.varinfo
100+
get_varinfo(state::AbstractVarInfo) = state
109101

110102
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
111103
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
@@ -115,27 +107,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
115107

116108
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
117109

118-
function make_updated_varinfo(
119-
f::DynamicPPL.LogDensityFunction, external_transition, external_state
120-
)
121-
# Set the parameters.
122-
new_parameters = getparams(f.model, external_state)
123-
new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters)
124-
# Set (or recalculate, if needed) the log density.
125-
new_logp = getlogp_external(external_transition, external_state)
126-
return if ismissing(new_logp)
127-
last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context))
128-
else
129-
# TODO(DPPL0.37/penelopeysm) This is obviously wrong. Note that we
130-
# have the same problem here as in HMC in that the sampler doesn't
131-
# tell us about how logp is broken down into prior and likelihood.
132-
# We should probably just re-evaluate unconditionally. A bit
133-
# unfortunate.
134-
DynamicPPL.setlogprior!!(new_varinfo, 0.0)
135-
DynamicPPL.setloglikelihood!!(new_varinfo, new_logp)
136-
end
137-
end
138-
139110
# TODO: Do we also support `resume`, etc?
140111
function AbstractMCMC.step(
141112
rng::Random.AbstractRNG,
@@ -182,13 +153,10 @@ function AbstractMCMC.step(
182153
)
183154
end
184155

185-
# Get the parameters and log density, and set them in the varinfo.
186-
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
187-
188-
# Update the `state`
156+
new_parameters = getparams(f.model, state_inner)
157+
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
189158
return (
190-
Transition(f.model, new_varinfo, transition_inner),
191-
TuringState(state_inner, new_varinfo, f),
159+
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
192160
)
193161
end
194162

@@ -207,12 +175,9 @@ function AbstractMCMC.step(
207175
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
208176
)
209177

210-
# Get the parameters and log density, and set them in the varinfo.
211-
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
212-
213-
# Update the `state`
178+
new_parameters = getparams(f.model, state_inner)
179+
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
214180
return (
215-
Transition(f.model, new_varinfo, transition_inner),
216-
TuringState(state_inner, new_varinfo, f),
181+
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
217182
)
218183
end

0 commit comments

Comments
 (0)