Skip to content

Commit 75913e9

Browse files
committed
Enforce re-evaluation when constructing Transition
1 parent d7a46e1 commit 75913e9

File tree

16 files changed

+215
-229
lines changed

16 files changed

+215
-229
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,9 @@ function DynamicPPL.initialstep(
7373
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
7474
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
7575

76-
# Update the variables.
77-
vi = DynamicPPL.unflatten(vi, Q.q)
78-
# TODO(DPPL0.37/penelopeysm): This is obviously incorrect. Fix this.
79-
vi = DynamicPPL.setloglikelihood!!(vi, Q.ℓq)
80-
vi = DynamicPPL.setlogprior!!(vi, 0.0)
81-
8276
# Create first sample and state.
83-
sample = Turing.Inference.Transition(model, vi)
77+
vi = DynamicPPL.unflatten(vi, Q.q)
78+
sample = Turing.Inference.Transition(model, vi, nothing)
8479
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
8580

8681
return sample, state
@@ -99,12 +94,9 @@ function AbstractMCMC.step(
9994
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
10095
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
10196

102-
# Update the variables.
103-
vi = DynamicPPL.unflatten(vi, Q.q)
104-
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
105-
10697
# Create next sample and state.
107-
sample = Turing.Inference.Transition(model, vi)
98+
vi = DynamicPPL.unflatten(vi, Q.q)
99+
sample = Turing.Inference.Transition(model, vi, nothing)
108100
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
109101

110102
return sample, newstate

src/mcmc/Inference.jl

Lines changed: 98 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -124,85 +124,119 @@ end
124124
######################
125125
# Default Transition #
126126
######################
127-
# Default
128-
getstats(t) = nothing
127+
getstats(::Any) = NamedTuple()
129128

129+
# TODO(penelopeysm): Remove this abstract type by converting SGLDTransition,
130+
# SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131
abstract type AbstractTransition end
131132

132-
struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition
133+
struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
133134
θ::T
134-
lp::F # TODO: merge `lp` with `stat`
135-
stat::S
136-
end
135+
logprior::F
136+
loglikelihood::F
137+
stat::N
138+
139+
"""
140+
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141+
142+
Construct a new `Turing.Inference.Transition` object using the outputs of a
143+
sampler step.
144+
145+
Here, `vi` represents a VarInfo _for which the appropriate parameters have
146+
already been set_. However, the accumulators (e.g. logp) may in general
147+
have junk contents. The role of this method is to re-evaluate `model` and
148+
thus set the accumulators to the correct values.
149+
150+
`sampler_transition` is the transition object returned by the sampler
151+
itself and is only used to extract statistics of interest.
152+
"""
153+
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
154+
vi = DynamicPPL.setaccs!!(
155+
vi,
156+
(
157+
DynamicPPL.ValuesAsInModelAccumulator(true),
158+
DynamicPPL.LogPriorAccumulator(),
159+
DynamicPPL.LogLikelihoodAccumulator(),
160+
),
161+
)
162+
_, vi = DynamicPPL.evaluate!!(model, vi)
163+
164+
# Extract all the information we need
165+
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
166+
logprior = DynamicPPL.getlogprior(vi)
167+
loglikelihood = DynamicPPL.getloglikelihood(vi)
168+
169+
# Convert values to the format needed (i.e. a Vector of (varname,
170+
# value) tuples, where value isa Real: all vector-valued varnames must
171+
# be split up.)
172+
# TODO(penelopeysm): This wouldn't be necessary if not for MCMCChains's
173+
# poor representation...
174+
values_split = if isempty(vals_as_in_model)
175+
# If there are no values, we return an empty vector.
176+
# This is the case for models with no parameters.
177+
Vector{Tuple{VarName,Any}}()
178+
else
179+
iters = map(
180+
DynamicPPL.varname_and_value_leaves,
181+
keys(vals_as_in_model),
182+
values(vals_as_in_model),
183+
)
184+
mapreduce(collect, vcat, iters)
185+
end
137186

138-
Transition(θ, lp) = Transition(θ, lp, nothing)
139-
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
140-
# TODO(DPPL0.37/penelopeysm): Fix this
141-
θ = getparams(model, vi)
142-
lp = getlogjoint_internal(vi)
143-
return Transition(θ, lp, getstats(t))
144-
end
187+
# Get additional statistics
188+
stats = getstats(sampler_transition)
189+
return new{typeof(values_split),typeof(logprior),typeof(stats)}(
190+
values_split, logprior, loglikelihood, stats
191+
)
192+
end
145193

146-
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147-
function metadata(t::Transition)
148-
stat = t.stat
149-
if stat === nothing
150-
return (lp=t.lp,)
151-
else
152-
return merge((lp=t.lp,), stat)
194+
function Transition(
195+
model::DynamicPPL.Model,
196+
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
197+
sampler_transition,
198+
)
199+
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
200+
# much faster to convert it to a typed varinfo first, hence this method.
201+
# https://github.com/TuringLang/Turing.jl/issues/2604
202+
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
153203
end
154204
end
155205

156-
# TODO(DPPL0.37/penelopeysm): Fix this
157-
DynamicPPL.getlogjoint(t::Transition) = t.lp
158-
159-
# Metadata of VarInfo object
160-
# TODO(DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161-
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
206+
function metadata(t::Transition)
207+
return merge(
208+
t.stat,
209+
(
210+
lp=t.logprior + t.loglikelihood,
211+
logprior=t.logprior,
212+
loglikelihood=t.loglikelihood,
213+
),
214+
)
215+
end
216+
function metadata(vi::AbstractVarInfo)
217+
return (
218+
lp=DynamicPPL.getlogjoint(vi),
219+
logprior=DynamicPPL.getlogp(vi),
220+
loglikelihood=DynamicPPL.getloglikelihood(vi),
221+
)
222+
end
162223

163224
##########################
164225
# Chain making utilities #
165226
##########################
166227

167-
"""
168-
getparams(model, t)
169-
170-
Return a named tuple of parameters.
171-
"""
172-
getparams(model, t) = t.θ
173-
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
174-
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175-
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176-
# of the parameters change depending on the realizations. Hence we have to use
177-
# `values_as_in_model`, which re-runs the model and extracts the parameters
178-
# as they are seen in the model, i.e. in the constrained space. Moreover,
179-
# this means that the code below will work both of linked and invlinked `vi`.
180-
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181-
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182-
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
183-
184-
# Obtain an iterator over the flattened parameter names and values.
185-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
186-
187-
# Materialize the iterators and concatenate.
188-
return mapreduce(collect, vcat, iters)
228+
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
229+
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
230+
t = Transition(model, vi, nothing)
231+
return getparams(model, t)
189232
end
190-
function getparams(
191-
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
192-
)
193-
# values_as_in_model is unconscionably slow for untyped VarInfo. It's
194-
# much faster to convert it to a typed varinfo before calling getparams.
195-
# https://github.com/TuringLang/Turing.jl/issues/2604
196-
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
197-
end
198-
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
199-
return float(Real)[]
200-
end
201-
202233
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
203234
names_set = OrderedSet{VarName}()
204235
# Extract the parameter names and values from each transition.
205236
dicts = map(ts) do t
237+
# TODO(penelopeysm): Get rid of AbstractVarInfo transitions. see
238+
# https://github.com/TuringLang/Turing.jl/issues/2631. That would
239+
# allow us to just use t.θ here.
206240
nms_and_vs = getparams(model, t)
207241
nms = map(first, nms_and_vs)
208242
vs = map(last, nms_and_vs)
@@ -221,7 +255,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
221255
end
222256

223257
function get_transition_extras(ts::AbstractVector{<:VarInfo})
224-
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
258+
valmat = reshape([DynamicPPL.getlogjoint(t) for t in ts], :, 1)
225259
return [:lp], valmat
226260
end
227261

@@ -463,16 +497,17 @@ function transitions_from_chain(
463497
chain::MCMCChains.Chains;
464498
sampler=DynamicPPL.SampleFromPrior(),
465499
)
466-
vi = Turing.VarInfo(model)
500+
vi = VarInfo(model)
467501

468502
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
469503
transitions = map(iters) do (sample_idx, chain_idx)
470504
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
505+
# TODO(DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
471506
DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx)
472507
model(rng, vi, sampler)
473508

474509
# Convert `VarInfo` into `NamedTuple` and save.
475-
Transition(model, vi)
510+
Transition(model, vi, nothing)
476511
end
477512

478513
return transitions

src/mcmc/emcee.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function AbstractMCMC.step(
6565
end
6666

6767
# Compute initial transition and states.
68-
transition = map(Base.Fix1(Transition, model), vis)
68+
transition = [Transition(model, vi, nothing) for vi in vis]
6969

7070
# TODO: Make compatible with immutable `AbstractVarInfo`.
7171
state = EmceeState(
@@ -92,13 +92,12 @@ function AbstractMCMC.step(
9292
)
9393

9494
# Compute the next states.
95-
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
95+
t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)
9696

9797
# Compute the next transition and state.
9898
transition = map(states) do _state
9999
vi = DynamicPPL.unflatten(vi, _state.params)
100-
t = Transition(getparams(model, vi), _state.lp)
101-
return t
100+
return Transition(model, vi, t)
102101
end
103102
newstate = EmceeState(vi, states)
104103

src/mcmc/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DynamicPPL.initialstep(
3131
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3232
error("ESS only supports Gaussian prior distributions")
3333
end
34-
return Transition(model, vi), vi
34+
return Transition(model, vi, nothing), vi
3535
end
3636

3737
function AbstractMCMC.step(
@@ -56,7 +56,7 @@ function AbstractMCMC.step(
5656
vi = DynamicPPL.unflatten(vi, sample)
5757
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

59-
return Transition(model, vi), vi
59+
return Transition(model, vi, nothing), vi
6060
end
6161

6262
# Prior distribution of considered random variable

src/mcmc/external_sampler.jl

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ There are a few more optional functions which you can implement to improve the i
2222
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
2323
2424
- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space.
25-
26-
- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method.
2725
"""
2826
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
2927
InferenceAlgorithm
@@ -85,27 +83,21 @@ function externalsampler(
8583
return ExternalSampler(sampler, adtype, Val(unconstrained))
8684
end
8785

88-
"""
89-
getlogp_external(external_transition, external_state)
90-
91-
Get the log probability density associated with the external sampler's
92-
transition and state. Returns `missing` by default; in this case, an extra
93-
model evaluation will be needed to calculate the correct log density.
94-
"""
95-
getlogp_external(::Any, ::Any) = missing
96-
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
97-
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density
98-
99-
struct TuringState{S,V1<:AbstractVarInfo,M,V}
86+
# TODO(penelopeysm): Can't we clean this up somehow?
87+
struct TuringState{S,V1,M,V}
10088
state::S
101-
# Note that this varinfo has the correct parameters and logp obtained from
102-
# the state, whereas `ldf.varinfo` will in general have junk inside it.
89+
# Note that this varinfo must have the correct parameters set; but logp
90+
# does not matter as it will be re-evaluated
10391
varinfo::V1
92+
# Note that in general the VarInfo inside this LogDensityFunction will have
93+
# junk parameters and logp. It only exists to provide structure
10494
ldf::DynamicPPL.LogDensityFunction{M,V}
10595
end
10696

107-
varinfo(state::TuringState) = state.varinfo
108-
varinfo(state::AbstractVarInfo) = state
97+
# get_varinfo should return something from which the correct parameters can be
98+
# obtained, hence we use state.varinfo rather than state.ldf.varinfo
99+
get_varinfo(state::TuringState) = state.varinfo
100+
get_varinfo(state::AbstractVarInfo) = state
109101

110102
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
111103
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
@@ -115,27 +107,6 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
115107

116108
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
117109

118-
function make_updated_varinfo(
119-
f::DynamicPPL.LogDensityFunction, external_transition, external_state
120-
)
121-
# Set the parameters.
122-
new_parameters = getparams(f.model, external_state)
123-
new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters)
124-
# Set (or recalculate, if needed) the log density.
125-
new_logp = getlogp_external(external_transition, external_state)
126-
return if ismissing(new_logp)
127-
last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context))
128-
else
129-
# TODO(DPPL0.37/penelopeysm) This is obviously wrong. Note that we
130-
# have the same problem here as in HMC in that the sampler doesn't
131-
# tell us about how logp is broken down into prior and likelihood.
132-
# We should probably just re-evaluate unconditionally. A bit
133-
# unfortunate.
134-
DynamicPPL.setlogprior!!(new_varinfo, 0.0)
135-
DynamicPPL.setloglikelihood!!(new_varinfo, new_logp)
136-
end
137-
end
138-
139110
# TODO: Do we also support `resume`, etc?
140111
function AbstractMCMC.step(
141112
rng::Random.AbstractRNG,
@@ -182,13 +153,10 @@ function AbstractMCMC.step(
182153
)
183154
end
184155

185-
# Get the parameters and log density, and set them in the varinfo.
186-
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
187-
188-
# Update the `state`
156+
new_parameters = getparams(f.model, state_inner)
157+
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
189158
return (
190-
Transition(f.model, new_varinfo, transition_inner),
191-
TuringState(state_inner, new_varinfo, f),
159+
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
192160
)
193161
end
194162

@@ -207,12 +175,9 @@ function AbstractMCMC.step(
207175
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
208176
)
209177

210-
# Get the parameters and log density, and set them in the varinfo.
211-
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)
212-
213-
# Update the `state`
178+
new_parameters = getparams(f.model, state_inner)
179+
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
214180
return (
215-
Transition(f.model, new_varinfo, transition_inner),
216-
TuringState(state_inner, new_varinfo, f),
181+
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
217182
)
218183
end

0 commit comments

Comments
 (0)