@@ -136,7 +136,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
136
136
stat:: N
137
137
138
138
"""
139
- Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
139
+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true )
140
140
141
141
Construct a new `Turing.Inference.Transition` object using the outputs of a
142
142
sampler step.
@@ -148,17 +148,38 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
148
148
149
149
`sampler_transition` is the transition object returned by the sampler
150
150
itself and is only used to extract statistics of interest.
151
+
152
+ By default, the model is re-evaluated in order to obtain values of:
153
+ - the values of the parameters as per user parameterisation (`vals_as_in_model`)
154
+ - the various components of the log joint probability (`logprior`, `loglikelihood`)
155
+ that are guaranteed to be correct.
156
+
157
+ If you **know** for a fact that the VarInfo `vi` already contains this information,
158
+ then you can set `reevaluate=false` to skip the re-evaluation step.
159
+
160
+ !!! warning
161
+ Note that in general this is unsafe and may lead to wrong results.
162
+
163
+ If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
164
+ the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
165
+ and `LogLikelihoodAccumulator` set up with the correct values. Note that the
166
+ `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
167
+ must be set up to track `x := y` statements.
151
168
"""
152
- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
153
- vi = DynamicPPL. setaccs!! (
154
- vi,
155
- (
156
- DynamicPPL. ValuesAsInModelAccumulator (true ),
157
- DynamicPPL. LogPriorAccumulator (),
158
- DynamicPPL. LogLikelihoodAccumulator (),
159
- ),
160
- )
161
- _, vi = DynamicPPL. evaluate!! (model, vi)
169
+ function Transition (
170
+ model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition; reevaluate= true
171
+ )
172
+ if reevaluate
173
+ vi = DynamicPPL. setaccs!! (
174
+ vi,
175
+ (
176
+ DynamicPPL. ValuesAsInModelAccumulator (true ),
177
+ DynamicPPL. LogPriorAccumulator (),
178
+ DynamicPPL. LogLikelihoodAccumulator (),
179
+ ),
180
+ )
181
+ _, vi = DynamicPPL. evaluate!! (model, vi)
182
+ end
162
183
163
184
# Extract all the information we need
164
185
vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
@@ -175,12 +196,18 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
175
196
function Transition (
176
197
model:: DynamicPPL.Model ,
177
198
untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
178
- sampler_transition,
199
+ sampler_transition;
200
+ reevaluate= true ,
179
201
)
180
202
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181
203
# much faster to convert it to a typed varinfo first, hence this method.
182
204
# https://github.com/TuringLang/Turing.jl/issues/2604
183
- return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
205
+ return Transition (
206
+ model,
207
+ DynamicPPL. typed_varinfo (untyped_vi),
208
+ sampler_transition;
209
+ reevaluate= reevaluate,
210
+ )
184
211
end
185
212
end
186
213
0 commit comments