|
| 1 | +struct LatentHandlingContext{Ctx<:AbstractContext} <: AbstractContext |
| 2 | + context::Ctx |
| 3 | +end |
| 4 | + |
| 5 | +LatentHandlingContext() = LatentHandlingContext(DefaultContext()) |
| 6 | + |
| 7 | +NodeTrait(context::LatentHandlingContext) = IsParent() |
| 8 | +childcontext(context::LatentHandlingContext) = context.context |
| 9 | +function setchildcontext(context::LatentHandlingContext, child::AbstractContext) |
| 10 | + return LatentHandlingContext(child) |
| 11 | +end |
| 12 | + |
| 13 | +""" |
| 14 | + latent(dist) |
| 15 | +
|
| 16 | +Return a distribution for the latent parameters of `dist`. |
| 17 | +""" |
| 18 | +function latent end |
| 19 | + |
| 20 | +""" |
| 21 | + conditional(dist, latents) |
| 22 | +
|
| 23 | +Return the distribution of emissions with the latent parameters of `dist` set to `latents`. |
| 24 | +""" |
| 25 | +function conditional end |
| 26 | + |
| 27 | +""" |
| 28 | + marginalize(dist) |
| 29 | +
|
| 30 | +Return the `dist` with the latent parameters marginalized out. |
| 31 | +""" |
| 32 | +function marginalize end |
| 33 | + |
| 34 | +""" |
| 35 | + has_latents(dist) |
| 36 | +
|
| 37 | +Return `true` if the distribution `dist` has latent parameters, otherwise `false`. |
| 38 | +
|
| 39 | +Note that if `has_latents(dist) = true`, then `dist` is assumed to implement the following methods: |
| 40 | +1. `latent(dist)`: Return the latent parameters of the distribution. |
| 41 | +2. `conditional(dist, latents)`: Return a new distribution with the latent parameters set to `latents`. |
| 42 | +3. `marginalize(dist)`: Return a new distribution with the latent parameters marginalized out. |
| 43 | +""" |
| 44 | +has_latents(dist) = false |
| 45 | + |
| 46 | +# Overload the tilde-statements to handle latent parameters. |
| 47 | +function suffix_varname(vn::VarName{sym}, ::Val{suffix}) where {sym,suffix} |
| 48 | + return VarName{Symbol(sym, ".", suffix)}(vn.optic) |
| 49 | +end |
| 50 | + |
| 51 | +# Cand dispatch on `dist` to choose different suffixes for latent variables. |
| 52 | +suffix_latent_varname(dist, vn) = suffix_varname(vn, Val{:latent}()) |
| 53 | + |
| 54 | +# `tilde_assume` |
| 55 | +function tilde_assume(context::LatentHandlingContext, right, vn, vi) |
| 56 | + has_latents(right) || return tilde_assume(childcontext(context), right, vn, vi) |
| 57 | + # Execute `tilde_assume` for the latent variables first. |
| 58 | + right_latent = latent(right) |
| 59 | + value_latent, logp_marginal, vi = tilde_assume( |
| 60 | + childcontext(context), right_latent, suffix_latent_varname(right, vn), vi |
| 61 | + ) |
| 62 | + # Now execute the conditional on the latent variables. |
| 63 | + right_conditional = conditional(right, value_latent) |
| 64 | + value_conditional, logp_conditional, vi = tilde_assume( |
| 65 | + childcontext(context), right_conditional, vn, vi |
| 66 | + ) |
| 67 | + # Return as usual. |
| 68 | + return value_conditional, logp_marginal + logp_conditional, vi |
| 69 | +end |
| 70 | +function tilde_assume( |
| 71 | + rng::Random.AbstractRNG, context::LatentHandlingContext, sampler, right, vn, vi |
| 72 | +) |
| 73 | + if !has_latents(right) |
| 74 | + return tilde_assume(rng, childcontext(context), sampler, right, vn, vi) |
| 75 | + end |
| 76 | + # Execute `tilde_assume` for the latent variables first. |
| 77 | + right_latent = latent(right) |
| 78 | + value_latent, logp_marginal, vi = tilde_assume( |
| 79 | + rng, |
| 80 | + childcontext(context), |
| 81 | + sampler, |
| 82 | + right_latent, |
| 83 | + suffix_latent_varname(right, vn), |
| 84 | + vi, |
| 85 | + ) |
| 86 | + # Now execute the conditional on the latent variables. |
| 87 | + right_conditional = conditional(right, value_latent) |
| 88 | + value_conditional, logp_conditional, vi = tilde_assume( |
| 89 | + rng, childcontext(context), sampler, right_conditional, vn, vi |
| 90 | + ) |
| 91 | + # Return as usual. |
| 92 | + return value_conditional, logp_marginal + logp_conditional, vi |
| 93 | +end |
| 94 | +# `tilde_observe` |
| 95 | +function tilde_observe(context::LatentHandlingContext, right, left, vi) |
| 96 | + has_latents(right) || return tilde_observe(childcontext(context), right, left, vi) |
| 97 | + # When used as `observe`, we want to use the marginalized version. |
| 98 | + right_marginal = marginalize(right) |
| 99 | + return tilde_observe(childcontext(context), right_marginal, left, vi) |
| 100 | +end |
| 101 | +function tilde_observe(context::LatentHandlingContext, sampler, right, left, vi) |
| 102 | + if !has_latents(right) |
| 103 | + return tilde_observe(childcontext(context), sampler, right, left, vi) |
| 104 | + end |
| 105 | + # When used as `observe`, we want to use the marginalized version. |
| 106 | + right_marginal = marginalize(right) |
| 107 | + return tilde_observe(childcontext(context), sampler, right_marginal, left, vi) |
| 108 | +end |
0 commit comments