Skip to content

Commit 6c776e9

Browse files
committed
Tidy up tilde-pipeline methods and docstrings
1 parent d7c4033 commit 6c776e9

File tree

2 files changed

+92
-28
lines changed

2 files changed

+92
-28
lines changed

src/context_implementations.jl

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,34 @@
1-
# assume
2-
function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi)
1+
"""
2+
DynamicPPL.tilde_assume!!(
3+
context::AbstractContext,
4+
right::Distribution,
5+
vn::VarName,
6+
vi::AbstractVarInfo
7+
)
8+
9+
Handle assumed variables, i.e. anything which is not observed (see
10+
[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the
11+
sampled value and updated `vi`.
12+
13+
`vn` is the VarName on the left-hand side of the tilde statement.
14+
"""
15+
function tilde_assume!!(
16+
context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
17+
)
318
return tilde_assume!!(childcontext(context), right, vn, vi)
419
end
5-
function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi)
20+
function tilde_assume!!(
21+
::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
22+
)
623
y = getindex_internal(vi, vn)
724
f = from_maybe_linked_internal_transform(vi, vn, right)
825
x, inv_logjac = with_logabsdet_jacobian(f, y)
926
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)
1027
return x, vi
1128
end
12-
function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
29+
function tilde_assume!!(
30+
context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
31+
)
1332
# Note that we can't use something like this here:
1433
# new_vn = prefix(context, vn)
1534
# return tilde_assume!!(childcontext(context), right, new_vn, vi)
@@ -22,24 +41,62 @@ function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
2241
new_vn, new_context = prefix_and_strip_contexts(context, vn)
2342
return tilde_assume!!(new_context, right, new_vn, vi)
2443
end
25-
2644
"""
27-
tilde_assume!!(context, right, vn, vi)
45+
DynamicPPL.tilde_assume!!(
46+
context::AbstractContext,
47+
right::DynamicPPL.Submodel,
48+
vn::VarName,
49+
vi::AbstractVarInfo
50+
)
2851
29-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
30-
accumulate the log probability, and return the sampled value and updated `vi`.
52+
Evaluate the submodel with the given context.
3153
"""
32-
function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi)
54+
function tilde_assume!!(
55+
context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo
56+
)
3357
return _evaluate!!(right, vi, context, vn)
3458
end
3559

36-
# observe
37-
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
60+
"""
61+
tilde_observe!!(
62+
context::AbstractContext,
63+
right::Distribution,
64+
left,
65+
vn::Union{VarName, Nothing},
66+
vi::AbstractVarInfo
67+
)
68+
69+
This function handles observed variables, which may be:
70+
71+
- literals on the left-hand side, e.g., `3.0 ~ Normal()`
72+
- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end`
73+
- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`.
74+
75+
The relevant log-probability associated with the observation is computed and accumulated in
76+
the VarInfo object `vi` (except for fixed variables, which do not contribute to the
77+
log-probability).
78+
79+
`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the
80+
left-hand side, or `nothing` if the left-hand side is a literal value.
81+
82+
Observations of submodels are not yet supported in DynamicPPL.
83+
"""
84+
function tilde_observe!!(
85+
context::AbstractContext,
86+
right::Distribution,
87+
left,
88+
vn::Union{VarName,Nothing},
89+
vi::AbstractVarInfo,
90+
)
3891
return tilde_observe!!(childcontext(context), right, left, vn, vi)
3992
end
40-
41-
# `PrefixContext`
42-
function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
93+
function tilde_observe!!(
94+
context::PrefixContext,
95+
right::Distribution,
96+
left,
97+
vn::Union{VarName,Nothing},
98+
vi::AbstractVarInfo,
99+
)
43100
# In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
44101
# value. For the need for prefix_and_strip_contexts rather than just prefix, see the
45102
# comment in `tilde_assume!!`.
@@ -50,21 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
50107
end
51108
return tilde_observe!!(new_context, right, left, new_vn, vi)
52109
end
53-
54-
"""
55-
tilde_observe!!(context, right, left, vn, vi)
56-
57-
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
58-
accumulate the log probability, and return the observed value and updated `vi`.
59-
60-
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
61-
and indices; if needed, these can be accessed through this function, though.
62-
"""
63-
function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi)
110+
function tilde_observe!!(
111+
::DefaultContext,
112+
right::Distribution,
113+
left,
114+
vn::Union{VarName,Nothing},
115+
vi::AbstractVarInfo,
116+
)
64117
vi = accumulate_observe!!(vi, right, left, vn)
65118
return left, vi
66119
end
67-
68-
function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi)
120+
function tilde_observe!!(
121+
::AbstractContext,
122+
::DynamicPPL.Submodel,
123+
left,
124+
vn::Union{VarName,Nothing},
125+
::AbstractVarInfo,
126+
)
69127
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
70128
end

src/contexts/init.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ function tilde_assume!!(
191191
return x, vi
192192
end
193193

194-
function tilde_observe!!(::InitContext, right, left, vn, vi)
194+
function tilde_observe!!(
195+
::InitContext,
196+
right::Distribution,
197+
left,
198+
vn::Union{VarName,Nothing},
199+
vi::AbstractVarInfo,
200+
)
195201
return tilde_observe!!(DefaultContext(), right, left, vn, vi)
196202
end

0 commit comments

Comments
 (0)