Skip to content

Commit c414095

Browse files
committed
Don't unflatten inside Transition constructor
TuringLang/DynamicPPL.jl#1001
1 parent 3607830 commit c414095

File tree

12 files changed

+46
-65
lines changed

12 files changed

+46
-65
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ function DynamicPPL.initialstep(
7474
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
7575

7676
# Create first sample and state.
77-
sample = Turing.Inference.Transition(model, vi, Q.q, nothing)
77+
vi = DynamicPPL.unflatten(vi, Q.q)
78+
sample = Turing.Inference.Transition(model, vi, nothing)
7879
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
7980

8081
return sample, state
@@ -94,7 +95,8 @@ function AbstractMCMC.step(
9495
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
9596

9697
# Create next sample and state.
97-
sample = Turing.Inference.Transition(model, vi, Q.q, nothing)
98+
vi = DynamicPPL.unflatten(vi, Q.q)
99+
sample = Turing.Inference.Transition(model, vi, nothing)
98100
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
99101

100102
return sample, newstate

src/mcmc/Inference.jl

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -137,31 +137,20 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
137137
stat::N
138138

139139
"""
140-
Transition(model::Model, vi::AbstractVarInfo, params::AbstractVector, sampler_transition)
140+
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141141
142-
Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step.
142+
Construct a new `Turing.Inference.Transition` object using the outputs of a
143+
sampler step.
143144
144-
Here, `vi` represents a VarInfo which in general may have junk contents (both
145-
parameters and accumulators). The role of this method is to re-evaluate `model` by inserting
146-
the new `params` (provided by the sampler) into the VarInfo `vi`.
145+
Here, `vi` represents a VarInfo _for which the appropriate parameters have
146+
already been set_. However, the accumulators (e.g. logp) may in general
147+
have junk contents. The role of this method is to re-evaluate `model` and
148+
thus set the accumulators to the correct values.
147149
148-
`sampler_transition` is the transition object returned by the sampler itself and is only used
149-
to extract statistics of interest.
150-
151-
!!! warning "Parameters must match varinfo linking status"
152-
It is mandatory that the vector of parameters provided line up exactly with how the
153-
VarInfo `vi` is linked. Otherwise, this can silently produce incorrect results.
150+
`sampler_transition` is the transition object returned by the sampler
151+
itself and is only used to extract statistics of interest.
154152
"""
155-
function Transition(
156-
model::DynamicPPL.Model,
157-
vi::AbstractVarInfo,
158-
parameters::AbstractVector,
159-
sampler_transition,
160-
)
161-
# To be safe...
162-
vi = deepcopy(vi)
163-
# Set the parameters and re-evaluate with the appropriate accumulators
164-
vi = DynamicPPL.unflatten(vi, parameters)
153+
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
165154
vi = DynamicPPL.setaccs!!(
166155
vi,
167156
(
@@ -193,18 +182,16 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
193182
values_split, logprior, loglikelihood, stats
194183
)
195184
end
185+
196186
function Transition(
197187
model::DynamicPPL.Model,
198188
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
199-
parameters::AbstractVector,
200189
sampler_transition,
201190
)
202191
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
203192
# much faster to convert it to a typed varinfo first, hence this method.
204193
# https://github.com/TuringLang/Turing.jl/issues/2604
205-
return Transition(
206-
model, DynamicPPL.typed_varinfo(untyped_vi), parameters, sampler_transition
207-
)
194+
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
208195
end
209196
end
210197

@@ -232,7 +219,7 @@ end
232219

233220
getparams(::DynamicPPL.Model, t::AbstractTransition) = t.θ
234221
function getparams(model::DynamicPPL.Model, vi::AbstractVarInfo)
235-
t = Transition(model, vi, vi[:], nothing)
222+
t = Transition(model, vi, nothing)
236223
return getparams(model, t)
237224
end
238225
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
@@ -502,16 +489,17 @@ function transitions_from_chain(
502489
chain::MCMCChains.Chains;
503490
sampler=DynamicPPL.SampleFromPrior(),
504491
)
505-
vi = Turing.VarInfo(model)
492+
vi = VarInfo(model)
506493

507494
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
508495
transitions = map(iters) do (sample_idx, chain_idx)
509496
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
497+
# TODO(DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
510498
DynamicPPL.setval_and_resample!(vi, chain, sample_idx, chain_idx)
511499
model(rng, vi, sampler)
512500

513501
# Convert `VarInfo` into `NamedTuple` and save.
514-
Transition(model, vi)
502+
Transition(model, vi, nothing)
515503
end
516504

517505
return transitions

src/mcmc/emcee.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function AbstractMCMC.step(
6565
end
6666

6767
# Compute initial transition and states.
68-
transition = [Transition(model, vi, vi[:], nothing) for vi in vis]
68+
transition = [Transition(model, vi, nothing) for vi in vis]
6969

7070
# TODO: Make compatible with immutable `AbstractVarInfo`.
7171
state = EmceeState(
@@ -97,7 +97,7 @@ function AbstractMCMC.step(
9797
# Compute the next transition and state.
9898
transition = map(states) do _state
9999
vi = DynamicPPL.unflatten(vi, _state.params)
100-
return Transition(model, vi, _state.params, t)
100+
return Transition(model, vi, t)
101101
end
102102
newstate = EmceeState(vi, states)
103103

src/mcmc/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DynamicPPL.initialstep(
3131
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3232
error("ESS only supports Gaussian prior distributions")
3333
end
34-
return Transition(model, vi, vi[:], nothing), vi
34+
return Transition(model, vi, nothing), vi
3535
end
3636

3737
function AbstractMCMC.step(
@@ -56,7 +56,7 @@ function AbstractMCMC.step(
5656
vi = DynamicPPL.unflatten(vi, sample)
5757
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

59-
return Transition(model, vi, vi[:], nothing), vi
59+
return Transition(model, vi, nothing), vi
6060
end
6161

6262
# Prior distribution of considered random variable

src/mcmc/external_sampler.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,8 @@ function AbstractMCMC.step(
148148
end
149149

150150
new_parameters = getparams(f.model, state_inner)
151-
return (
152-
Transition(f.model, f.varinfo, new_parameters, transition_inner),
153-
TuringState(state_inner, f),
154-
)
151+
vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
152+
return (Transition(f.model, vi, transition_inner), TuringState(state_inner, f))
155153
end
156154

157155
function AbstractMCMC.step(
@@ -170,8 +168,6 @@ function AbstractMCMC.step(
170168
)
171169

172170
new_parameters = getparams(f.model, state_inner)
173-
return (
174-
Transition(f.model, f.varinfo, new_parameters, transition_inner),
175-
TuringState(state_inner, f),
176-
)
171+
vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
172+
return (Transition(f.model, vi, transition_inner), TuringState(state_inner, f))
177173
end

src/mcmc/gibbs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function AbstractMCMC.step(
389389
initial_params=initial_params,
390390
kwargs...,
391391
)
392-
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
392+
return Transition(model, vi, nothing), GibbsState(vi, states)
393393
end
394394

395395
function AbstractMCMC.step_warmup(
@@ -414,7 +414,7 @@ function AbstractMCMC.step_warmup(
414414
initial_params=initial_params,
415415
kwargs...,
416416
)
417-
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
417+
return Transition(model, vi, nothing), GibbsState(vi, states)
418418
end
419419

420420
"""
@@ -502,7 +502,7 @@ function AbstractMCMC.step(
502502
vi, states = gibbs_step_recursive(
503503
rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs...
504504
)
505-
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
505+
return Transition(model, vi, nothing), GibbsState(vi, states)
506506
end
507507

508508
function AbstractMCMC.step_warmup(
@@ -522,7 +522,7 @@ function AbstractMCMC.step_warmup(
522522
vi, states = gibbs_step_recursive(
523523
rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs...
524524
)
525-
return Transition(model, vi, vi[:], nothing), GibbsState(vi, states)
525+
return Transition(model, vi, nothing), GibbsState(vi, states)
526526
end
527527

528528
"""

src/mcmc/hmc.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,15 @@ function DynamicPPL.initialstep(
234234
)
235235
end
236236

237-
# Update VarInfo based on acceptance
237+
# Update VarInfo parameters based on acceptance
238238
new_params = if t.stat.is_accept
239239
t.z.θ
240240
else
241241
theta
242242
end
243243
vi = DynamicPPL.unflatten(vi, new_params)
244244

245-
transition = Transition(model, vi, new_params, t)
245+
transition = Transition(model, vi, t)
246246
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
247247

248248
return transition, state
@@ -283,14 +283,11 @@ function AbstractMCMC.step(
283283
# Update variables
284284
vi = state.vi
285285
if t.stat.is_accept
286-
new_params = t.z.θ
287286
vi = DynamicPPL.unflatten(vi, new_params)
288-
else
289-
new_params = vi[:]
290287
end
291288

292289
# Compute next transition and state.
293-
transition = Transition(model, vi, new_params, t)
290+
transition = Transition(model, vi, t)
294291
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
295292

296293
return transition, newstate

src/mcmc/is.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler
3131
function DynamicPPL.initialstep(
3232
rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs...
3333
)
34-
return Transition(model, vi, vi[:], nothing), nothing
34+
return Transition(model, vi, nothing), nothing
3535
end
3636

3737
function AbstractMCMC.step(
3838
rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs...
3939
)
4040
vi = VarInfo(rng, model, spl)
41-
return Transition(model, vi, vi[:], nothing), nothing
41+
return Transition(model, vi, nothing), nothing
4242
end
4343

4444
# Calculate evidence.

src/mcmc/mh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ function DynamicPPL.initialstep(
364364
# just link everything before sampling.
365365
vi = maybe_link!!(vi, spl, spl.alg.proposals, model)
366366

367-
return Transition(model, vi, vi[:], nothing), vi
367+
return Transition(model, vi, nothing), vi
368368
end
369369

370370
function AbstractMCMC.step(
@@ -375,7 +375,7 @@ function AbstractMCMC.step(
375375
# 2. A bunch of NamedTuples that specify the proposal space
376376
new_vi = propose!!(rng, vi, model, spl, spl.alg.proposals)
377377

378-
return Transition(model, new_vi, new_vi[:], nothing), new_vi
378+
return Transition(model, new_vi, nothing), new_vi
379379
end
380380

381381
####

src/mcmc/prior.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function AbstractMCMC.step(
1717
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
1818
)
1919
_, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo())
20-
return Transition(model, vi, vi[:], nothing), nothing
20+
return Transition(model, vi, nothing), nothing
2121
end
2222

2323
DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains

0 commit comments

Comments
 (0)