Skip to content

Commit 75b5c51

Browse files
committed
Tidy up tilde-pipeline methods and docstrings
1 parent d7c4033 commit 75b5c51

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

src/context_implementations.jl

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
# assume
1+
"""
2+
tilde_assume!!(context, right::Distribution, vn, vi)
3+
4+
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
5+
accumulate the log probability, and return the sampled value and updated `vi`.
6+
7+
tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi)
8+
9+
Evaluate the submodel with the given context.
10+
"""
211
function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi)
312
return tilde_assume!!(childcontext(context), right, vn, vi)
413
end
@@ -22,24 +31,23 @@ function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
2231
new_vn, new_context = prefix_and_strip_contexts(context, vn)
2332
return tilde_assume!!(new_context, right, new_vn, vi)
2433
end
34+
function tilde_assume!!(context::AbstractContext, right::DynamicPPL.Submodel, vn, vi)
35+
return _evaluate!!(right, vi, context, vn)
36+
end
2537

2638
"""
27-
tilde_assume!!(context, right, vn, vi)
39+
tilde_observe!!(context, right::Distribution, left, vn, vi)
2840
29-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
30-
accumulate the log probability, and return the sampled value and updated `vi`.
31-
"""
32-
function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi)
33-
return _evaluate!!(right, vi, context, vn)
34-
end
41+
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42+
accumulate the log probability, and return the observed value and updated `vi`.
3543
36-
# observe
37-
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
44+
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about
45+
variable name and indices; if needed, these can be accessed through this function, though.
46+
"""
47+
function tilde_observe!!(context::AbstractContext, right::Distribution, left, vn, vi)
3848
return tilde_observe!!(childcontext(context), right, left, vn, vi)
3949
end
40-
41-
# `PrefixContext`
42-
function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
50+
function tilde_observe!!(context::PrefixContext, right::Distribution, left, vn, vi)
4351
# In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
4452
# value. For the need for prefix_and_strip_contexts rather than just prefix, see the
4553
# comment in `tilde_assume!!`.
@@ -50,21 +58,10 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
5058
end
5159
return tilde_observe!!(new_context, right, left, new_vn, vi)
5260
end
53-
54-
"""
55-
tilde_observe!!(context, right, left, vn, vi)
56-
57-
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
58-
accumulate the log probability, and return the observed value and updated `vi`.
59-
60-
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
61-
and indices; if needed, these can be accessed through this function, though.
62-
"""
6361
function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi)
6462
vi = accumulate_observe!!(vi, right, left, vn)
6563
return left, vi
6664
end
67-
68-
function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi)
65+
function tilde_observe!!(::AbstractContext, ::DynamicPPL.Submodel, left, vn, vi)
6966
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
7067
end

src/contexts/init.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ function tilde_assume!!(
191191
return x, vi
192192
end
193193

194-
function tilde_observe!!(::InitContext, right, left, vn, vi)
194+
function tilde_observe!!(
195+
::InitContext,
196+
right::Distribution,
197+
left,
198+
vn::Union{VarName,Nothing},
199+
vi::AbstractVarInfo,
200+
)
195201
return tilde_observe!!(DefaultContext(), right, left, vn, vi)
196202
end

0 commit comments

Comments
 (0)