Skip to content
Merged
6 changes: 4 additions & 2 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
`loadstate` is exported from DynamicPPL.

### Change of default keytype of `pointwise_logdensities`
### Change of output type for `pointwise_logdensities`

The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`.
The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects by default, instead of dictionaries of matrices.

If you want the old behaviour, you can pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chain, OrderedDict)`.

**Other changes**

Expand Down
286 changes: 286 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,290 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
end
end

"""
DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains
::Val{whichlogprob}=Val(:both),
)

Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
the log-density of each variable at each sample is stored (rather than its value).

`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
`:likelihood`.

You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName,
Matrix{Float64}}` instead.

See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref),
[`DynamicPPL.pointwise_prior_logdensities`](@ref).

# Examples

```jldoctest pointwise-logdensities-chains; setup=:(using Distributions)
julia> using MCMCChains

julia> @model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
y ~ Normal(m, √s)
end
demo (generic function with 2 methods)

julia> # Example observations.
model = demo([1.0, 2.0, 3.0], [4.0]);

julia> # A chain with 3 iterations.
chain = Chains(
reshape(1.:6., 3, 2),
[:s, :m];
info=(varname_to_symbol=Dict(
@varname(s) => :s,
@varname(m) => :m,
),),
);

julia> plds = pointwise_logdensities(model, chain)
Chains MCMC chain (3×6×1 Array{Float64, 3}):

Iterations = 1:1:3
Number of chains = 1
Samples per chain = 3
parameters = s, m, xs[1], xs[2], xs[3], y
[...]

julia> plds[:s]
2-dimensional AxisArray{Float64,2,...} with axes:
:iter, 1:1:3
:chain, 1:1
And data, a 3×1 Matrix{Float64}:
-0.8027754226637804
-1.3822169643436162
-2.0986122886681096

julia> # The above is the same as:
logpdf.(InverseGamma(2, 3), chain[:s])
3×1 Matrix{Float64}:
-0.8027754226637804
-1.3822169643436162
-2.0986122886681096
```

julia> # Alternatively:
plds_dict = pointwise_logdensities(model, chain, OrderedDict)
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
s => [-0.802775; -1.38222; -2.09861;;]
m => [-8.91894; -7.51551; -7.46824;;]
xs[1] => [-5.41894; -5.26551; -5.63491;;]
xs[2] => [-2.91894; -3.51551; -4.13491;;]
xs[3] => [-1.41894; -2.26551; -2.96824;;]
y => [-0.918939; -1.51551; -2.13491;;]
"""
function DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains,
::Val{whichlogprob}=Val(:both),
) where {whichlogprob,Tout}
vi = DynamicPPL.VarInfo(model)
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
accname = DynamicPPL.accumulator_name(acc)
vi = DynamicPPL.setaccs!!(vi, (acc,))
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
pointwise_logps = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Re-evaluate the model
_, vi = DynamicPPL.init!!(
model,
vi,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
DynamicPPL.getacc(vi, Val(accname)).logps
end

# pointwise_logps is a matrix of OrderedDicts
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
for d in pointwise_logps
union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d)))
end
# this is a 3D array: (iterations, variables, chains)
new_data = [
get(pointwise_logps[iter, chain], k, missing) for
iter in 1:size(pointwise_logps, 1), k in all_keys,
chain in 1:size(pointwise_logps, 2)
]

if Tout == MCMCChains.Chains
return MCMCChains.Chains(new_data, Symbol.(collect(all_keys)))
elseif Tout <: AbstractDict
return Tout{DynamicPPL.VarName,Matrix{Float64}}(
k => new_data[:, i, :] for (i, k) in enumerate(all_keys)
)
end
end

"""
DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains
)

Compute the pointwise log-likelihoods of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the likelihood terms.

See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref).
"""
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
) where {Tout}
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood))
end

"""
DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains
)

Compute the pointwise log-prior-densities of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the prior terms.

See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref).
"""
function DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
) where {Tout}
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior))
end

"""
logjoint(model::Model, chain::MCMCChains.Chains)

Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # Construct a chain of samples using MCMCChains.
# This sets s = 0.5 and m = 1.0 for all three samples.
chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);

julia> logjoint(demo_model([1., 2.]), chain)
3×1 Matrix{Float64}:
-5.440428709758045
-5.440428709758045
-5.440428709758045
```
"""
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
end

"""
loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)

Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # Construct a chain of samples using MCMCChains.
# This sets s = 0.5 and m = 1.0 for all three samples.
chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);

julia> loglikelihood(demo_model([1., 2.]), chain)
3×1 Matrix{Float64}:
-2.1447298858494
-2.1447298858494
-2.1447298858494
```
"""
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.loglikelihood(model, argvals_dict)
end
end

"""
logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)

Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # Construct a chain of samples using MCMCChains.
# This sets s = 0.5 and m = 1.0 for all three samples.
chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);

julia> logprior(demo_model([1., 2.]), chain)
3×1 Matrix{Float64}:
-3.2956988239086447
-3.2956988239086447
-3.2956988239086447
```
"""
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
end

end
Loading
Loading