11module DynamicPPLMCMCChainsExt
22
3- using DynamicPPL: DynamicPPL, AbstractPPL
3+ using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
44using MCMCChains: MCMCChains
55
66_has_varname_to_symbol (info:: NamedTuple{names} ) where {names} = :varname_to_symbol in names
@@ -36,6 +36,110 @@ function chain_sample_to_varname_dict(
3636 return d
3737end
3838
39+ """
40+ AbstractMCMC.from_samples(
41+ ::Type{MCMCChains.Chains},
42+ params_and_stats::AbstractMatrix{<:ParamsWithStats}
43+ )
44+
45+ Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+ """
47+ function AbstractMCMC. from_samples (
48+ :: Type{MCMCChains.Chains} ,
49+ params_and_stats:: AbstractMatrix{<:DynamicPPL.ParamsWithStats} ,
50+ )
51+ # Handle parameters
52+ all_vn_leaves = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
53+ split_dicts = map (params_and_stats) do ps
54+ # Separate into individual VarNames.
55+ vn_leaves_and_vals = if isempty (ps. params)
56+ Tuple{DynamicPPL. VarName,Any}[]
57+ else
58+ iters = map (
59+ AbstractPPL. varname_and_value_leaves,
60+ keys (ps. params),
61+ values (ps. params),
62+ )
63+ mapreduce (collect, vcat, iters)
64+ end
65+ vn_leaves = map (first, vn_leaves_and_vals)
66+ vals = map (last, vn_leaves_and_vals)
67+ for vn_leaf in vn_leaves
68+ push! (all_vn_leaves, vn_leaf)
69+ end
70+ DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
71+ end
72+ vn_leaves = collect (all_vn_leaves)
73+ param_vals = [
74+ get (split_dicts[i, j], key, missing ) for i in eachindex (axes (split_dicts, 1 )),
75+ key in vn_leaves, j in eachindex (axes (split_dicts, 2 ))
76+ ]
77+ param_symbols = map (Symbol, vn_leaves)
78+ # Handle statistics
79+ stat_keys = DynamicPPL. OrderedCollections. OrderedSet {Symbol} ()
80+ for ps in params_and_stats
81+ for k in keys (ps. stats)
82+ push! (stat_keys, k)
83+ end
84+ end
85+ stat_keys = collect (stat_keys)
86+ stat_vals = [
87+ get (params_and_stats[i, j]. stats, key, missing ) for
88+ i in eachindex (axes (params_and_stats, 1 )), key in stat_keys,
89+ j in eachindex (axes (params_and_stats, 2 ))
90+ ]
91+ # Construct name map and info
92+ name_map = (internals= stat_keys,)
93+ info = (
94+ varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
95+ zip (all_vn_leaves, param_symbols)
96+ ),
97+ )
98+ # Concatenate parameter and statistic values
99+ vals = cat (param_vals, stat_vals; dims= 2 )
100+ symbols = vcat (param_symbols, stat_keys)
101+ return MCMCChains. Chains (MCMCChains. concretize (vals), symbols, name_map; info= info)
102+ end
103+
104+ """
105+ AbstractMCMC.to_samples(
106+ ::Type{DynamicPPL.ParamsWithStats},
107+ chain::MCMCChains.Chains
108+ )
109+
110+ Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`.
111+
112+ For this to work, `chain` must contain the `varname_to_symbol` mapping in its `info` field.
113+ """
114+ function AbstractMCMC. to_samples (
115+ :: Type{DynamicPPL.ParamsWithStats} , chain:: MCMCChains.Chains
116+ )
117+ idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
118+ # Get parameters
119+ params_matrix = map (idxs) do (sample_idx, chain_idx)
120+ d = DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} ()
121+ for vn in DynamicPPL. varnames (chain)
122+ d[vn] = DynamicPPL. getindex_varname (chain, sample_idx, vn, chain_idx)
123+ end
124+ d
125+ end
126+ # Statistics
127+ stats_matrix = if :internals in MCMCChains. sections (chain)
128+ internals_chain = MCMCChains. get_sections (chain, :internals )
129+ map (idxs) do (sample_idx, chain_idx)
130+ get (internals_chain[sample_idx, :, chain_idx], keys (internals_chain); flatten= true )
131+ end
132+ else
133+ fill (NamedTuple (), size (idxs))
134+ end
135+ # Bundle them together
136+ return map (idxs) do (sample_idx, chain_idx)
137+ DynamicPPL. ParamsWithStats (
138+ params_matrix[sample_idx, chain_idx], stats_matrix[sample_idx, chain_idx]
139+ )
140+ end
141+ end
142+
39143"""
40144 predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41145
@@ -110,42 +214,24 @@ function DynamicPPL.predict(
110214 DynamicPPL. VarInfo (),
111215 (
112216 DynamicPPL. LogPriorAccumulator (),
113- DynamicPPL. LogJacobianAccumulator (),
114217 DynamicPPL. LogLikelihoodAccumulator (),
115218 DynamicPPL. ValuesAsInModelAccumulator (false ),
116219 ),
117220 )
118221 _, varinfo = DynamicPPL. init!! (model, varinfo)
119222 varinfo = DynamicPPL. typed_varinfo (varinfo)
120223
121- iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
122- predictive_samples = map (iters) do (sample_idx, chain_idx)
123- # Extract values from the chain
124- values_dict = chain_sample_to_varname_dict (parameter_only_chain, sample_idx, chain_idx)
125- # Resample any variables that are not present in `values_dict`
224+ params_and_stats = AbstractMCMC. to_samples (
225+ DynamicPPL. ParamsWithStats, parameter_only_chain
226+ )
227+ predictions = map (params_and_stats) do ps
126228 _, varinfo = DynamicPPL. init!! (
127- rng,
128- model,
129- varinfo,
130- DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
229+ rng, model, varinfo, DynamicPPL. InitFromParams (ps. params)
131230 )
132- vals = DynamicPPL. getacc (varinfo, Val (:ValuesAsInModel )). values
133- varname_vals = mapreduce (
134- collect,
135- vcat,
136- map (AbstractPPL. varname_and_value_leaves, keys (vals), values (vals)),
137- )
138-
139- return (varname_and_values= varname_vals, logp= DynamicPPL. getlogjoint (varinfo))
231+ DynamicPPL. ParamsWithStats (varinfo)
140232 end
233+ chain_result = AbstractMCMC. from_samples (MCMCChains. Chains, predictions)
141234
142- chain_result = reduce (
143- MCMCChains. chainscat,
144- [
145- _predictive_samples_to_chains (predictive_samples[:, chain_idx]) for
146- chain_idx in 1 : size (predictive_samples, 2 )
147- ],
148- )
149235 parameter_names = if include_all
150236 MCMCChains. names (chain_result, :parameters )
151237 else
@@ -164,45 +250,6 @@ function DynamicPPL.predict(
164250 )
165251end
166252
167- function _predictive_samples_to_arrays (predictive_samples)
168- variable_names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
169-
170- sample_dicts = map (predictive_samples) do sample
171- varname_value_pairs = sample. varname_and_values
172- varnames = map (first, varname_value_pairs)
173- values = map (last, varname_value_pairs)
174- for varname in varnames
175- push! (variable_names_set, varname)
176- end
177-
178- return DynamicPPL. OrderedCollections. OrderedDict (zip (varnames, values))
179- end
180-
181- variable_names = collect (variable_names_set)
182- variable_values = [
183- get (sample_dicts[i], key, missing ) for i in eachindex (sample_dicts),
184- key in variable_names
185- ]
186-
187- return variable_names, variable_values
188- end
189-
190- function _predictive_samples_to_chains (predictive_samples)
191- variable_names, variable_values = _predictive_samples_to_arrays (predictive_samples)
192- variable_names_symbols = map (Symbol, variable_names)
193-
194- internal_parameters = [:lp ]
195- log_probabilities = reshape ([sample. logp for sample in predictive_samples], :, 1 )
196-
197- parameter_names = [variable_names_symbols; internal_parameters]
198- parameter_values = hcat (variable_values, log_probabilities)
199- parameter_values = MCMCChains. concretize (parameter_values)
200-
201- return MCMCChains. Chains (
202- parameter_values, parameter_names, (internals= internal_parameters,)
203- )
204- end
205-
206253"""
207254 returned(model::Model, chain::MCMCChains.Chains)
208255
@@ -266,17 +313,15 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
266313 chain = MCMCChains. get_sections (chain_full, :parameters )
267314 varinfo = DynamicPPL. VarInfo (model)
268315 iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
269- return map (iters) do (sample_idx, chain_idx)
270- # Extract values from the chain
271- values_dict = chain_sample_to_varname_dict (chain, sample_idx, chain_idx)
272- # Resample any variables that are not present in `values_dict`, and
273- # return the model's retval.
274- retval, _ = DynamicPPL. init!! (
275- model,
276- varinfo,
277- DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
316+ params_with_stats = AbstractMCMC. to_samples (DynamicPPL. ParamsWithStats, chain)
317+ return map (params_with_stats) do ps
318+ first (
319+ DynamicPPL. init!! (
320+ model,
321+ varinfo,
322+ DynamicPPL. InitFromParams (ps. params, DynamicPPL. InitFromPrior ()),
323+ ),
278324 )
279- retval
280325 end
281326end
282327
0 commit comments