-
Notifications
You must be signed in to change notification settings - Fork 36
Fix DynamicPPL / MCMCChains methods #1076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
7e37432
4002b08
b444419
e1489ba
2e69e0b
fc393bc
f2a83b3
0dec616
f41106e
0b1d6a6
94432eb
57e099a
802e38f
7688065
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,4 +292,245 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha | |
end | ||
end | ||
|
||
""" | ||
pointwise_logdensities( | ||
model::Model, | ||
chain::Chains, | ||
::Val{whichlogprob}=Val(:both), | ||
) | ||
|
||
Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where | ||
the log-density of each variable at each sample is stored (rather than its value). | ||
|
||
`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or | ||
`:likelihood`. | ||
|
||
See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_prior_logdensities`](@ref). | ||
|
||
# Examples | ||
|
||
```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) | ||
julia> using MCMCChains | ||
|
||
julia> @model function demo(xs, y) | ||
s ~ InverseGamma(2, 3) | ||
m ~ Normal(0, √s) | ||
for i in eachindex(xs) | ||
xs[i] ~ Normal(m, √s) | ||
end | ||
y ~ Normal(m, √s) | ||
end | ||
demo (generic function with 2 methods) | ||
|
||
julia> # Example observations. | ||
model = demo([1.0, 2.0, 3.0], [4.0]); | ||
|
||
julia> # A chain with 3 iterations. | ||
chain = Chains( | ||
reshape(1.:6., 3, 2), | ||
[:s, :m]; | ||
info=(varname_to_symbol=Dict( | ||
@varname(s) => :s, | ||
@varname(m) => :m, | ||
),), | ||
); | ||
|
||
julia> plds = pointwise_logdensities(model, chain) | ||
Chains MCMC chain (3×6×1 Array{Float64, 3}): | ||
|
||
Iterations = 1:1:3 | ||
Number of chains = 1 | ||
Samples per chain = 3 | ||
parameters = s, m, xs[1], xs[2], xs[3], y | ||
[...] | ||
|
||
julia> plds[:s] | ||
2-dimensional AxisArray{Float64,2,...} with axes: | ||
:iter, 1:1:3 | ||
:chain, 1:1 | ||
And data, a 3×1 Matrix{Float64}: | ||
-0.8027754226637804 | ||
-1.3822169643436162 | ||
-2.0986122886681096 | ||
|
||
julia> # The above is the same as: | ||
logpdf.(InverseGamma(2, 3), chain[:s]) | ||
3×1 Matrix{Float64}: | ||
-0.8027754226637804 | ||
-1.3822169643436162 | ||
-2.0986122886681096 | ||
``` | ||
""" | ||
function DynamicPPL.pointwise_logdensities( | ||
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Val{whichlogprob}=Val(:both) | ||
) where {whichlogprob} | ||
vi = DynamicPPL.VarInfo(model) | ||
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() | ||
accname = DynamicPPL.accumulator_name(acc) | ||
vi = DynamicPPL.setaccs!!(vi, (acc,)) | ||
|
||
parameter_only_chain = MCMCChains.get_sections(chain, :parameters) | ||
|
||
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
pointwise_logps = map(iters) do (sample_idx, chain_idx) | ||
# Extract values from the chain | ||
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) | ||
# Re-evaluate the model | ||
_, vi = DynamicPPL.init!!( | ||
model, | ||
vi, | ||
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | ||
) | ||
DynamicPPL.getacc(vi, Val(accname)).logps | ||
end | ||
|
||
# pointwise_logps is a matrix of OrderedDicts -- we just need to convert to a Chains | ||
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() | ||
for d in pointwise_logps | ||
union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) | ||
end | ||
new_data = [ | ||
get(pointwise_logps[iter, chain], k, missing) for | ||
iter in 1:size(pointwise_logps, 1), k in all_keys, | ||
chain in 1:size(pointwise_logps, 2) | ||
] | ||
return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) | ||
end | ||
|
||
""" | ||
pointwise_loglikelihoods(model, chain, ::Val{whichlogprob}=Val(:both)) | ||
|
||
Compute the pointwise log-likelihoods of the model given the chain. This is the same as | ||
`pointwise_logdensities(model, chain)`, but only including the likelihood terms. | ||
|
||
See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). | ||
""" | ||
function DynamicPPL.pointwise_loglikelihoods( | ||
model::DynamicPPL.Model, chain::MCMCChains.Chains | ||
) | ||
return DynamicPPL.pointwise_logdensities(model, chain, Val(:likelihood)) | ||
end | ||
|
||
""" | ||
pointwise_prior_logdensities(model, chain, ::Val{whichlogprob}=Val(:both)) | ||
|
||
Compute the pointwise log-prior-densities of the model given the chain. This is the same as | ||
`pointwise_logdensities(model, chain)`, but only including the prior terms. | ||
|
||
See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). | ||
""" | ||
function DynamicPPL.pointwise_prior_logdensities( | ||
model::DynamicPPL.Model, chain::MCMCChains.Chains | ||
) | ||
return DynamicPPL.pointwise_logdensities(model, chain, Val(:prior)) | ||
end | ||
|
||
""" | ||
logjoint(model::Model, chain::MCMCChains.Chains) | ||
|
||
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. | ||
|
||
# Examples | ||
|
||
```jldoctest | ||
julia> using MCMCChains, Distributions | ||
|
||
julia> @model function demo_model(x) | ||
s ~ InverseGamma(2, 3) | ||
m ~ Normal(0, sqrt(s)) | ||
for i in eachindex(x) | ||
x[i] ~ Normal(m, sqrt(s)) | ||
end | ||
end; | ||
|
||
julia> # construct a chain of samples using MCMCChains | ||
chain = Chains(rand(10, 2, 3), [:s, :m]); | ||
|
||
julia> logjoint(demo_model([1., 2.]), chain); | ||
``` | ||
""" | ||
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model | ||
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) | ||
argvals_dict = OrderedDict{VarName,Any}( | ||
vn_parent => DynamicPPL.values_from_chain( | ||
var_info, vn_parent, chain, chain_idx, iteration_idx | ||
) for vn_parent in keys(var_info) | ||
) | ||
DynamicPPL.logjoint(model, argvals_dict) | ||
end | ||
end | ||
|
||
""" | ||
loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
|
||
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. | ||
n | ||
# Examples | ||
|
||
```jldoctest | ||
julia> using MCMCChains, Distributions | ||
|
||
julia> @model function demo_model(x) | ||
s ~ InverseGamma(2, 3) | ||
m ~ Normal(0, sqrt(s)) | ||
for i in eachindex(x) | ||
x[i] ~ Normal(m, sqrt(s)) | ||
end | ||
end; | ||
|
||
julia> # construct a chain of samples using MCMCChains | ||
chain = Chains(rand(10, 2, 3), [:s, :m]); | ||
|
||
julia> loglikelihood(demo_model([1., 2.]), chain); | ||
|
||
``` | ||
""" | ||
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model | ||
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) | ||
argvals_dict = OrderedDict{DynamicPPL.VarName,Any}( | ||
vn_parent => DynamicPPL.values_from_chain( | ||
var_info, vn_parent, chain, chain_idx, iteration_idx | ||
) for vn_parent in keys(var_info) | ||
) | ||
DynamicPPL.loglikelihood(model, argvals_dict) | ||
end | ||
end | ||
|
||
""" | ||
logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
|
||
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. | ||
|
||
# Examples | ||
|
||
```jldoctest | ||
julia> using MCMCChains, Distributions | ||
|
||
julia> @model function demo_model(x) | ||
s ~ InverseGamma(2, 3) | ||
m ~ Normal(0, sqrt(s)) | ||
for i in eachindex(x) | ||
x[i] ~ Normal(m, sqrt(s)) | ||
end | ||
end; | ||
|
||
julia> # construct a chain of samples using MCMCChains | ||
chain = Chains(rand(10, 2, 3), [:s, :m]); | ||
|
||
julia> logprior(demo_model([1., 2.]), chain); | ||
``` | ||
""" | ||
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model | ||
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) | ||
argvals_dict = OrderedDict{VarName,Any}( | ||
vn_parent => DynamicPPL.values_from_chain( | ||
var_info, vn_parent, chain, chain_idx, iteration_idx | ||
) for vn_parent in keys(var_info) | ||
) | ||
DynamicPPL.logprior(model, argvals_dict) | ||
end | ||
end | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.