Skip to content

Commit 30208ec

Browse files
committed
use NamedTuple
1 parent 3dc742a commit 30208ec

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ function _params_to_array(predictive_samples)
147147
names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
148148

149149
dicts = map(predictive_samples) do t
150-
nms_and_vs = t[:values]
151-
nms = map(first, nms_and_vs)
152-
vs = map(last, nms_and_vs)
150+
varname_and_values = t.varname_and_values
151+
nms = map(first, varname_and_values)
152+
vs = map(last, varname_and_values)
153153
for nm in nms
154154
push!(names_set, nm)
155155
end
@@ -164,15 +164,11 @@ function _params_to_array(predictive_samples)
164164
return names, vals
165165
end
166166

167-
function _bundle_predictive_samples(
168-
predictive_samples::AbstractArray{
169-
<:DynamicPPL.OrderedCollections.OrderedDict{Symbol,Any}
170-
},
171-
)
167+
function _bundle_predictive_samples(predictive_samples)
172168
varnames, vals = _params_to_array(predictive_samples)
173169
varnames_symbol = map(Symbol, varnames)
174170
extra_params = [:lp]
175-
extra_values = reshape([t[:logp] for t in predictive_samples], :, 1)
171+
extra_values = reshape([t.logp for t in predictive_samples], :, 1)
176172
nms = [varnames_symbol; extra_params]
177173
parray = hcat(vals, extra_values)
178174
parray = MCMCChains.concretize(parray)

src/model.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,15 +1223,13 @@ function predict(
12231223
varinfos::AbstractArray{<:AbstractVarInfo};
12241224
include_all=false,
12251225
)
1226-
predictive_samples = similar(varinfos, OrderedDict{Symbol,Any})
1226+
predictive_samples = similar(varinfos, NamedTuple{(:varname_and_values, :logp)})
12271227
for i in eachindex(varinfos)
12281228
model(rng, varinfos[i], SampleFromPrior())
12291229
vals = values_as_in_model(model, varinfos[i])
12301230
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
1231-
params = mapreduce(collect, vcat, iters)
1232-
predictive_samples[i] = OrderedDict(
1233-
:values => params, :logp => getlogp(varinfos[i])
1234-
)
1231+
params = mapreduce(collect, vcat, iters) # returns a vector of tuples (varname, value)
1232+
predictive_samples[i] = (varname_and_values=params, logp=getlogp(varinfos[i]))
12351233
end
12361234
return predictive_samples
12371235
end

0 commit comments

Comments
 (0)