@@ -19,47 +19,47 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
19
19
_getindex (x, inds:: Tuple{} ) = x
20
20
21
21
# assume
22
- function tilde (ctx:: DefaultContext , sampler, right, vn:: VarName , _, vi)
23
- return _tilde (sampler, right, vn, vi)
22
+ function tilde (rng, ctx:: DefaultContext , sampler, right, vn:: VarName , _, vi)
23
+ return _tilde (rng, sampler, right, vn, vi)
24
24
end
25
- function tilde (ctx:: PriorContext , sampler, right, vn:: VarName , inds, vi)
25
+ function tilde (rng, ctx:: PriorContext , sampler, right, vn:: VarName , inds, vi)
26
26
if ctx. vars != = nothing
27
27
vi[vn] = vectorize (right, _getindex (getfield (ctx. vars, getsym (vn)), inds))
28
28
settrans! (vi, false , vn)
29
29
end
30
- return _tilde (sampler, right, vn, vi)
30
+ return _tilde (rng, sampler, right, vn, vi)
31
31
end
32
- function tilde (ctx:: LikelihoodContext , sampler, right, vn:: VarName , inds, vi)
32
+ function tilde (rng, ctx:: LikelihoodContext , sampler, right, vn:: VarName , inds, vi)
33
33
if ctx. vars != = nothing
34
34
vi[vn] = vectorize (right, _getindex (getfield (ctx. vars, getsym (vn)), inds))
35
35
settrans! (vi, false , vn)
36
36
end
37
- return _tilde (sampler, NoDist (right), vn, vi)
37
+ return _tilde (rng, sampler, NoDist (right), vn, vi)
38
38
end
39
- function tilde (ctx:: MiniBatchContext , sampler, right, left:: VarName , inds, vi)
40
- return tilde (ctx. ctx, sampler, right, left, inds, vi)
39
+ function tilde (rng, ctx:: MiniBatchContext , sampler, right, left:: VarName , inds, vi)
40
+ return tilde (rng, ctx. ctx, sampler, right, left, inds, vi)
41
41
end
42
42
43
43
"""
44
- tilde_assume(ctx, sampler, right, vn, inds, vi)
44
+ tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
45
45
46
46
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
47
47
accumulate the log probability, and return the sampled value.
48
48
49
- Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
49
+ Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`.
50
50
"""
51
- function tilde_assume (ctx, sampler, right, vn, inds, vi)
52
- value, logp = tilde (ctx, sampler, right, vn, inds, vi)
51
+ function tilde_assume (rng, ctx, sampler, right, vn, inds, vi)
52
+ value, logp = tilde (rng, ctx, sampler, right, vn, inds, vi)
53
53
acclogp! (vi, logp)
54
54
return value
55
55
end
56
56
57
57
58
- function _tilde (sampler, right, vn:: VarName , vi)
59
- return assume (sampler, right, vn, vi)
58
+ function _tilde (rng, sampler, right, vn:: VarName , vi)
59
+ return assume (rng, sampler, right, vn, vi)
60
60
end
61
- function _tilde (sampler, right:: NamedDist , vn:: VarName , vi)
62
- return _tilde (sampler, right. dist, right. name, vi)
61
+ function _tilde (rng, sampler, right:: NamedDist , vn:: VarName , vi)
62
+ return _tilde (rng, sampler, right. dist, right. name, vi)
63
63
end
64
64
65
65
# observe
108
108
109
109
_tilde (sampler, right, left, vi) = observe (sampler, right, left, vi)
110
110
111
- function assume (spl:: Sampler , dist)
111
+ function assume (rng, spl:: Sampler , dist)
112
112
error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
113
113
end
114
114
@@ -117,6 +117,7 @@ function observe(spl::Sampler, weight)
117
117
end
118
118
119
119
function assume (
120
+ rng,
120
121
spl:: Union{SampleFromPrior,SampleFromUniform} ,
121
122
dist:: Distribution ,
122
123
vn:: VarName ,
@@ -126,15 +127,15 @@ function assume(
126
127
# Always overwrite the parameters with new ones for `SampleFromUniform`.
127
128
if spl isa SampleFromUniform || is_flagged (vi, vn, " del" )
128
129
unset_flag! (vi, vn, " del" )
129
- r = init (dist, spl)
130
+ r = init (rng, dist, spl)
130
131
vi[vn] = vectorize (dist, r)
131
132
settrans! (vi, false , vn)
132
133
setorder! (vi, vn, get_num_produce (vi))
133
134
else
134
135
r = vi[vn]
135
136
end
136
137
else
137
- r = init (dist, spl)
138
+ r = init (rng, dist, spl)
138
139
push! (vi, vn, r, dist, spl)
139
140
settrans! (vi, false , vn)
140
141
end
@@ -154,11 +155,12 @@ end
154
155
# .~ functions
155
156
156
157
# assume
157
- function dot_tilde (ctx:: DefaultContext , sampler, right, left, vn:: VarName , _, vi)
158
+ function dot_tilde (rng, ctx:: DefaultContext , sampler, right, left, vn:: VarName , _, vi)
158
159
vns, dist = get_vns_and_dist (right, left, vn)
159
- return _dot_tilde (sampler, dist, left, vns, vi)
160
+ return _dot_tilde (rng, sampler, dist, left, vns, vi)
160
161
end
161
162
function dot_tilde (
163
+ rng,
162
164
ctx:: LikelihoodContext ,
163
165
sampler,
164
166
right,
@@ -175,12 +177,13 @@ function dot_tilde(
175
177
else
176
178
vns, dist = get_vns_and_dist (right, left, vn)
177
179
end
178
- return _dot_tilde (sampler, NoDist (dist), left, vns, vi)
180
+ return _dot_tilde (rng, sampler, NoDist (dist), left, vns, vi)
179
181
end
180
- function dot_tilde (ctx:: MiniBatchContext , sampler, right, left, vn:: VarName , inds, vi)
181
- return dot_tilde (ctx. ctx, sampler, right, left, vn, inds, vi)
182
+ function dot_tilde (rng, ctx:: MiniBatchContext , sampler, right, left, vn:: VarName , inds, vi)
183
+ return dot_tilde (rng, ctx. ctx, sampler, right, left, vn, inds, vi)
182
184
end
183
185
function dot_tilde (
186
+ rng,
184
187
ctx:: PriorContext ,
185
188
sampler,
186
189
right,
@@ -197,19 +200,19 @@ function dot_tilde(
197
200
else
198
201
vns, dist = get_vns_and_dist (right, left, vn)
199
202
end
200
- return _dot_tilde (sampler, dist, left, vns, vi)
203
+ return _dot_tilde (rng, sampler, dist, left, vns, vi)
201
204
end
202
205
203
206
"""
204
- dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
207
+ dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)
205
208
206
209
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
207
210
model inputs), accumulate the log probability, and return the sampled value.
208
211
209
- Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
212
+ Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`.
210
213
"""
211
- function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi)
212
- value, logp = dot_tilde (ctx, sampler, right, left, vn, inds, vi)
214
+ function dot_tilde_assume (rng, ctx, sampler, right, left, vn, inds, vi)
215
+ value, logp = dot_tilde (rng, ctx, sampler, right, left, vn, inds, vi)
213
216
acclogp! (vi, logp)
214
217
return value
215
218
end
@@ -232,12 +235,13 @@ function get_vns_and_dist(
232
235
return getvn .(CartesianIndices (var)), dist
233
236
end
234
237
235
- function _dot_tilde (sampler, right, left, vns:: AbstractArray{<:VarName} , vi)
236
- return dot_assume (sampler, right, vns, left, vi)
238
+ function _dot_tilde (rng, sampler, right, left, vns:: AbstractArray{<:VarName} , vi)
239
+ return dot_assume (rng, sampler, right, vns, left, vi)
237
240
end
238
241
239
242
# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
240
243
function _dot_tilde (
244
+ rng,
241
245
sampler:: AbstractSampler ,
242
246
right:: Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}} ,
243
247
left:: AbstractMatrix{>:AbstractVector} ,
@@ -248,32 +252,35 @@ function _dot_tilde(
248
252
end
249
253
250
254
function dot_assume (
255
+ rng,
251
256
spl:: Union{SampleFromPrior, SampleFromUniform} ,
252
257
dist:: MultivariateDistribution ,
253
258
vns:: AbstractVector{<:VarName} ,
254
259
var:: AbstractMatrix ,
255
260
vi,
256
261
)
257
262
@assert length (dist) == size (var, 1 )
258
- r = get_and_set_val! (vi, vns, dist, spl)
263
+ r = get_and_set_val! (rng, vi, vns, dist, spl)
259
264
lp = sum (Bijectors. logpdf_with_trans (dist, r, istrans (vi, vns[1 ])))
260
265
var .= r
261
266
return var, lp
262
267
end
263
268
function dot_assume (
269
+ rng,
264
270
spl:: Union{SampleFromPrior, SampleFromUniform} ,
265
271
dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
266
272
vns:: AbstractArray{<:VarName} ,
267
273
var:: AbstractArray ,
268
274
vi,
269
275
)
270
- r = get_and_set_val! (vi, vns, dists, spl)
276
+ r = get_and_set_val! (rng, vi, vns, dists, spl)
271
277
# Make sure `r` is not a matrix for multivariate distributions
272
278
lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans (vi, vns[1 ])))
273
279
var .= r
274
280
return var, lp
275
281
end
276
282
function dot_assume (
283
+ rng,
277
284
spl:: Sampler ,
278
285
:: Any ,
279
286
:: AbstractArray{<:VarName} ,
@@ -284,6 +291,7 @@ function dot_assume(
284
291
end
285
292
286
293
function get_and_set_val! (
294
+ rng,
287
295
vi,
288
296
vns:: AbstractVector{<:VarName} ,
289
297
dist:: MultivariateDistribution ,
@@ -294,7 +302,7 @@ function get_and_set_val!(
294
302
# Always overwrite the parameters with new ones for `SampleFromUniform`.
295
303
if spl isa SampleFromUniform || is_flagged (vi, vns[1 ], " del" )
296
304
unset_flag! (vi, vns[1 ], " del" )
297
- r = init (dist, spl, n)
305
+ r = init (rng, dist, spl, n)
298
306
for i in 1 : n
299
307
vn = vns[i]
300
308
vi[vn] = vectorize (dist, r[:, i])
@@ -305,7 +313,7 @@ function get_and_set_val!(
305
313
r = vi[vns]
306
314
end
307
315
else
308
- r = init (dist, spl, n)
316
+ r = init (rng, dist, spl, n)
309
317
for i in 1 : n
310
318
vn = vns[i]
311
319
push! (vi, vn, r[:,i], dist, spl)
@@ -316,6 +324,7 @@ function get_and_set_val!(
316
324
end
317
325
318
326
function get_and_set_val! (
327
+ rng,
319
328
vi,
320
329
vns:: AbstractArray{<:VarName} ,
321
330
dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
@@ -325,7 +334,7 @@ function get_and_set_val!(
325
334
# Always overwrite the parameters with new ones for `SampleFromUniform`.
326
335
if spl isa SampleFromUniform || is_flagged (vi, vns[1 ], " del" )
327
336
unset_flag! (vi, vns[1 ], " del" )
328
- f = (vn, dist) -> init (dist, spl)
337
+ f = (vn, dist) -> init (rng, dist, spl)
329
338
r = f .(vns, dists)
330
339
for i in eachindex (vns)
331
340
vn = vns[i]
@@ -338,7 +347,7 @@ function get_and_set_val!(
338
347
r = reshape (vi[vec (vns)], size (vns))
339
348
end
340
349
else
341
- f = (vn, dist) -> init (dist, spl)
350
+ f = (vn, dist) -> init (rng, dist, spl)
342
351
r = f .(vns, dists)
343
352
push! .(Ref (vi), vns, r, dists, Ref (spl))
344
353
settrans! .(Ref (vi), false , vns)
0 commit comments