@@ -17,8 +17,8 @@ using DynamicPPL:
17
17
setindex!!,
18
18
push!!,
19
19
setlogp!!,
20
- getlogp,
21
20
getlogjoint,
21
+ getlogjoint_internal,
22
22
VarName,
23
23
getsym,
24
24
getdist,
@@ -123,71 +123,94 @@ end
123
123
# #####################
124
124
# Default Transition #
125
125
# #####################
126
- # Default
127
- getstats (t) = nothing
126
+ getstats (:: Any ) = NamedTuple ()
128
127
128
+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
129
+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
129
130
abstract type AbstractTransition end
130
131
131
- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
132
+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
132
133
θ:: T
133
- lp:: F # TODO : merge `lp` with `stat`
134
- stat:: S
135
- end
136
-
137
- Transition (θ, lp) = Transition (θ, lp, nothing )
138
- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
139
- θ = getparams (model, vi)
140
- lp = getlogjoint (vi)
141
- return Transition (θ, lp, getstats (t))
142
- end
134
+ logprior:: F
135
+ loglikelihood:: F
136
+ stat:: N
137
+
138
+ """
139
+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
140
+
141
+ Construct a new `Turing.Inference.Transition` object using the outputs of a
142
+ sampler step.
143
+
144
+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
145
+ already been set_. However, the accumulators (e.g. logp) may in general
146
+ have junk contents. The role of this method is to re-evaluate `model` and
147
+ thus set the accumulators to the correct values.
148
+
149
+ `sampler_transition` is the transition object returned by the sampler
150
+ itself and is only used to extract statistics of interest.
151
+ """
152
+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
153
+ vi = DynamicPPL. setaccs!! (
154
+ vi,
155
+ (
156
+ DynamicPPL. ValuesAsInModelAccumulator (true ),
157
+ DynamicPPL. LogPriorAccumulator (),
158
+ DynamicPPL. LogLikelihoodAccumulator (),
159
+ ),
160
+ )
161
+ _, vi = DynamicPPL. evaluate!! (model, vi)
162
+
163
+ # Extract all the information we need
164
+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
165
+ logprior = DynamicPPL. getlogprior (vi)
166
+ loglikelihood = DynamicPPL. getloglikelihood (vi)
167
+
168
+ # Get additional statistics
169
+ stats = getstats (sampler_transition)
170
+ return new {typeof(vals_as_in_model),typeof(logprior),typeof(stats)} (
171
+ vals_as_in_model, logprior, loglikelihood, stats
172
+ )
173
+ end
143
174
144
- function metadata (t:: Transition )
145
- stat = t. stat
146
- if stat === nothing
147
- return (lp= t. lp,)
148
- else
149
- return merge ((lp= t. lp,), stat)
175
+ function Transition (
176
+ model:: DynamicPPL.Model ,
177
+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
178
+ sampler_transition,
179
+ )
180
+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181
+ # much faster to convert it to a typed varinfo first, hence this method.
182
+ # https://github.com/TuringLang/Turing.jl/issues/2604
183
+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
150
184
end
151
185
end
152
186
153
- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
154
-
155
- # Metadata of VarInfo object
156
- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
187
+ function getstats_with_lp (t:: Transition )
188
+ return merge (
189
+ t. stat,
190
+ (
191
+ lp= t. logprior + t. loglikelihood,
192
+ logprior= t. logprior,
193
+ loglikelihood= t. loglikelihood,
194
+ ),
195
+ )
196
+ end
197
+ function getstats_with_lp (vi:: AbstractVarInfo )
198
+ return (
199
+ lp= DynamicPPL. getlogjoint (vi),
200
+ logprior= DynamicPPL. getlogprior (vi),
201
+ loglikelihood= DynamicPPL. getloglikelihood (vi),
202
+ )
203
+ end
157
204
158
205
# #########################
159
206
# Chain making utilities #
160
207
# #########################
161
208
162
- """
163
- getparams(model, t)
164
-
165
- Return a named tuple of parameters.
166
- """
167
- getparams (model, t) = t. θ
168
- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
169
- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
170
- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
171
- # of the parameters change depending on the realizations. Hence we have to use
172
- # `values_as_in_model`, which re-runs the model and extracts the parameters
173
- # as they are seen in the model, i.e. in the constrained space. Moreover,
174
- # this means that the code below will work both of linked and invlinked `vi`.
175
- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
176
- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
177
- return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
178
- end
179
- function getparams (
180
- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
181
- )
182
- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
183
- # much faster to convert it to a typed varinfo before calling getparams.
184
- # https://github.com/TuringLang/Turing.jl/issues/2604
185
- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
209
+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
210
+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
211
+ t = Transition (model, vi, nothing )
212
+ return getparams (model, t)
186
213
end
187
- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
188
- return Dict {VarName,Any} ()
189
- end
190
-
191
214
function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
192
215
names_set = OrderedSet {VarName} ()
193
216
# Extract the parameter names and values from each transition.
@@ -203,7 +226,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
203
226
iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
204
227
mapreduce (collect, vcat, iters)
205
228
end
206
-
207
229
nms = map (first, nms_and_vs)
208
230
vs = map (last, nms_and_vs)
209
231
for nm in nms
@@ -218,14 +240,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
218
240
return names, vals
219
241
end
220
242
221
- function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
222
- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
223
- return [:lp ], valmat
224
- end
225
-
226
243
function get_transition_extras (ts:: AbstractVector )
227
- # Extract all metadata.
228
- extra_data = map (metadata , ts)
244
+ # Extract stats + log probabilities from each transition or VarInfo
245
+ extra_data = map (getstats_with_lp , ts)
229
246
return names_values (extra_data)
230
247
end
231
248
@@ -334,7 +351,7 @@ function AbstractMCMC.bundle_samples(
334
351
vals = map (values (sym_to_vns)) do vns
335
352
map (Base. Fix1 (getindex, params), vns)
336
353
end
337
- return merge (NamedTuple (zip (keys (sym_to_vns), vals)), metadata (t))
354
+ return merge (NamedTuple (zip (keys (sym_to_vns), vals)), getstats_with_lp (t))
338
355
end
339
356
end
340
357
@@ -396,84 +413,4 @@ function DynamicPPL.get_matching_type(
396
413
return Array{T,N}
397
414
end
398
415
399
- # #############
400
- # Utilities #
401
- # #############
402
-
403
- """
404
-
405
- transitions_from_chain(
406
- [rng::AbstractRNG,]
407
- model::Model,
408
- chain::MCMCChains.Chains;
409
- sampler = DynamicPPL.SampleFromPrior()
410
- )
411
-
412
- Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
413
-
414
- The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
415
-
416
- # Details
417
-
418
- In a bit more detail, the process is as follows:
419
- 1. For every `sample` in `chain`
420
- 1. For every `variable` in `sample`
421
- 1. Set `variable` in `model` to its value in `sample`
422
- 2. Execute `model` with variables fixed as above, sampling variables NOT present
423
- in `chain` using `SampleFromPrior`
424
- 3. Return sampled variables and log-joint
425
-
426
- # Example
427
- ```julia-repl
428
- julia> using Turing
429
-
430
- julia> @model function demo()
431
- m ~ Normal(0, 1)
432
- x ~ Normal(m, 1)
433
- end;
434
-
435
- julia> m = demo();
436
-
437
- julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
438
-
439
- julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
440
-
441
- julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
442
- 2-element Array{Float64,1}:
443
- -3.6294991938628374
444
- -2.5697948166987845
445
-
446
- julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
447
- 2-element Array{Array{Float64,1},1}:
448
- [-2.0844148956440796]
449
- [-1.704630494695469]
450
- ```
451
- """
452
- function transitions_from_chain (
453
- model:: DynamicPPL.Model , chain:: MCMCChains.Chains ; kwargs...
454
- )
455
- return transitions_from_chain (Random. default_rng (), model, chain; kwargs... )
456
- end
457
-
458
- function transitions_from_chain (
459
- rng:: Random.AbstractRNG ,
460
- model:: DynamicPPL.Model ,
461
- chain:: MCMCChains.Chains ;
462
- sampler= DynamicPPL. SampleFromPrior (),
463
- )
464
- vi = Turing. VarInfo (model)
465
-
466
- iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
467
- transitions = map (iters) do (sample_idx, chain_idx)
468
- # Set variables present in `chain` and mark those NOT present in chain to be resampled.
469
- DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
470
- model (rng, vi, sampler)
471
-
472
- # Convert `VarInfo` into `NamedTuple` and save.
473
- Transition (model, vi)
474
- end
475
-
476
- return transitions
477
- end
478
-
479
416
end # module
0 commit comments