-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
#984 uses init!!
to implement predict
. However, the implementation of include_all=false
seems a bit wasteful because it first constructs a chain using all parameters (including the ones we don't want) before then subsetting the chain. It seems more sensible to, inside the loop, filter the dictionary of varname => value pairs in each iteration so that those variables don't end up in the chain to begin with.
DynamicPPL.jl/ext/DynamicPPLMCMCChainsExt.jl
Lines 122 to 158 in 956ed54
predictive_samples = map(iters) do (sample_idx, chain_idx) | |
# Extract values from the chain | |
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) | |
# Resample any variables that are not present in `values_dict` | |
_, varinfo = DynamicPPL.init!!( | |
rng, | |
model, | |
varinfo, | |
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | |
) | |
vals = DynamicPPL.values_as_in_model(model, false, varinfo) | |
varname_vals = mapreduce( | |
collect, | |
vcat, | |
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), | |
) | |
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) | |
end | |
chain_result = reduce( | |
MCMCChains.chainscat, | |
[ | |
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for | |
chain_idx in 1:size(predictive_samples, 2) | |
], | |
) | |
parameter_names = if include_all | |
MCMCChains.names(chain_result, :parameters) | |
else | |
filter( | |
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)), | |
names(chain_result, :parameters), | |
) | |
end | |
return chain_result[parameter_names] | |
end |
Not making this change in #984 to avoid complicating matters.
Metadata
Metadata
Assignees
Labels
No labels