Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# DynamicPPL Changelog

## 0.38.8

Added a new exported struct, `DynamicPPL.ParamsWithStats`.
This can broadly be used to represent the output of a model: it consists of an `OrderedDict` of `VarName` parameters and their values, along with a `stats` NamedTuple which can hold arbitrary data, such as (but not limited to) log-probabilities.

Implemented the functions `AbstractMCMC.to_samples` and `AbstractMCMC.from_samples`, which convert between an `MCMCChains.Chains` object and a matrix of `DynamicPPL.ParamsWithStats` objects.

## 0.38.7

Made a small tweak to DynamicPPL's compiler output to avoid potential undefined variables when resuming model functions midway through (e.g. with Libtask in Turing's SMC/PG samplers).
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.38.7"
version = "0.38.8"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]

[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractMCMC = "5.10"
AbstractPPL = "0.13.1"
Accessors = "0.1"
BangBang = "0.4.1"
Expand Down
3 changes: 2 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -24,6 +25,6 @@ FillArrays = "0.13, 1"
ForwardDiff = "0.10, 1"
JET = "0.9, 0.10, 0.11"
LogDensityProblems = "2"
MarginalLogDensities = "0.4"
MCMCChains = "5, 6, 7"
MarginalLogDensities = "0.4"
StableRNGs = "1"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Distributions
using DocumenterMermaid
# load MCMCChains package extension to make `predict` available
using MCMCChains
using AbstractMCMC: AbstractMCMC
using MarginalLogDensities: MarginalLogDensities

# Need this to document a method which uses a type inside the extension...
Expand Down
26 changes: 26 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,29 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### Converting VarInfos to/from chains

It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.

This can be accomplished by first converting each `VarInfo` into a `ParamsWithStats` object:

```@docs
DynamicPPL.ParamsWithStats
```

Once you have a **matrix** of these, you can convert them into a chains object using:

```@docs
AbstractMCMC.from_samples(::Type{MCMCChains.Chains}, ::AbstractMatrix{<:DynamicPPL.ParamsWithStats})
```

If you only have a vector you can use `hcat` to convert it into an `N×1` matrix first.

Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with:

```@docs
AbstractMCMC.to_samples(::Type{DynamicPPL.ParamsWithStats}, ::MCMCChains.Chains)
```

With these, you can (for example) extract the parameter dictionaries and use `InitFromParams` to re-evaluate a model at each point in the chain.
195 changes: 120 additions & 75 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
using MCMCChains: MCMCChains

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
Expand Down Expand Up @@ -36,6 +36,110 @@ function chain_sample_to_varname_dict(
return d
end

"""
AbstractMCMC.from_samples(
::Type{MCMCChains.Chains},
params_and_stats::AbstractMatrix{<:ParamsWithStats}
)

Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
"""
function AbstractMCMC.from_samples(
::Type{MCMCChains.Chains},
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
)
# Handle parameters
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
split_dicts = map(params_and_stats) do ps
# Separate into individual VarNames.
vn_leaves_and_vals = if isempty(ps.params)
Tuple{DynamicPPL.VarName,Any}[]
else
iters = map(
AbstractPPL.varname_and_value_leaves,
keys(ps.params),
values(ps.params),
)
mapreduce(collect, vcat, iters)
end
vn_leaves = map(first, vn_leaves_and_vals)
vals = map(last, vn_leaves_and_vals)
for vn_leaf in vn_leaves
push!(all_vn_leaves, vn_leaf)
end
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
end
vn_leaves = collect(all_vn_leaves)
param_vals = [
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
]
param_symbols = map(Symbol, vn_leaves)
# Handle statistics
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
for ps in params_and_stats
for k in keys(ps.stats)
push!(stat_keys, k)
end
end
stat_keys = collect(stat_keys)
stat_vals = [
get(params_and_stats[i, j].stats, key, missing) for
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
j in eachindex(axes(params_and_stats, 2))
]
# Construct name map and info
name_map = (internals=stat_keys,)
info = (
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
zip(all_vn_leaves, param_symbols)
),
)
# Concatenate parameter and statistic values
vals = cat(param_vals, stat_vals; dims=2)
symbols = vcat(param_symbols, stat_keys)
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
end

"""
AbstractMCMC.to_samples(
::Type{DynamicPPL.ParamsWithStats},
chain::MCMCChains.Chains
)

Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`.

For this to work, `chain` must contain the `varname_to_symbol` mapping in its `info` field.
"""
function AbstractMCMC.to_samples(
::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains
)
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
# Get parameters
params_matrix = map(idxs) do (sample_idx, chain_idx)
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(chain)
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
end
d
end
# Statistics
stats_matrix = if :internals in MCMCChains.sections(chain)
internals_chain = MCMCChains.get_sections(chain, :internals)
map(idxs) do (sample_idx, chain_idx)
get(internals_chain[sample_idx, :, chain_idx], keys(internals_chain); flatten=true)
end
else
fill(NamedTuple(), size(idxs))
end
# Bundle them together
return map(idxs) do (sample_idx, chain_idx)
DynamicPPL.ParamsWithStats(
params_matrix[sample_idx, chain_idx], stats_matrix[sample_idx, chain_idx]
)
end
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -110,42 +214,24 @@ function DynamicPPL.predict(
DynamicPPL.VarInfo(),
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogJacobianAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(false),
),
)
_, varinfo = DynamicPPL.init!!(model, varinfo)
varinfo = DynamicPPL.typed_varinfo(varinfo)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
params_and_stats = AbstractMCMC.to_samples(
DynamicPPL.ParamsWithStats, parameter_only_chain
)
predictions = map(params_and_stats) do ps
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
)
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
varname_vals = mapreduce(
collect,
vcat,
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
DynamicPPL.ParamsWithStats(varinfo)
end
chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions)

chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
Expand All @@ -164,45 +250,6 @@ function DynamicPPL.predict(
)
end

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end

return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end

variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]

return variable_names, variable_values
end

function _predictive_samples_to_chains(predictive_samples)
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
variable_names_symbols = map(Symbol, variable_names)

internal_parameters = [:lp]
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)

parameter_names = [variable_names_symbols; internal_parameters]
parameter_values = hcat(variable_values, log_probabilities)
parameter_values = MCMCChains.concretize(parameter_values)

return MCMCChains.Chains(
parameter_values, parameter_names, (internals=internal_parameters,)
)
end

"""
returned(model::Model, chain::MCMCChains.Chains)

Expand Down Expand Up @@ -266,17 +313,15 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
return map(params_with_stats) do ps
first(
DynamicPPL.init!!(
model,
varinfo,
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
),
)
retval
end
end

Expand Down
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ export AbstractVarInfo,
prefix,
returned,
to_submodel,
# Struct to hold model outputs
ParamsWithStats,
# Convenience macros
@addlogprob!,
value_iterator_from_chain,
Expand Down Expand Up @@ -169,7 +171,6 @@ abstract type AbstractVarInfo <: AbstractModelTrace end

# Necessary forward declarations
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/default.jl")
include("contexts/init.jl")
Expand All @@ -193,6 +194,7 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("chains.jl")
include("bijector.jl")

include("debug_utils.jl")
Expand Down
Loading