1
1
# assume
2
- """
3
- tilde_assume(context::SamplingContext, right, vn, vi)
4
-
5
- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6
- accumulate the log probability, and return the sampled value with a context associated
7
- with a sampler.
8
-
9
- Falls back to
10
- ```julia
11
- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12
- ```
13
- """
14
- function tilde_assume (context:: SamplingContext , right, vn, vi)
15
- return tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16
- end
17
-
18
2
function tilde_assume (context:: AbstractContext , args... )
19
3
return tilde_assume (childcontext (context), args... )
20
4
end
21
5
function tilde_assume (:: DefaultContext , right, vn, vi)
22
- return assume (right, vn, vi)
23
- end
24
-
25
- function tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26
- return tilde_assume (rng, childcontext (context), args... )
27
- end
28
- function tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29
- return assume (rng, sampler, right, vn, vi)
30
- end
31
- function tilde_assume (rng:: Random.AbstractRNG , :: InitContext , sampler, right, vn, vi)
32
- @warn (
33
- " Encountered SamplingContext->InitContext. This method will be removed in the next PR." ,
34
- )
35
- # just pretend the `InitContext` isn't there for now.
36
- return assume (rng, sampler, right, vn, vi)
37
- end
38
- function tilde_assume (:: DefaultContext , sampler, right, vn, vi)
39
- # same as above but no rng
40
- return assume (Random. default_rng (), sampler, right, vn, vi)
6
+ y = getindex_internal (vi, vn)
7
+ f = from_maybe_linked_internal_transform (vi, vn, right)
8
+ x, inv_logjac = with_logabsdet_jacobian (f, y)
9
+ vi = accumulate_assume!! (vi, x, - inv_logjac, vn, right)
10
+ return x, vi
41
11
end
42
-
43
12
function tilde_assume (context:: PrefixContext , right, vn, vi)
44
13
# Note that we can't use something like this here:
45
14
# new_vn = prefix(context, vn)
@@ -53,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
53
22
new_vn, new_context = prefix_and_strip_contexts (context, vn)
54
23
return tilde_assume (new_context, right, new_vn, vi)
55
24
end
56
- function tilde_assume (
57
- rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
58
- )
59
- new_vn, new_context = prefix_and_strip_contexts (context, vn)
60
- return tilde_assume (rng, new_context, sampler, right, new_vn, vi)
61
- end
62
25
63
26
"""
64
27
tilde_assume!!(context, right, vn, vi)
@@ -78,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
78
41
end
79
42
80
43
# observe
81
- """
82
- tilde_observe!!(context::SamplingContext, right, left, vi)
83
-
84
- Handle observed constants with a `context` associated with a sampler.
85
-
86
- Falls back to `tilde_observe!!(context.context, right, left, vi)`.
87
- """
88
- function tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
89
- return tilde_observe!! (context. context, right, left, vn, vi)
90
- end
91
-
92
44
function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
93
45
return tilde_observe!! (childcontext (context), right, left, vn, vi)
94
46
end
@@ -121,58 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
121
73
vi = accumulate_observe!! (vi, right, left, vn)
122
74
return left, vi
123
75
end
124
-
125
- function assume (:: Random.AbstractRNG , spl:: Sampler , dist)
126
- return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
127
- end
128
-
129
- # fallback without sampler
130
- function assume (dist:: Distribution , vn:: VarName , vi)
131
- y = getindex_internal (vi, vn)
132
- f = from_maybe_linked_internal_transform (vi, vn, dist)
133
- x, inv_logjac = with_logabsdet_jacobian (f, y)
134
- vi = accumulate_assume!! (vi, x, - inv_logjac, vn, dist)
135
- return x, vi
136
- end
137
-
138
- # TODO : Remove this thing.
139
- # SampleFromPrior and SampleFromUniform
140
- function assume (
141
- rng:: Random.AbstractRNG ,
142
- sampler:: Union{SampleFromPrior,SampleFromUniform} ,
143
- dist:: Distribution ,
144
- vn:: VarName ,
145
- vi:: VarInfoOrThreadSafeVarInfo ,
146
- )
147
- if haskey (vi, vn)
148
- # Always overwrite the parameters with new ones for `SampleFromUniform`.
149
- if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
150
- # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
151
- # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
152
- # if that's okay.
153
- unset_flag! (vi, vn, " del" , true )
154
- r = init (rng, dist, sampler)
155
- f = to_maybe_linked_internal_transform (vi, vn, dist)
156
- # TODO (mhauru) This should probably be call a function called setindex_internal!
157
- vi = BangBang. setindex!! (vi, f (r), vn)
158
- else
159
- # Otherwise we just extract it.
160
- r = vi[vn, dist]
161
- end
162
- else
163
- r = init (rng, dist, sampler)
164
- if istrans (vi)
165
- f = to_linked_internal_transform (vi, vn, dist)
166
- vi = push!! (vi, vn, f (r), dist)
167
- # By default `push!!` sets the transformed flag to `false`.
168
- vi = settrans!! (vi, true , vn)
169
- else
170
- vi = push!! (vi, vn, r, dist)
171
- end
172
- end
173
-
174
- # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
175
- logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
176
- vi = accumulate_assume!! (vi, r, logjac, vn, dist)
177
- return r, vi
178
- end
0 commit comments