@@ -124,75 +124,94 @@ end
124
124
# #####################
125
125
# Default Transition #
126
126
# #####################
127
- # Default
128
- getstats (t) = nothing
127
+ getstats (:: Any ) = NamedTuple ()
129
128
129
+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130
+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130
131
abstract type AbstractTransition end
131
132
132
- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133
+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133
134
θ:: T
134
- lp:: F # TODO : merge `lp` with `stat`
135
- stat:: S
136
- end
135
+ logprior:: F
136
+ loglikelihood:: F
137
+ stat:: N
138
+
139
+ """
140
+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141
+
142
+ Construct a new `Turing.Inference.Transition` object using the outputs of a
143
+ sampler step.
144
+
145
+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
146
+ already been set_. However, the accumulators (e.g. logp) may in general
147
+ have junk contents. The role of this method is to re-evaluate `model` and
148
+ thus set the accumulators to the correct values.
149
+
150
+ `sampler_transition` is the transition object returned by the sampler
151
+ itself and is only used to extract statistics of interest.
152
+ """
153
+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
154
+ vi = DynamicPPL. setaccs!! (
155
+ vi,
156
+ (
157
+ DynamicPPL. ValuesAsInModelAccumulator (true ),
158
+ DynamicPPL. LogPriorAccumulator (),
159
+ DynamicPPL. LogLikelihoodAccumulator (),
160
+ ),
161
+ )
162
+ _, vi = DynamicPPL. evaluate!! (model, vi)
163
+
164
+ # Extract all the information we need
165
+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
166
+ logprior = DynamicPPL. getlogprior (vi)
167
+ loglikelihood = DynamicPPL. getloglikelihood (vi)
168
+
169
+ # Get additional statistics
170
+ stats = getstats (sampler_transition)
171
+ return new {typeof(vals_as_in_model),typeof(logprior),typeof(stats)} (
172
+ vals_as_in_model, logprior, loglikelihood, stats
173
+ )
174
+ end
137
175
138
- Transition (θ, lp) = Transition (θ, lp, nothing )
139
- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140
- # TODO (DPPL0.37/penelopeysm): Fix this
141
- θ = getparams (model, vi)
142
- lp = getlogjoint_internal (vi)
143
- return Transition (θ, lp, getstats (t))
176
+ function Transition (
177
+ model:: DynamicPPL.Model ,
178
+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
179
+ sampler_transition,
180
+ )
181
+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
182
+ # much faster to convert it to a typed varinfo first, hence this method.
183
+ # https://github.com/TuringLang/Turing.jl/issues/2604
184
+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
185
+ end
144
186
end
145
187
146
- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147
188
function metadata (t:: Transition )
148
- stat = t. stat
149
- if stat === nothing
150
- return (lp= t. lp,)
151
- else
152
- return merge ((lp= t. lp,), stat)
153
- end
189
+ return merge (
190
+ t. stat,
191
+ (
192
+ lp= t. logprior + t. loglikelihood,
193
+ logprior= t. logprior,
194
+ loglikelihood= t. loglikelihood,
195
+ ),
196
+ )
197
+ end
198
+ function metadata (vi:: AbstractVarInfo )
199
+ return (
200
+ lp= DynamicPPL. getlogjoint (vi),
201
+ logprior= DynamicPPL. getlogp (vi),
202
+ loglikelihood= DynamicPPL. getloglikelihood (vi),
203
+ )
154
204
end
155
-
156
- # TODO (DPPL0.37/penelopeysm): Fix this
157
- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158
-
159
- # Metadata of VarInfo object
160
- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161
- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
162
205
163
206
# #########################
164
207
# Chain making utilities #
165
208
# #########################
166
209
167
- """
168
- getparams(model, t)
169
-
170
- Return a named tuple of parameters.
171
- """
172
- getparams (model, t) = t. θ
173
- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174
- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175
- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176
- # of the parameters change depending on the realizations. Hence we have to use
177
- # `values_as_in_model`, which re-runs the model and extracts the parameters
178
- # as they are seen in the model, i.e. in the constrained space. Moreover,
179
- # this means that the code below will work both of linked and invlinked `vi`.
180
- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181
- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182
- return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
210
+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
211
+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
212
+ t = Transition (model, vi, nothing )
213
+ return getparams (model, t)
183
214
end
184
- function getparams (
185
- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
186
- )
187
- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
188
- # much faster to convert it to a typed varinfo before calling getparams.
189
- # https://github.com/TuringLang/Turing.jl/issues/2604
190
- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
191
- end
192
- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
193
- return Dict {VarName,Any} ()
194
- end
195
-
196
215
function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
197
216
names_set = OrderedSet {VarName} ()
198
217
# Extract the parameter names and values from each transition.
@@ -208,7 +227,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
208
227
iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
209
228
mapreduce (collect, vcat, iters)
210
229
end
211
-
212
230
nms = map (first, nms_and_vs)
213
231
vs = map (last, nms_and_vs)
214
232
for nm in nms
@@ -224,7 +242,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
224
242
end
225
243
226
244
function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
227
- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
245
+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
228
246
return [:lp ], valmat
229
247
end
230
248
@@ -466,16 +484,17 @@ function transitions_from_chain(
466
484
chain:: MCMCChains.Chains ;
467
485
sampler= DynamicPPL. SampleFromPrior (),
468
486
)
469
- vi = Turing . VarInfo (model)
487
+ vi = VarInfo (model)
470
488
471
489
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
472
490
transitions = map (iters) do (sample_idx, chain_idx)
473
491
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
492
+ # TODO (DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
474
493
DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
475
494
model (rng, vi, sampler)
476
495
477
496
# Convert `VarInfo` into `NamedTuple` and save.
478
- Transition (model, vi)
497
+ Transition (model, vi, nothing )
479
498
end
480
499
481
500
return transitions
0 commit comments