Skip to content
Merged
6 changes: 4 additions & 2 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
`loadstate` is exported from DynamicPPL.

### Change of default keytype of `pointwise_logdensities`
### Change of output type for `pointwise_logdensities`

The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`.
The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects, instead of dictionaries of matrices.
This also means that you can no longer specify the output type.
If you want to extract the matrices, you can do so by indexing into the returned `Chains` object.

**Other changes**

Expand Down
241 changes: 241 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,245 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
end
end

"""
pointwise_logdensities(
model::Model,
chain::Chains,
::Val{whichlogprob}=Val(:both),
)

Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
the log-density of each variable at each sample is stored (rather than its value).

`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
`:likelihood`.

See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_prior_logdensities`](@ref).

# Examples

```jldoctest pointwise-logdensities-chains; setup=:(using Distributions)
julia> using MCMCChains

julia> @model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
y ~ Normal(m, √s)
end
demo (generic function with 2 methods)

julia> # Example observations.
model = demo([1.0, 2.0, 3.0], [4.0]);

julia> # A chain with 3 iterations.
chain = Chains(
reshape(1.:6., 3, 2),
[:s, :m];
info=(varname_to_symbol=Dict(
@varname(s) => :s,
@varname(m) => :m,
),),
);

julia> plds = pointwise_logdensities(model, chain)
Chains MCMC chain (3×6×1 Array{Float64, 3}):

Iterations = 1:1:3
Number of chains = 1
Samples per chain = 3
parameters = s, m, xs[1], xs[2], xs[3], y
[...]

julia> plds[:s]
2-dimensional AxisArray{Float64,2,...} with axes:
:iter, 1:1:3
:chain, 1:1
And data, a 3×1 Matrix{Float64}:
-0.8027754226637804
-1.3822169643436162
-2.0986122886681096

julia> # The above is the same as:
logpdf.(InverseGamma(2, 3), chain[:s])
3×1 Matrix{Float64}:
-0.8027754226637804
-1.3822169643436162
-2.0986122886681096
```
"""
function DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Val{whichlogprob}=Val(:both)
) where {whichlogprob}
vi = DynamicPPL.VarInfo(model)
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
accname = DynamicPPL.accumulator_name(acc)
vi = DynamicPPL.setaccs!!(vi, (acc,))

parameter_only_chain = MCMCChains.get_sections(chain, :parameters)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
pointwise_logps = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Re-evaluate the model
_, vi = DynamicPPL.init!!(
model,
vi,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
DynamicPPL.getacc(vi, Val(accname)).logps
end

# pointwise_logps is a matrix of OrderedDicts -- we just need to convert to a Chains
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
for d in pointwise_logps
union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d)))
end
new_data = [
get(pointwise_logps[iter, chain], k, missing) for
iter in 1:size(pointwise_logps, 1), k in all_keys,
chain in 1:size(pointwise_logps, 2)
]
return MCMCChains.Chains(new_data, Symbol.(collect(all_keys)))
end

"""
pointwise_loglikelihoods(model, chain, ::Val{whichlogprob}=Val(:both))

Compute the pointwise log-likelihoods of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the likelihood terms.

See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref).
"""
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, chain::MCMCChains.Chains
)
return DynamicPPL.pointwise_logdensities(model, chain, Val(:likelihood))
end

"""
pointwise_prior_logdensities(model, chain, ::Val{whichlogprob}=Val(:both))

Compute the pointwise log-prior-densities of the model given the chain. This is the same as
`pointwise_logdensities(model, chain)`, but only including the prior terms.

See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref).
"""
function DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains
)
return DynamicPPL.pointwise_logdensities(model, chain, Val(:prior))
end

"""
logjoint(model::Model, chain::MCMCChains.Chains)

Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logjoint(demo_model([1., 2.]), chain);
```
"""
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict{VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
end

"""
loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)

Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
n
# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> loglikelihood(demo_model([1., 2.]), chain);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional that this doctest doesn't check the output at all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly, I'm not sure. I just copy-pasted it. Very happy to change it now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the output! Also changed the rand() to some fixed values to avoid making it too fragile.

```
"""
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.loglikelihood(model, argvals_dict)
end
end

"""
logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)

Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logprior(demo_model([1., 2.]), chain);
```
"""
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict{VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
end

end
108 changes: 0 additions & 108 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1058,42 +1058,6 @@ function logjoint(model::Model, varinfo::AbstractVarInfo)
return getlogjoint(last(evaluate!!(model, varinfo)))
end

"""
logjoint(model::Model, chain::AbstractMCMC.AbstractChains)

Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logjoint(demo_model([1., 2.]), chain);
```
"""
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
logjoint(model, argvals_dict)
end
end

"""
logprior(model::Model, varinfo::AbstractVarInfo)

Expand All @@ -1116,42 +1080,6 @@ function logprior(model::Model, varinfo::AbstractVarInfo)
return getlogprior(last(evaluate!!(model, varinfo)))
end

"""
logprior(model::Model, chain::AbstractMCMC.AbstractChains)

Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logprior(demo_model([1., 2.]), chain);
```
"""
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
logprior(model, argvals_dict)
end
end

"""
loglikelihood(model::Model, varinfo::AbstractVarInfo)

Expand All @@ -1170,42 +1098,6 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
return getloglikelihood(last(evaluate!!(model, varinfo)))
end

"""
loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)

Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> loglikelihood(demo_model([1., 2.]), chain);
```
"""
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
loglikelihood(model, argvals_dict)
end
end

# Implemented & documented in DynamicPPLMCMCChainsExt
function predict end

Expand Down
Loading
Loading