@@ -137,31 +137,20 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
137
137
stat:: N
138
138
139
139
"""
140
- Transition(model::Model, vi::AbstractVarInfo, params::AbstractVector, sampler_transition)
140
+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141
141
142
- Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step.
142
+ Construct a new `Turing.Inference.Transition` object using the outputs of a
143
+ sampler step.
143
144
144
- Here, `vi` represents a VarInfo which in general may have junk contents (both
145
- parameters and accumulators). The role of this method is to re-evaluate `model` by inserting
146
- the new `params` (provided by the sampler) into the VarInfo `vi`.
145
+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
146
+ already been set_. However, the accumulators (e.g. logp) may in general
147
+ have junk contents. The role of this method is to re-evaluate `model` and
148
+ thus set the accumulators to the correct values.
147
149
148
- `sampler_transition` is the transition object returned by the sampler itself and is only used
149
- to extract statistics of interest.
150
-
151
- !!! warning "Parameters must match varinfo linking status"
152
- It is mandatory that the vector of parameters provided line up exactly with how the
153
- VarInfo `vi` is linked. Otherwise, this can silently produce incorrect results.
150
+ `sampler_transition` is the transition object returned by the sampler
151
+ itself and is only used to extract statistics of interest.
154
152
"""
155
- function Transition (
156
- model:: DynamicPPL.Model ,
157
- vi:: AbstractVarInfo ,
158
- parameters:: AbstractVector ,
159
- sampler_transition,
160
- )
161
- # To be safe...
162
- vi = deepcopy (vi)
163
- # Set the parameters and re-evaluate with the appropriate accumulators
164
- vi = DynamicPPL. unflatten (vi, parameters)
153
+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
165
154
vi = DynamicPPL. setaccs!! (
166
155
vi,
167
156
(
@@ -193,18 +182,16 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
193
182
values_split, logprior, loglikelihood, stats
194
183
)
195
184
end
185
+
196
186
function Transition (
197
187
model:: DynamicPPL.Model ,
198
188
untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
199
- parameters:: AbstractVector ,
200
189
sampler_transition,
201
190
)
202
191
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
203
192
# much faster to convert it to a typed varinfo first, hence this method.
204
193
# https://github.com/TuringLang/Turing.jl/issues/2604
205
- return Transition (
206
- model, DynamicPPL. typed_varinfo (untyped_vi), parameters, sampler_transition
207
- )
194
+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
208
195
end
209
196
end
210
197
232
219
233
220
getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
234
221
function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
235
- t = Transition (model, vi, vi[:], nothing )
222
+ t = Transition (model, vi, nothing )
236
223
return getparams (model, t)
237
224
end
238
225
function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
@@ -502,16 +489,17 @@ function transitions_from_chain(
502
489
chain:: MCMCChains.Chains ;
503
490
sampler= DynamicPPL. SampleFromPrior (),
504
491
)
505
- vi = Turing . VarInfo (model)
492
+ vi = VarInfo (model)
506
493
507
494
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
508
495
transitions = map (iters) do (sample_idx, chain_idx)
509
496
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
497
+ # TODO (DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
510
498
DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
511
499
model (rng, vi, sampler)
512
500
513
501
# Convert `VarInfo` into `NamedTuple` and save.
514
- Transition (model, vi)
502
+ Transition (model, vi, nothing )
515
503
end
516
504
517
505
return transitions
0 commit comments