@@ -128,7 +128,7 @@ function DynamicPPL.predict(
128128 chain_result = reduce (
129129 MCMCChains. chainscat,
130130 [
131- _bundle_samples (predictive_samples[:, chain_idx]) for
131+ _bundle_predictive_samples (predictive_samples[:, chain_idx]) for
132132 chain_idx in 1 : size (predictive_samples, 2 )
133133 ],
134134 )
@@ -143,11 +143,11 @@ function DynamicPPL.predict(
143143 return chain_result[parameter_names]
144144end
145145
146- function _params_to_array (ts :: Vector )
146+ function _params_to_array (predictive_samples )
147147 names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
148148
149- dicts = map (ts ) do t
150- nms_and_vs = t. values
149+ dicts = map (predictive_samples ) do t
150+ nms_and_vs = t[ : values]
151151 nms = map (first, nms_and_vs)
152152 vs = map (last, nms_and_vs)
153153 for nm in nms
@@ -164,11 +164,15 @@ function _params_to_array(ts::Vector)
164164 return names, vals
165165end
166166
167- function _bundle_samples (ts:: Vector{<:DynamicPPL.PredictiveSample} )
168- varnames, vals = _params_to_array (ts)
167+ function _bundle_predictive_samples (
168+ predictive_samples:: AbstractArray {
169+ <: DynamicPPL.OrderedCollections.OrderedDict{Symbol,Any}
170+ },
171+ )
172+ varnames, vals = _params_to_array (predictive_samples)
169173 varnames_symbol = map (Symbol, varnames)
170174 extra_params = [:lp ]
171- extra_values = reshape ([t. logp for t in ts ], :, 1 )
175+ extra_values = reshape ([t[ : logp] for t in predictive_samples ], :, 1 )
172176 nms = [varnames_symbol; extra_params]
173177 parray = hcat (vals, extra_values)
174178 parray = MCMCChains. concretize (parray)
0 commit comments