1
- # assume
2
- function tilde_assume!! (context:: AbstractContext , right:: Distribution , vn, vi)
1
+ """
2
+ DynamicPPL.tilde_assume!!(
3
+ context::AbstractContext,
4
+ right::Distribution,
5
+ vn::VarName,
6
+ vi::AbstractVarInfo
7
+ )
8
+
9
+ Handle assumed variables, i.e. anything which is not observed (see
10
+ [`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the
11
+ sampled value and updated `vi`.
12
+
13
+ `vn` is the VarName on the left-hand side of the tilde statement.
14
+ """
15
+ function tilde_assume!! (
16
+ context:: AbstractContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
17
+ )
3
18
return tilde_assume!! (childcontext (context), right, vn, vi)
4
19
end
5
- function tilde_assume!! (:: DefaultContext , right:: Distribution , vn, vi)
20
+ function tilde_assume!! (
21
+ :: DefaultContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
22
+ )
6
23
y = getindex_internal (vi, vn)
7
24
f = from_maybe_linked_internal_transform (vi, vn, right)
8
25
x, inv_logjac = with_logabsdet_jacobian (f, y)
9
26
vi = accumulate_assume!! (vi, x, - inv_logjac, vn, right)
10
27
return x, vi
11
28
end
12
- function tilde_assume!! (context:: PrefixContext , right:: Distribution , vn, vi)
29
+ function tilde_assume!! (
30
+ context:: PrefixContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
31
+ )
13
32
# Note that we can't use something like this here:
14
33
# new_vn = prefix(context, vn)
15
34
# return tilde_assume!!(childcontext(context), right, new_vn, vi)
@@ -22,24 +41,62 @@ function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
22
41
new_vn, new_context = prefix_and_strip_contexts (context, vn)
23
42
return tilde_assume!! (new_context, right, new_vn, vi)
24
43
end
25
-
26
44
"""
27
- tilde_assume!!(context, right, vn, vi)
45
+ DynamicPPL.tilde_assume!!(
46
+ context::AbstractContext,
47
+ right::DynamicPPL.Submodel,
48
+ vn::VarName,
49
+ vi::AbstractVarInfo
50
+ )
28
51
29
- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
30
- accumulate the log probability, and return the sampled value and updated `vi`.
52
+ Evaluate the submodel with the given context.
31
53
"""
32
- function tilde_assume!! (context, right:: DynamicPPL.Submodel , vn, vi)
54
+ function tilde_assume!! (
55
+ context:: AbstractContext , right:: DynamicPPL.Submodel , vn:: VarName , vi:: AbstractVarInfo
56
+ )
33
57
return _evaluate!! (right, vi, context, vn)
34
58
end
35
59
36
- # observe
37
- function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
60
+ """
61
+ tilde_observe!!(
62
+ context::AbstractContext,
63
+ right::Distribution,
64
+ left,
65
+ vn::Union{VarName, Nothing},
66
+ vi::AbstractVarInfo
67
+ )
68
+
69
+ This function handles observed variables, which may be:
70
+
71
+ - literals on the left-hand side, e.g., `3.0 ~ Normal()`
72
+ - a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end`
73
+ - a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`.
74
+
75
+ The relevant log-probability associated with the observation is computed and accumulated in
76
+ the VarInfo object `vi` (except for fixed variables, which do not contribute to the
77
+ log-probability).
78
+
79
+ `left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the
80
+ left-hand side, or `nothing` if the left-hand side is a literal value.
81
+
82
+ Observations of submodels are not yet supported in DynamicPPL.
83
+ """
84
+ function tilde_observe!! (
85
+ context:: AbstractContext ,
86
+ right:: Distribution ,
87
+ left,
88
+ vn:: Union{VarName,Nothing} ,
89
+ vi:: AbstractVarInfo ,
90
+ )
38
91
return tilde_observe!! (childcontext (context), right, left, vn, vi)
39
92
end
40
-
41
- # `PrefixContext`
42
- function tilde_observe!! (context:: PrefixContext , right, left, vn, vi)
93
+ function tilde_observe!! (
94
+ context:: PrefixContext ,
95
+ right:: Distribution ,
96
+ left,
97
+ vn:: Union{VarName,Nothing} ,
98
+ vi:: AbstractVarInfo ,
99
+ )
43
100
# In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
44
101
# value. For the need for prefix_and_strip_contexts rather than just prefix, see the
45
102
# comment in `tilde_assume!!`.
@@ -50,21 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
50
107
end
51
108
return tilde_observe!! (new_context, right, left, new_vn, vi)
52
109
end
53
-
54
- """
55
- tilde_observe!!(context, right, left, vn, vi)
56
-
57
- Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
58
- accumulate the log probability, and return the observed value and updated `vi`.
59
-
60
- Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
61
- and indices; if needed, these can be accessed through this function, though.
62
- """
63
- function tilde_observe!! (:: DefaultContext , right:: Distribution , left, vn, vi)
110
+ function tilde_observe!! (
111
+ :: DefaultContext ,
112
+ right:: Distribution ,
113
+ left,
114
+ vn:: Union{VarName,Nothing} ,
115
+ vi:: AbstractVarInfo ,
116
+ )
64
117
vi = accumulate_observe!! (vi, right, left, vn)
65
118
return left, vi
66
119
end
67
-
68
- function tilde_observe!! (:: DefaultContext , :: DynamicPPL.Submodel , left, vn, vi)
120
+ function tilde_observe!! (
121
+ :: AbstractContext ,
122
+ :: DynamicPPL.Submodel ,
123
+ left,
124
+ vn:: Union{VarName,Nothing} ,
125
+ :: AbstractVarInfo ,
126
+ )
69
127
throw (ArgumentError (" `x ~ to_submodel(...)` is not supported when `x` is observed" ))
70
128
end
0 commit comments