Skip to content
Merged
6 changes: 4 additions & 2 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
`loadstate` is exported from DynamicPPL.

### Change of default keytype of `pointwise_logdensities`
### Change of output type for `pointwise_logdensities`

The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`.
The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects by default, instead of dictionaries of matrices.

If you want the old behaviour, you can pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chain, OrderedDict)`.

**Other changes**

Expand Down
272 changes: 272 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,276 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
end
end

"""
DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains
::Val{whichlogprob}=Val(:both),
)

Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
the log-density of each variable at each sample is stored (rather than its value).

`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
`:likelihood`.

You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName,
Matrix{Float64}}` instead.

See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref),
[`DynamicPPL.pointwise_prior_logdensities`](@ref).

# Examples

```jldoctest pointwise-logdensities-chains; setup=:(using Distributions)
julia> using MCMCChains

julia> @model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
y ~ Normal(m, √s)
end
demo (generic function with 2 methods)

julia> # Example observations.
model = demo([1.0, 2.0, 3.0], [4.0]);

julia> # A chain with 3 iterations.
chain = Chains(
reshape(1.:6., 3, 2),
[:s, :m];
info=(varname_to_symbol=Dict(
@varname(s) => :s,
@varname(m) => :m,
),),
);

julia> plds = pointwise_logdensities(model, chain)
Chains MCMC chain (3×6×1 Array{Float64, 3}):

Iterations = 1:1:3
Number of chains = 1
Samples per chain = 3
parameters = s, m, xs[1], xs[2], xs[3], y
[...]

julia> plds[:s]
2-dimensional AxisArray{Float64,2,...} with axes:
:iter, 1:1:3
:chain, 1:1
And data, a 3×1 Matrix{Float64}:
-0.8027754226637804
-1.3822169643436162
-2.0986122886681096

julia> # The above is the same as:
logpdf.(InverseGamma(2, 3), chain[:s])
3×1 Matrix{Float64}:
-0.8027754226637804
-1.3822169643436162
-2.0986122886681096
```

julia> # Alternatively:
plds_dict = pointwise_logdensities(model, chain, OrderedDict)
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
s => [-0.802775; -1.38222; -2.09861;;]
m => [-8.91894; -7.51551; -7.46824;;]
xs[1] => [-5.41894; -5.26551; -5.63491;;]
xs[2] => [-2.91894; -3.51551; -4.13491;;]
xs[3] => [-1.41894; -2.26551; -2.96824;;]
y => [-0.918939; -1.51551; -2.13491;;]
"""
function DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains,
::Val{whichlogprob}=Val(:both),
) where {whichlogprob,Tout}
vi = DynamicPPL.VarInfo(model)
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
accname = DynamicPPL.accumulator_name(acc)
vi = DynamicPPL.setaccs!!(vi, (acc,))
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
pointwise_logps = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Re-evaluate the model
_, vi = DynamicPPL.init!!(
model,
vi,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
DynamicPPL.getacc(vi, Val(accname)).logps
end

# pointwise_logps is a matrix of OrderedDicts
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
for d in pointwise_logps
union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d)))
end
# this is a 3D array: (iterations, variables, chains)
new_data = [
get(pointwise_logps[iter, chain], k, missing) for
iter in 1:size(pointwise_logps, 1), k in all_keys,
chain in 1:size(pointwise_logps, 2)
]

if Tout == MCMCChains.Chains
return MCMCChains.Chains(new_data, Symbol.(collect(all_keys)))
elseif Tout <: AbstractDict
return Tout{DynamicPPL.VarName,Matrix{Float64}}(
k => new_data[:, i, :] for (i, k) in enumerate(all_keys)
)
end
end

"""
DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains
)

Compute the pointwise log-likelihoods of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the likelihood terms.

See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref).
"""
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
) where {Tout}
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood))
end

"""
DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains
)

Compute the pointwise log-prior-densities of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the prior terms.

See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref).
"""
function DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
) where {Tout}
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior))
end

"""
logjoint(model::Model, chain::MCMCChains.Chains)

Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logjoint(demo_model([1., 2.]), chain);
```
"""
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
end

"""
loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)

Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
n
# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> loglikelihood(demo_model([1., 2.]), chain);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional that this doctest doesn't check the output at all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly, I'm not sure. I just copy-pasted it. Very happy to change it now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the output! Also changed the rand() to some fixed values to avoid making it too fragile.

```
"""
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.loglikelihood(model, argvals_dict)
end
end

"""
logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)

Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logprior(demo_model([1., 2.]), chain);
```
"""
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
end

end
Loading
Loading