1
1
# assume
2
- """
3
- tilde_assume(context::SamplingContext, right, vn, vi)
4
-
5
- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6
- accumulate the log probability, and return the sampled value with a context associated
7
- with a sampler.
8
-
9
- Falls back to
10
- ```julia
11
- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12
- ```
13
- """
14
- function tilde_assume (context:: SamplingContext , right, vn, vi)
15
- return tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16
- end
17
-
18
2
function tilde_assume (context:: AbstractContext , args... )
19
3
return tilde_assume (childcontext (context), args... )
20
4
end
21
5
function tilde_assume (:: DefaultContext , right, vn, vi)
22
- return assume (right, vn, vi)
23
- end
24
-
25
- function tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26
- return tilde_assume (rng, childcontext (context), args... )
27
- end
28
- function tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29
- return assume (rng, sampler, right, vn, vi)
30
- end
31
- function tilde_assume (:: DefaultContext , sampler, right, vn, vi)
32
- # same as above but no rng
33
- return assume (Random. default_rng (), sampler, right, vn, vi)
6
+ y = getindex_internal (vi, vn)
7
+ f = from_maybe_linked_internal_transform (vi, vn, right)
8
+ x, inv_logjac = with_logabsdet_jacobian (f, y)
9
+ vi = accumulate_assume!! (vi, x, - inv_logjac, vn, right)
10
+ return x, vi
34
11
end
35
-
36
12
function tilde_assume (context:: PrefixContext , right, vn, vi)
37
13
# Note that we can't use something like this here:
38
14
# new_vn = prefix(context, vn)
@@ -46,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
46
22
new_vn, new_context = prefix_and_strip_contexts (context, vn)
47
23
return tilde_assume (new_context, right, new_vn, vi)
48
24
end
49
- function tilde_assume (
50
- rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
51
- )
52
- new_vn, new_context = prefix_and_strip_contexts (context, vn)
53
- return tilde_assume (rng, new_context, sampler, right, new_vn, vi)
54
- end
55
25
56
26
"""
57
27
tilde_assume!!(context, right, vn, vi)
@@ -71,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
71
41
end
72
42
73
43
# observe
74
- """
75
- tilde_observe!!(context::SamplingContext, right, left, vi)
76
-
77
- Handle observed constants with a `context` associated with a sampler.
78
-
79
- Falls back to `tilde_observe!!(context.context, right, left, vi)`.
80
- """
81
- function tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
82
- return tilde_observe!! (context. context, right, left, vn, vi)
83
- end
84
-
85
44
function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
86
45
return tilde_observe!! (childcontext (context), right, left, vn, vi)
87
46
end
@@ -114,58 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
114
73
vi = accumulate_observe!! (vi, right, left, vn)
115
74
return left, vi
116
75
end
117
-
118
- function assume (:: Random.AbstractRNG , spl:: Sampler , dist)
119
- return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
120
- end
121
-
122
- # fallback without sampler
123
- function assume (dist:: Distribution , vn:: VarName , vi)
124
- y = getindex_internal (vi, vn)
125
- f = from_maybe_linked_internal_transform (vi, vn, dist)
126
- x, inv_logjac = with_logabsdet_jacobian (f, y)
127
- vi = accumulate_assume!! (vi, x, - inv_logjac, vn, dist)
128
- return x, vi
129
- end
130
-
131
- # TODO : Remove this thing.
132
- # SampleFromPrior and SampleFromUniform
133
- function assume (
134
- rng:: Random.AbstractRNG ,
135
- sampler:: Union{SampleFromPrior,SampleFromUniform} ,
136
- dist:: Distribution ,
137
- vn:: VarName ,
138
- vi:: VarInfoOrThreadSafeVarInfo ,
139
- )
140
- if haskey (vi, vn)
141
- # Always overwrite the parameters with new ones for `SampleFromUniform`.
142
- if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
143
- # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
144
- # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
145
- # if that's okay.
146
- unset_flag! (vi, vn, " del" , true )
147
- r = init (rng, dist, sampler)
148
- f = to_maybe_linked_internal_transform (vi, vn, dist)
149
- # TODO (mhauru) This should probably be call a function called setindex_internal!
150
- vi = BangBang. setindex!! (vi, f (r), vn)
151
- else
152
- # Otherwise we just extract it.
153
- r = vi[vn, dist]
154
- end
155
- else
156
- r = init (rng, dist, sampler)
157
- if istrans (vi)
158
- f = to_linked_internal_transform (vi, vn, dist)
159
- vi = push!! (vi, vn, f (r), dist)
160
- # By default `push!!` sets the transformed flag to `false`.
161
- vi = settrans!! (vi, true , vn)
162
- else
163
- vi = push!! (vi, vn, r, dist)
164
- end
165
- end
166
-
167
- # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
168
- logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
169
- vi = accumulate_assume!! (vi, r, logjac, vn, dist)
170
- return r, vi
171
- end
0 commit comments