|
1 | 1 | # assume
|
2 |
| -function tilde_assume(context::AbstractContext, args...) |
3 |
| - return tilde_assume(childcontext(context), args...) |
| 2 | +function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi) |
| 3 | + return tilde_assume!!(childcontext(context), right, vn, vi) |
4 | 4 | end
|
5 |
| -function tilde_assume(::DefaultContext, right, vn, vi) |
| 5 | +function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi) |
6 | 6 | y = getindex_internal(vi, vn)
|
7 | 7 | f = from_maybe_linked_internal_transform(vi, vn, right)
|
8 | 8 | x, logjac = with_logabsdet_jacobian(f, y)
|
9 | 9 | vi = accumulate_assume!!(vi, x, logjac, vn, right)
|
10 | 10 | return x, vi
|
11 | 11 | end
|
12 |
| -function tilde_assume(context::PrefixContext, right, vn, vi) |
| 12 | +function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) |
13 | 13 | # Note that we can't use something like this here:
|
14 | 14 | # new_vn = prefix(context, vn)
|
15 |
| - # return tilde_assume(childcontext(context), right, new_vn, vi) |
| 15 | + # return tilde_assume!!(childcontext(context), right, new_vn, vi) |
16 | 16 | # This is because `prefix` applies _all_ prefixes in a given context to a
|
17 | 17 | # variable name. Thus, if we had two levels of nested prefixes e.g.
|
18 | 18 | # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the
|
19 | 19 | # first call would apply the prefix `a.b._`, and the recursive call
|
20 | 20 | # would apply the prefix `b._`, resulting in `b.a.b._`.
|
21 | 21 | # This is why we need a special function, `prefix_and_strip_contexts`.
|
22 | 22 | new_vn, new_context = prefix_and_strip_contexts(context, vn)
|
23 |
| - return tilde_assume(new_context, right, new_vn, vi) |
| 23 | + return tilde_assume!!(new_context, right, new_vn, vi) |
24 | 24 | end
|
25 | 25 |
|
26 | 26 | """
|
27 | 27 | tilde_assume!!(context, right, vn, vi)
|
28 | 28 |
|
29 | 29 | Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
|
30 | 30 | accumulate the log probability, and return the sampled value and updated `vi`.
|
31 |
| -
|
32 |
| -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log |
33 |
| -probability of `vi` with the returned value. |
34 | 31 | """
|
35 |
| -function tilde_assume!!(context, right, vn, vi) |
36 |
| - return if right isa DynamicPPL.Submodel |
37 |
| - _evaluate!!(right, vi, context, vn) |
38 |
| - else |
39 |
| - tilde_assume(context, right, vn, vi) |
40 |
| - end |
| 32 | +function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi) |
| 33 | + return _evaluate!!(right, vi, context, vn) |
41 | 34 | end
|
42 | 35 |
|
43 | 36 | # observe
|
|
0 commit comments