@@ -124,85 +124,119 @@ end
124
124
# #####################
125
125
# Default Transition #
126
126
# #####################
127
- # Default
128
- getstats (t) = nothing
127
+ getstats (:: Any ) = NamedTuple ()
129
128
129
+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130
+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130
131
abstract type AbstractTransition end
131
132
132
- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133
+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133
134
θ:: T
134
- lp:: F # TODO : merge `lp` with `stat`
135
- stat:: S
136
- end
135
+ logprior:: F
136
+ loglikelihood:: F
137
+ stat:: N
138
+
139
+ """
140
+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141
+
142
+ Construct a new `Turing.Inference.Transition` object using the outputs of a
143
+ sampler step.
144
+
145
+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
146
+ already been set_. However, the accumulators (e.g. logp) may in general
147
+ have junk contents. The role of this method is to re-evaluate `model` and
148
+ thus set the accumulators to the correct values.
149
+
150
+ `sampler_transition` is the transition object returned by the sampler
151
+ itself and is only used to extract statistics of interest.
152
+ """
153
+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
154
+ vi = DynamicPPL. setaccs!! (
155
+ vi,
156
+ (
157
+ DynamicPPL. ValuesAsInModelAccumulator (true ),
158
+ DynamicPPL. LogPriorAccumulator (),
159
+ DynamicPPL. LogLikelihoodAccumulator (),
160
+ ),
161
+ )
162
+ _, vi = DynamicPPL. evaluate!! (model, vi)
163
+
164
+ # Extract all the information we need
165
+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
166
+ logprior = DynamicPPL. getlogprior (vi)
167
+ loglikelihood = DynamicPPL. getloglikelihood (vi)
168
+
169
+ # Convert values to the format needed (i.e. a Vector of (varname,
170
+ # value) tuples, where value isa Real: all vector-valued varnames must
171
+ # be split up.)
172
+ # TODO (penelopeysm): This wouldn't be necessary if not for MCMCChains's
173
+ # poor representation...
174
+ values_split = if isempty (vals_as_in_model)
175
+ # If there are no values, we return an empty vector.
176
+ # This is the case for models with no parameters.
177
+ Vector {Tuple{VarName,Any}} ()
178
+ else
179
+ iters = map (
180
+ DynamicPPL. varname_and_value_leaves,
181
+ keys (vals_as_in_model),
182
+ values (vals_as_in_model),
183
+ )
184
+ mapreduce (collect, vcat, iters)
185
+ end
137
186
138
- Transition (θ, lp) = Transition (θ, lp, nothing )
139
- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140
- # TODO (DPPL0.37/penelopeysm): Fix this
141
- θ = getparams (model, vi)
142
- lp = getlogjoint_internal (vi)
143
- return Transition (θ, lp, getstats (t))
144
- end
187
+ # Get additional statistics
188
+ stats = getstats (sampler_transition)
189
+ return new {typeof(values_split),typeof(logprior),typeof(stats)} (
190
+ values_split, logprior, loglikelihood, stats
191
+ )
192
+ end
145
193
146
- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147
- function metadata (t:: Transition )
148
- stat = t. stat
149
- if stat === nothing
150
- return (lp= t. lp,)
151
- else
152
- return merge ((lp= t. lp,), stat)
194
+ function Transition (
195
+ model:: DynamicPPL.Model ,
196
+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
197
+ sampler_transition,
198
+ )
199
+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
200
+ # much faster to convert it to a typed varinfo first, hence this method.
201
+ # https://github.com/TuringLang/Turing.jl/issues/2604
202
+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
153
203
end
154
204
end
155
205
156
- # TODO (DPPL0.37/penelopeysm): Fix this
157
- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158
-
159
- # Metadata of VarInfo object
160
- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161
- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
206
+ function metadata (t:: Transition )
207
+ return merge (
208
+ t. stat,
209
+ (
210
+ lp= t. logprior + t. loglikelihood,
211
+ logprior= t. logprior,
212
+ loglikelihood= t. loglikelihood,
213
+ ),
214
+ )
215
+ end
216
+ function metadata (vi:: AbstractVarInfo )
217
+ return (
218
+ lp= DynamicPPL. getlogjoint (vi),
219
+ logprior= DynamicPPL. getlogp (vi),
220
+ loglikelihood= DynamicPPL. getloglikelihood (vi),
221
+ )
222
+ end
162
223
163
224
# #########################
164
225
# Chain making utilities #
165
226
# #########################
166
227
167
- """
168
- getparams(model, t)
169
-
170
- Return a named tuple of parameters.
171
- """
172
- getparams (model, t) = t. θ
173
- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174
- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175
- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176
- # of the parameters change depending on the realizations. Hence we have to use
177
- # `values_as_in_model`, which re-runs the model and extracts the parameters
178
- # as they are seen in the model, i.e. in the constrained space. Moreover,
179
- # this means that the code below will work both of linked and invlinked `vi`.
180
- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181
- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182
- vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
183
-
184
- # Obtain an iterator over the flattened parameter names and values.
185
- iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
186
-
187
- # Materialize the iterators and concatenate.
188
- return mapreduce (collect, vcat, iters)
228
+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
229
+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
230
+ t = Transition (model, vi, nothing )
231
+ return getparams (model, t)
189
232
end
190
- function getparams (
191
- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
192
- )
193
- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
194
- # much faster to convert it to a typed varinfo before calling getparams.
195
- # https://github.com/TuringLang/Turing.jl/issues/2604
196
- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
197
- end
198
- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
199
- return float (Real)[]
200
- end
201
-
202
233
function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
203
234
names_set = OrderedSet {VarName} ()
204
235
# Extract the parameter names and values from each transition.
205
236
dicts = map (ts) do t
237
+ # TODO (penelopeysm): Get rid of AbstractVarInfo transitions. see
238
+ # https://github.com/TuringLang/Turing.jl/issues/2631. That would
239
+ # allow us to just use t.θ here.
206
240
nms_and_vs = getparams (model, t)
207
241
nms = map (first, nms_and_vs)
208
242
vs = map (last, nms_and_vs)
@@ -221,7 +255,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
221
255
end
222
256
223
257
function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
224
- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
258
+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
225
259
return [:lp ], valmat
226
260
end
227
261
@@ -463,16 +497,17 @@ function transitions_from_chain(
463
497
chain:: MCMCChains.Chains ;
464
498
sampler= DynamicPPL. SampleFromPrior (),
465
499
)
466
- vi = Turing . VarInfo (model)
500
+ vi = VarInfo (model)
467
501
468
502
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
469
503
transitions = map (iters) do (sample_idx, chain_idx)
470
504
# Set variables present in `chain` and mark those NOT present in chain to be resampled.
505
+ # TODO (DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
471
506
DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
472
507
model (rng, vi, sampler)
473
508
474
509
# Convert `VarInfo` into `NamedTuple` and save.
475
- Transition (model, vi)
510
+ Transition (model, vi, nothing )
476
511
end
477
512
478
513
return transitions
0 commit comments