|
| 1 | +# Context version |
| 2 | +struct ElementwiseLikelihoodContext{A, Ctx} <: AbstractContext |
| 3 | + loglikelihoods::A |
| 4 | + ctx::Ctx |
| 5 | +end |
| 6 | + |
| 7 | +function ElementwiseLikelihoodContext( |
| 8 | + likelihoods = Dict{VarName, Vector{Float64}}(), |
| 9 | + ctx::AbstractContext = LikelihoodContext() |
| 10 | +) |
| 11 | + return ElementwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) |
| 12 | +end |
| 13 | + |
| 14 | +function Base.push!( |
| 15 | + ctx::ElementwiseLikelihoodContext{Dict{VarName, Vector{Float64}}}, |
| 16 | + vn::VarName, |
| 17 | + logp::Real |
| 18 | +) |
| 19 | + lookup = ctx.loglikelihoods |
| 20 | + ℓ = get!(lookup, vn, Float64[]) |
| 21 | + push!(ℓ, logp) |
| 22 | +end |
| 23 | + |
| 24 | +function Base.push!( |
| 25 | + ctx::ElementwiseLikelihoodContext{Dict{VarName, Float64}}, |
| 26 | + vn::VarName, |
| 27 | + logp::Real |
| 28 | +) |
| 29 | + ctx.loglikelihoods[vn] = logp |
| 30 | +end |
| 31 | + |
| 32 | + |
| 33 | +function tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, vn, inds, vi) |
| 34 | + return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) |
| 35 | +end |
| 36 | + |
| 37 | +function dot_tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, left, vn, inds, vi) |
| 38 | + value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) |
| 39 | + acclogp!(vi, logp) |
| 40 | + return value |
| 41 | +end |
| 42 | + |
| 43 | + |
| 44 | +function tilde_observe(ctx::ElementwiseLikelihoodContext, sampler, right, left, vname, vinds, vi) |
| 45 | + # This is slightly unfortunate since it is not completely generic... |
| 46 | + # Ideally we would call `tilde_observe` recursively but then we don't get the |
| 47 | + # loglikelihood value. |
| 48 | + logp = tilde(ctx.ctx, sampler, right, left, vi) |
| 49 | + acclogp!(vi, logp) |
| 50 | + |
| 51 | + # track loglikelihood value |
| 52 | + push!(ctx, vname, logp) |
| 53 | + |
| 54 | + return left |
| 55 | +end |
| 56 | + |
| 57 | + |
| 58 | +""" |
| 59 | + elementwise_loglikelihoods(model::Model, chain::Chains) |
| 60 | +
|
| 61 | +Runs `model` on each sample in `chain` returning an array of arrays with |
| 62 | +the i-th element inner arrays corresponding to the the likelihood of the i-th |
| 63 | +observation for that particular sample in `chain`. |
| 64 | +
|
| 65 | +# Notes |
| 66 | +Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` |
| 67 | +both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an |
| 68 | +*observation*) statements can be implemented in two ways: |
| 69 | +```julia |
| 70 | +for i in eachindex(y) |
| 71 | + y[i] ~ Normal(μ, σ) |
| 72 | +end |
| 73 | +``` |
| 74 | +or |
| 75 | +```julia |
| 76 | +y ~ MvNormal(fill(μ, n), fill(σ, n)) |
| 77 | +``` |
| 78 | +Unfortunately, just by looking at the latter statement, it's impossible to tell whether or |
| 79 | +not this is one *single* observation which is `n` dimensional OR if we have *multiple* |
| 80 | +1-dimensional observations. Therefore, `loglikelihoods` will only work with the first |
| 81 | +example. |
| 82 | +
|
| 83 | +# Examples |
| 84 | +```julia-repl |
| 85 | +julia> using DynamicPPL, Turing |
| 86 | +
|
| 87 | +julia> @model function demo(xs, y) |
| 88 | + s ~ InverseGamma(2, 3) |
| 89 | + m ~ Normal(0, √s) |
| 90 | + for i in eachindex(xs) |
| 91 | + xs[i] ~ Normal(m, √s) |
| 92 | + end |
| 93 | +
|
| 94 | + y ~ Normal(m, √s) |
| 95 | + end |
| 96 | +demo (generic function with 1 method) |
| 97 | +
|
| 98 | +julia> model = demo(randn(3), randn()); |
| 99 | +
|
| 100 | +julia> chain = sample(model, MH(), 10); |
| 101 | +
|
| 102 | +julia> DynamicPPL.elementwise_loglikelihoods(model, chain) |
| 103 | +Dict{String,Array{Float64,1}} with 4 entries: |
| 104 | + "xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716] |
| 105 | + "xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839] |
| 106 | + "xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025] |
| 107 | + "y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267] |
| 108 | +``` |
| 109 | +""" |
| 110 | +function elementwise_loglikelihoods(model::Model, chain) |
| 111 | + # Get the data by executing the model once |
| 112 | + ctx = ElementwiseLikelihoodContext() |
| 113 | + spl = SampleFromPrior() |
| 114 | + vi = VarInfo(model) |
| 115 | + |
| 116 | + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) |
| 117 | + for (sample_idx, chain_idx) in iters |
| 118 | + # Update the values |
| 119 | + setval!(vi, chain, sample_idx, chain_idx) |
| 120 | + |
| 121 | + # Execute model |
| 122 | + model(vi, spl, ctx) |
| 123 | + end |
| 124 | + return ctx.loglikelihoods |
| 125 | +end |
| 126 | + |
| 127 | +function elementwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) |
| 128 | + ctx = ElementwiseLikelihoodContext(Dict{VarName, Float64}()) |
| 129 | + model(varinfo, SampleFromPrior(), ctx) |
| 130 | + return ctx.loglikelihoods |
| 131 | +end |
0 commit comments