Skip to content

Commit fcd7c3d

Browse files
committed
stop using PredictiveSample type
1 parent 53b6749 commit fcd7c3d

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function DynamicPPL.predict(
128128
chain_result = reduce(
129129
MCMCChains.chainscat,
130130
[
131-
_bundle_samples(predictive_samples[:, chain_idx]) for
131+
_bundle_predictive_samples(predictive_samples[:, chain_idx]) for
132132
chain_idx in 1:size(predictive_samples, 2)
133133
],
134134
)
@@ -143,11 +143,11 @@ function DynamicPPL.predict(
143143
return chain_result[parameter_names]
144144
end
145145

146-
function _params_to_array(ts::Vector)
146+
function _params_to_array(predictive_samples)
147147
names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
148148

149-
dicts = map(ts) do t
150-
nms_and_vs = t.values
149+
dicts = map(predictive_samples) do t
150+
nms_and_vs = t[:values]
151151
nms = map(first, nms_and_vs)
152152
vs = map(last, nms_and_vs)
153153
for nm in nms
@@ -164,11 +164,15 @@ function _params_to_array(ts::Vector)
164164
return names, vals
165165
end
166166

167-
function _bundle_samples(ts::Vector{<:DynamicPPL.PredictiveSample})
168-
varnames, vals = _params_to_array(ts)
167+
function _bundle_predictive_samples(
168+
predictive_samples::AbstractArray{
169+
<:DynamicPPL.OrderedCollections.OrderedDict{Symbol,Any}
170+
},
171+
)
172+
varnames, vals = _params_to_array(predictive_samples)
169173
varnames_symbol = map(Symbol, varnames)
170174
extra_params = [:lp]
171-
extra_values = reshape([t.logp for t in ts], :, 1)
175+
extra_values = reshape([t[:logp] for t in predictive_samples], :, 1)
172176
nms = [varnames_symbol; extra_params]
173177
parray = hcat(vals, extra_values)
174178
parray = MCMCChains.concretize(parray)

src/model.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,11 +1203,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
12031203
end
12041204
end
12051205

1206-
struct PredictiveSample{T,F}
1207-
values::T
1208-
logp::F
1209-
end
1210-
12111206
"""
12121207
predict([rng::AbstractRNG,] model::Model, chain; include_all=false)
12131208
@@ -1228,13 +1223,15 @@ function predict(
12281223
varinfos::AbstractArray{<:AbstractVarInfo};
12291224
include_all=false,
12301225
)
1231-
predictive_samples = Array{PredictiveSample}(undef, size(varinfos))
1226+
predictive_samples = similar(varinfos, OrderedDict{Symbol,Any})
12321227
for i in eachindex(varinfos)
12331228
model(rng, varinfos[i], SampleFromPrior())
12341229
vals = values_as_in_model(model, varinfos[i])
12351230
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
12361231
params = mapreduce(collect, vcat, iters)
1237-
predictive_samples[i] = PredictiveSample(params, getlogp(varinfos[i]))
1232+
predictive_samples[i] = OrderedDict(
1233+
:values => params, :logp => getlogp(varinfos[i])
1234+
)
12381235
end
12391236
return predictive_samples
12401237
end

0 commit comments

Comments
 (0)