Skip to content

Commit 6532d96

Browse files
authored
Implement AbstractMCMC.{to,from}_samples (again) (#1112)
* Implement `AbstractMCMC.{to,from}_samples` (again) * Port docs as well * Simplify `predict` and `returned` implementation * Fix test * Fix test * fix test (again)
1 parent ab6f38a commit 6532d96

File tree

13 files changed

+398
-109
lines changed

13 files changed

+398
-109
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.8
4+
5+
Added a new exported struct, `DynamicPPL.ParamsWithStats`.
6+
This can broadly be used to represent the output of a model: it consists of an `OrderedDict` of `VarName` parameters and their values, along with a `stats` NamedTuple which can hold arbitrary data, such as (but not limited to) log-probabilities.
7+
8+
Implemented the functions `AbstractMCMC.to_samples` and `AbstractMCMC.from_samples`, which convert between an `MCMCChains.Chains` object and a matrix of `DynamicPPL.ParamsWithStats` objects.
9+
310
## 0.38.7
411

512
Made a small tweak to DynamicPPL's compiler output to avoid potential undefined variables when resuming model functions midway through (e.g. with Libtask in Turing's SMC/PG samplers).

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.7"
3+
version = "0.38.8"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4747

4848
[compat]
4949
ADTypes = "1"
50-
AbstractMCMC = "5"
50+
AbstractMCMC = "5.10"
5151
AbstractPPL = "0.13.1"
5252
Accessors = "0.1"
5353
BangBang = "0.4.1"

docs/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
23
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
34
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
45
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -24,6 +25,6 @@ FillArrays = "0.13, 1"
2425
ForwardDiff = "0.10, 1"
2526
JET = "0.9, 0.10, 0.11"
2627
LogDensityProblems = "2"
27-
MarginalLogDensities = "0.4"
2828
MCMCChains = "5, 6, 7"
29+
MarginalLogDensities = "0.4"
2930
StableRNGs = "1"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using Distributions
1111
using DocumenterMermaid
1212
# load MCMCChains package extension to make `predict` available
1313
using MCMCChains
14+
using AbstractMCMC: AbstractMCMC
1415
using MarginalLogDensities: MarginalLogDensities
1516

1617
# Need this to document a method which uses a type inside the extension...

docs/src/api.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,29 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
505505
DynamicPPL.Experimental.determine_suitable_varinfo
506506
DynamicPPL.Experimental.is_suitable_varinfo
507507
```
508+
509+
### Converting VarInfos to/from chains
510+
511+
It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.
512+
513+
This can be accomplished by first converting each `VarInfo` into a `ParamsWithStats` object:
514+
515+
```@docs
516+
DynamicPPL.ParamsWithStats
517+
```
518+
519+
Once you have a **matrix** of these, you can convert them into a chains object using:
520+
521+
```@docs
522+
AbstractMCMC.from_samples(::Type{MCMCChains.Chains}, ::AbstractMatrix{<:DynamicPPL.ParamsWithStats})
523+
```
524+
525+
If you only have a vector you can use `hcat` to convert it into an `N×1` matrix first.
526+
527+
Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with:
528+
529+
```@docs
530+
AbstractMCMC.to_samples(::Type{DynamicPPL.ParamsWithStats}, ::MCMCChains.Chains)
531+
```
532+
533+
With these, you can (for example) extract the parameter dictionaries and use `InitFromParams` to re-evaluate a model at each point in the chain.

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 120 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module DynamicPPLMCMCChainsExt
22

3-
using DynamicPPL: DynamicPPL, AbstractPPL
3+
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
44
using MCMCChains: MCMCChains
55

66
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
@@ -36,6 +36,110 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39+
"""
40+
AbstractMCMC.from_samples(
41+
::Type{MCMCChains.Chains},
42+
params_and_stats::AbstractMatrix{<:ParamsWithStats}
43+
)
44+
45+
Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+
"""
47+
function AbstractMCMC.from_samples(
48+
::Type{MCMCChains.Chains},
49+
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
50+
)
51+
# Handle parameters
52+
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
53+
split_dicts = map(params_and_stats) do ps
54+
# Separate into individual VarNames.
55+
vn_leaves_and_vals = if isempty(ps.params)
56+
Tuple{DynamicPPL.VarName,Any}[]
57+
else
58+
iters = map(
59+
AbstractPPL.varname_and_value_leaves,
60+
keys(ps.params),
61+
values(ps.params),
62+
)
63+
mapreduce(collect, vcat, iters)
64+
end
65+
vn_leaves = map(first, vn_leaves_and_vals)
66+
vals = map(last, vn_leaves_and_vals)
67+
for vn_leaf in vn_leaves
68+
push!(all_vn_leaves, vn_leaf)
69+
end
70+
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
71+
end
72+
vn_leaves = collect(all_vn_leaves)
73+
param_vals = [
74+
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
75+
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
76+
]
77+
param_symbols = map(Symbol, vn_leaves)
78+
# Handle statistics
79+
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
80+
for ps in params_and_stats
81+
for k in keys(ps.stats)
82+
push!(stat_keys, k)
83+
end
84+
end
85+
stat_keys = collect(stat_keys)
86+
stat_vals = [
87+
get(params_and_stats[i, j].stats, key, missing) for
88+
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
89+
j in eachindex(axes(params_and_stats, 2))
90+
]
91+
# Construct name map and info
92+
name_map = (internals=stat_keys,)
93+
info = (
94+
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
95+
zip(all_vn_leaves, param_symbols)
96+
),
97+
)
98+
# Concatenate parameter and statistic values
99+
vals = cat(param_vals, stat_vals; dims=2)
100+
symbols = vcat(param_symbols, stat_keys)
101+
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
102+
end
103+
104+
"""
105+
AbstractMCMC.to_samples(
106+
::Type{DynamicPPL.ParamsWithStats},
107+
chain::MCMCChains.Chains
108+
)
109+
110+
Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`.
111+
112+
For this to work, `chain` must contain the `varname_to_symbol` mapping in its `info` field.
113+
"""
114+
function AbstractMCMC.to_samples(
115+
::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains
116+
)
117+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
118+
# Get parameters
119+
params_matrix = map(idxs) do (sample_idx, chain_idx)
120+
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
121+
for vn in DynamicPPL.varnames(chain)
122+
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
123+
end
124+
d
125+
end
126+
# Statistics
127+
stats_matrix = if :internals in MCMCChains.sections(chain)
128+
internals_chain = MCMCChains.get_sections(chain, :internals)
129+
map(idxs) do (sample_idx, chain_idx)
130+
get(internals_chain[sample_idx, :, chain_idx], keys(internals_chain); flatten=true)
131+
end
132+
else
133+
fill(NamedTuple(), size(idxs))
134+
end
135+
# Bundle them together
136+
return map(idxs) do (sample_idx, chain_idx)
137+
DynamicPPL.ParamsWithStats(
138+
params_matrix[sample_idx, chain_idx], stats_matrix[sample_idx, chain_idx]
139+
)
140+
end
141+
end
142+
39143
"""
40144
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41145
@@ -110,42 +214,24 @@ function DynamicPPL.predict(
110214
DynamicPPL.VarInfo(),
111215
(
112216
DynamicPPL.LogPriorAccumulator(),
113-
DynamicPPL.LogJacobianAccumulator(),
114217
DynamicPPL.LogLikelihoodAccumulator(),
115218
DynamicPPL.ValuesAsInModelAccumulator(false),
116219
),
117220
)
118221
_, varinfo = DynamicPPL.init!!(model, varinfo)
119222
varinfo = DynamicPPL.typed_varinfo(varinfo)
120223

121-
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
122-
predictive_samples = map(iters) do (sample_idx, chain_idx)
123-
# Extract values from the chain
124-
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
125-
# Resample any variables that are not present in `values_dict`
224+
params_and_stats = AbstractMCMC.to_samples(
225+
DynamicPPL.ParamsWithStats, parameter_only_chain
226+
)
227+
predictions = map(params_and_stats) do ps
126228
_, varinfo = DynamicPPL.init!!(
127-
rng,
128-
model,
129-
varinfo,
130-
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
229+
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
131230
)
132-
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
133-
varname_vals = mapreduce(
134-
collect,
135-
vcat,
136-
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
137-
)
138-
139-
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
231+
DynamicPPL.ParamsWithStats(varinfo)
140232
end
233+
chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions)
141234

142-
chain_result = reduce(
143-
MCMCChains.chainscat,
144-
[
145-
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
146-
chain_idx in 1:size(predictive_samples, 2)
147-
],
148-
)
149235
parameter_names = if include_all
150236
MCMCChains.names(chain_result, :parameters)
151237
else
@@ -164,45 +250,6 @@ function DynamicPPL.predict(
164250
)
165251
end
166252

167-
function _predictive_samples_to_arrays(predictive_samples)
168-
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
169-
170-
sample_dicts = map(predictive_samples) do sample
171-
varname_value_pairs = sample.varname_and_values
172-
varnames = map(first, varname_value_pairs)
173-
values = map(last, varname_value_pairs)
174-
for varname in varnames
175-
push!(variable_names_set, varname)
176-
end
177-
178-
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
179-
end
180-
181-
variable_names = collect(variable_names_set)
182-
variable_values = [
183-
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
184-
key in variable_names
185-
]
186-
187-
return variable_names, variable_values
188-
end
189-
190-
function _predictive_samples_to_chains(predictive_samples)
191-
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
192-
variable_names_symbols = map(Symbol, variable_names)
193-
194-
internal_parameters = [:lp]
195-
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
196-
197-
parameter_names = [variable_names_symbols; internal_parameters]
198-
parameter_values = hcat(variable_values, log_probabilities)
199-
parameter_values = MCMCChains.concretize(parameter_values)
200-
201-
return MCMCChains.Chains(
202-
parameter_values, parameter_names, (internals=internal_parameters,)
203-
)
204-
end
205-
206253
"""
207254
returned(model::Model, chain::MCMCChains.Chains)
208255
@@ -266,17 +313,15 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
266313
chain = MCMCChains.get_sections(chain_full, :parameters)
267314
varinfo = DynamicPPL.VarInfo(model)
268315
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
269-
return map(iters) do (sample_idx, chain_idx)
270-
# Extract values from the chain
271-
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
272-
# Resample any variables that are not present in `values_dict`, and
273-
# return the model's retval.
274-
retval, _ = DynamicPPL.init!!(
275-
model,
276-
varinfo,
277-
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
316+
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
317+
return map(params_with_stats) do ps
318+
first(
319+
DynamicPPL.init!!(
320+
model,
321+
varinfo,
322+
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
323+
),
278324
)
279-
retval
280325
end
281326
end
282327

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ export AbstractVarInfo,
126126
prefix,
127127
returned,
128128
to_submodel,
129+
# Struct to hold model outputs
130+
ParamsWithStats,
129131
# Convenience macros
130132
@addlogprob!,
131133
value_iterator_from_chain,
@@ -169,7 +171,6 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
169171

170172
# Necessary forward declarations
171173
include("utils.jl")
172-
include("chains.jl")
173174
include("contexts.jl")
174175
include("contexts/default.jl")
175176
include("contexts/init.jl")
@@ -193,6 +194,7 @@ include("logdensityfunction.jl")
193194
include("model_utils.jl")
194195
include("extract_priors.jl")
195196
include("values_as_in_model.jl")
197+
include("chains.jl")
196198
include("bijector.jl")
197199

198200
include("debug_utils.jl")

0 commit comments

Comments
 (0)