Skip to content

Commit 78877b4

Browse files
committed
initial work on handling latent parameters
1 parent 0639681 commit 78877b4

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ include("logdensityfunction.jl")
181181
include("model_utils.jl")
182182
include("extract_priors.jl")
183183
include("values_as_in_model.jl")
184+
include("latent_handling.jl")
184185

185186
include("debug_utils.jl")
186187
using .DebugUtils

src/latent_handling.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
struct LatentHandlingContext{Ctx<:AbstractContext} <: AbstractContext
2+
context::Ctx
3+
end
4+
5+
LatentHandlingContext() = LatentHandlingContext(DefaultContext())
6+
7+
NodeTrait(context::LatentHandlingContext) = IsParent()
8+
childcontext(context::LatentHandlingContext) = context.context
9+
function setchildcontext(context::LatentHandlingContext, child::AbstractContext)
10+
return LatentHandlingContext(child)
11+
end
12+
13+
"""
14+
latent(dist)
15+
16+
Return a distribution for the latent parameters of `dist`.
17+
"""
18+
function latent end
19+
20+
"""
21+
conditional(dist, latents)
22+
23+
Return the distribution of emissions with the latent parameters of `dist` set to `latents`.
24+
"""
25+
function conditional end
26+
27+
"""
28+
marginalize(dist)
29+
30+
Return the `dist` with the latent parameters marginalized out.
31+
"""
32+
function marginalize end
33+
34+
"""
35+
has_latents(dist)
36+
37+
Return `true` if the distribution `dist` has latent parameters, otherwise `false`.
38+
39+
Note that if `has_latents(dist) = true`, then `dist` is assumed to implement the following methods:
40+
1. `latent(dist)`: Return the latent parameters of the distribution.
41+
2. `conditional(dist, latents)`: Return a new distribution with the latent parameters set to `latents`.
42+
3. `marginalize(dist)`: Return a new distribution with the latent parameters marginalized out.
43+
"""
44+
has_latents(dist) = false
45+
46+
# Overload the tilde-statements to handle latent parameters.
47+
function suffix_varname(vn::VarName{sym}, ::Val{suffix}) where {sym,suffix}
48+
return VarName{Symbol(sym, ".", suffix)}(vn.optic)
49+
end
50+
51+
# Cand dispatch on `dist` to choose different suffixes for latent variables.
52+
suffix_latent_varname(dist, vn) = suffix_varname(vn, Val{:latent}())
53+
54+
# `tilde_assume`
55+
function tilde_assume(context::LatentHandlingContext, right, vn, vi)
56+
has_latents(right) || return tilde_assume(childcontext(context), right, vn, vi)
57+
# Execute `tilde_assume` for the latent variables first.
58+
right_latent = latent(right)
59+
value_latent, logp_marginal, vi = tilde_assume(
60+
childcontext(context), right_latent, suffix_latent_varname(right, vn), vi
61+
)
62+
# Now execute the conditional on the latent variables.
63+
right_conditional = conditional(right, value_latent)
64+
value_conditional, logp_conditional, vi = tilde_assume(
65+
childcontext(context), right_conditional, vn, vi
66+
)
67+
# Return as usual.
68+
return value_conditional, logp_marginal + logp_conditional, vi
69+
end
70+
function tilde_assume(
71+
rng::Random.AbstractRNG, context::LatentHandlingContext, sampler, right, vn, vi
72+
)
73+
if !has_latents(right)
74+
return tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
75+
end
76+
# Execute `tilde_assume` for the latent variables first.
77+
right_latent = latent(right)
78+
value_latent, logp_marginal, vi = tilde_assume(
79+
rng,
80+
childcontext(context),
81+
sampler,
82+
right_latent,
83+
suffix_latent_varname(right, vn),
84+
vi,
85+
)
86+
# Now execute the conditional on the latent variables.
87+
right_conditional = conditional(right, value_latent)
88+
value_conditional, logp_conditional, vi = tilde_assume(
89+
rng, childcontext(context), sampler, right_conditional, vn, vi
90+
)
91+
# Return as usual.
92+
return value_conditional, logp_marginal + logp_conditional, vi
93+
end
94+
# `tilde_observe`
95+
function tilde_observe(context::LatentHandlingContext, right, left, vi)
96+
has_latents(right) || return tilde_observe(childcontext(context), right, left, vi)
97+
# When used as `observe`, we want to use the marginalized version.
98+
right_marginal = marginalize(right)
99+
return tilde_observe(childcontext(context), right_marginal, left, vi)
100+
end
101+
function tilde_observe(context::LatentHandlingContext, sampler, right, left, vi)
102+
if !has_latents(right)
103+
return tilde_observe(childcontext(context), sampler, right, left, vi)
104+
end
105+
# When used as `observe`, we want to use the marginalized version.
106+
right_marginal = marginalize(right)
107+
return tilde_observe(childcontext(context), sampler, right_marginal, left, vi)
108+
end

0 commit comments

Comments
 (0)