Skip to content

improve implementation of predict(...; include_all) #1042

@penelopeysm

Description

@penelopeysm

#984 uses init!! to implement predict. However, the implementation of include_all=false seems a bit wasteful because it first constructs a chain using all parameters (including the ones we don't want) before then subsetting the chain. It seems more sensible to, inside the loop, filter the dictionary of varname => value pairs in each iteration so that those variables don't end up in the chain to begin with.

predictive_samples = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
vcat,
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
)
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
end
chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
filter(
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
names(chain_result, :parameters),
)
end
return chain_result[parameter_names]
end

Not making this change in #984 to avoid complicating matters.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions