@@ -124,85 +124,124 @@ end
124
124
# #####################
125
125
# Default Transition #
126
126
# #####################
127
- # Default
128
- getstats (t) = nothing
127
+ getstats (:: Any ) = NamedTuple ()
129
128
129
+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130
+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130
131
abstract type AbstractTransition end
131
132
132
- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133
+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133
134
θ:: T
134
- lp:: F # TODO : merge `lp` with `stat`
135
- stat:: S
136
- end
137
-
138
- Transition (θ, lp) = Transition (θ, lp, nothing )
139
- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140
- # TODO (DPPL0.37/penelopeysm): Fix this
141
- θ = getparams (model, vi)
142
- lp = getlogjoint_internal (vi)
143
- return Transition (θ, lp, getstats (t))
135
+ logprior:: F
136
+ loglikelihood:: F
137
+ stat:: N
138
+
139
+ """
140
+ Transition(model::Model, vi::AbstractVarInfo, params::AbstractVector, sampler_transition)
141
+
142
+ Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step.
143
+
144
+ Here, `vi` represents a VarInfo which in general may have junk contents (both
145
+ parameters and accumulators). The role of this method is to re-evaluate `model` by inserting
146
+ the new `params` (provided by the sampler) into the VarInfo `vi`.
147
+
148
+ `sampler_transition` is the transition object returned by the sampler itself and is only used
149
+ to extract statistics of interest.
150
+
151
+ !!! warning "Parameters must match varinfo linking status"
152
+ It is mandatory that the vector of parameters provided line up exactly with how the
153
+ VarInfo `vi` is linked. Otherwise, this can silently produce incorrect results.
154
+ """
155
+ function Transition (
156
+ model:: DynamicPPL.Model ,
157
+ vi:: AbstractVarInfo ,
158
+ parameters:: AbstractVector ,
159
+ sampler_transition,
160
+ )
161
+ # To be safe...
162
+ vi = deepcopy (vi)
163
+ # Set the parameters and re-evaluate with the appropriate accumulators
164
+ vi = DynamicPPL. unflatten (vi, parameters)
165
+ vi = DynamicPPL. setaccs!! (
166
+ vi,
167
+ (
168
+ DynamicPPL. ValuesAsInModelAccumulator (true ),
169
+ DynamicPPL. LogPriorAccumulator (),
170
+ DynamicPPL. LogLikelihoodAccumulator (),
171
+ ),
172
+ )
173
+ _, vi = DynamicPPL. evaluate!! (model, vi)
174
+
175
+ # Extract all the information we need
176
+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
177
+ logprior = DynamicPPL. getlogprior (vi)
178
+ loglikelihood = DynamicPPL. getloglikelihood (vi)
179
+
180
+ # Convert values to the format needed (with individual VarNames split up).
181
+ # TODO (penelopeysm): This wouldn't be necessary if not for MCMCChains's poor
182
+ # representation...
183
+ iters = map (
184
+ DynamicPPL. varname_and_value_leaves,
185
+ keys (vals_as_in_model),
186
+ values (vals_as_in_model),
187
+ )
188
+ values_split = mapreduce (collect, vcat, iters)
189
+
190
+ # Get additional statistics
191
+ stats = getstats (sampler_transition)
192
+ return new {typeof(values_split),typeof(logprior),typeof(stats)} (
193
+ values_split, logprior, loglikelihood, stats
194
+ )
195
+ end
196
+ function Transition (
197
+ model:: DynamicPPL.Model ,
198
+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
199
+ parameters:: AbstractVector ,
200
+ sampler_transition,
201
+ )
202
+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
203
+ # much faster to convert it to a typed varinfo first, hence this method.
204
+ # https://github.com/TuringLang/Turing.jl/issues/2604
205
+ return Transition (
206
+ model, DynamicPPL. typed_varinfo (untyped_vi), parameters, sampler_transition
207
+ )
208
+ end
144
209
end
145
210
146
- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147
211
function metadata (t:: Transition )
148
- stat = t. stat
149
- if stat === nothing
150
- return (lp= t. lp,)
151
- else
152
- return merge ((lp= t. lp,), stat)
153
- end
212
+ return merge (
213
+ t. stat,
214
+ (
215
+ lp= t. logprior + t. loglikelihood,
216
+ logprior= t. logprior,
217
+ loglikelihood= t. loglikelihood,
218
+ ),
219
+ )
220
+ end
221
+ function metadata (vi:: AbstractVarInfo )
222
+ return (
223
+ lp= DynamicPPL. getlogjoint (vi),
224
+ logprior= DynamicPPL. getlogp (vi),
225
+ loglikelihood= DynamicPPL. getloglikelihood (vi),
226
+ )
154
227
end
155
-
156
- # TODO (DPPL0.37/penelopeysm): Fix this
157
- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158
-
159
- # Metadata of VarInfo object
160
- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161
- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
162
228
163
229
# #########################
164
230
# Chain making utilities #
165
231
# #########################
166
232
167
- """
168
- getparams(model, t)
169
-
170
- Return a named tuple of parameters.
171
- """
172
- getparams (model, t) = t. θ
173
- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174
- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175
- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176
- # of the parameters change depending on the realizations. Hence we have to use
177
- # `values_as_in_model`, which re-runs the model and extracts the parameters
178
- # as they are seen in the model, i.e. in the constrained space. Moreover,
179
- # this means that the code below will work both of linked and invlinked `vi`.
180
- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181
- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182
- vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
183
-
184
- # Obtain an iterator over the flattened parameter names and values.
185
- iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
186
-
187
- # Materialize the iterators and concatenate.
188
- return mapreduce (collect, vcat, iters)
233
+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
234
+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
235
+ t = Transition (model, vi, vi[:], nothing )
236
+ return getparams (model, t)
189
237
end
190
- function getparams (
191
- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
192
- )
193
- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
194
- # much faster to convert it to a typed varinfo before calling getparams.
195
- # https://github.com/TuringLang/Turing.jl/issues/2604
196
- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
197
- end
198
- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
199
- return float (Real)[]
200
- end
201
-
202
238
function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
203
239
names_set = OrderedSet {VarName} ()
204
240
# Extract the parameter names and values from each transition.
205
241
dicts = map (ts) do t
242
+ # TODO (penelopeysm): Get rid of AbstractVarInfo transitions. see
243
+ # https://github.com/TuringLang/Turing.jl/issues/2631. That would
244
+ # allow us to just use t.θ here.
206
245
nms_and_vs = getparams (model, t)
207
246
nms = map (first, nms_and_vs)
208
247
vs = map (last, nms_and_vs)
@@ -221,7 +260,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
221
260
end
222
261
223
262
function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
224
- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
263
+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
225
264
return [:lp ], valmat
226
265
end
227
266
0 commit comments