@@ -36,17 +36,16 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
36
36
end
37
37
38
38
"""
39
- tilde_assume(ctx, sampler, right, vn, inds, vi, logps )
39
+ tilde_assume(ctx, sampler, right, vn, inds, vi)
40
40
41
41
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42
- accumulate the log probability in `logps` (separately for each thread), and return the
43
- sampled value.
42
+ accumulate the log probability, and return the sampled value.
44
43
45
44
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
46
45
"""
47
- function tilde_assume (ctx, sampler, right, vn, inds, vi, logps )
46
+ function tilde_assume (ctx, sampler, right, vn, inds, vi)
48
47
value, logp = tilde (ctx, sampler, right, vn, inds, vi)
49
- logps[Threads . threadid ()] += logp
48
+ acclogp! (vi, logp)
50
49
return value
51
50
end
52
51
76
75
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
77
76
78
77
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
79
- accumulate the log probability in `logps` (separately for each thread), and return the
80
- observed value.
78
+ accumulate the log probability, and return the observed value.
81
79
82
80
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
83
81
and indices; if needed, these can be accessed through this function, though.
84
82
"""
85
- function tilde_observe (ctx, sampler, right, left, vname, vinds, vi, logps )
83
+ function tilde_observe (ctx, sampler, right, left, vname, vinds, vi)
86
84
logp = tilde (ctx, sampler, right, left, vi)
87
- logps[Threads . threadid ()] += logp
85
+ acclogp! (vi, logp)
88
86
return left
89
87
end
90
88
91
89
"""
92
- tilde_observe(ctx, sampler, right, left, vi, logps )
90
+ tilde_observe(ctx, sampler, right, left, vi)
93
91
94
- Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
95
- (separately for each thread), and return the observed value.
92
+ Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and
93
+ return the observed value.
96
94
97
95
Falls back to `tilde(ctx, sampler, right, left, vi)`.
98
96
"""
99
- function tilde_observe (ctx, sampler, right, left, vi, logps )
97
+ function tilde_observe (ctx, sampler, right, left, vi)
100
98
logp = tilde (ctx, sampler, right, left, vi)
101
- logps[Threads . threadid ()] += logp
99
+ acclogp! (vi, logp)
102
100
return left
103
101
end
104
102
@@ -117,7 +115,7 @@ function assume(
117
115
spl:: Union{SampleFromPrior,SampleFromUniform} ,
118
116
dist:: Distribution ,
119
117
vn:: VarName ,
120
- vi:: VarInfo ,
118
+ vi,
121
119
)
122
120
if haskey (vi, vn)
123
121
# Always overwrite the parameters with new ones for `SampleFromUniform`.
@@ -142,7 +140,7 @@ function observe(
142
140
spl:: Union{SampleFromPrior, SampleFromUniform} ,
143
141
dist:: Distribution ,
144
142
value,
145
- vi:: VarInfo ,
143
+ vi,
146
144
)
147
145
increment_num_produce! (vi)
148
146
return Distributions. logpdf (dist, value)
@@ -201,14 +199,13 @@ end
201
199
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
202
200
203
201
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
204
- model inputs), accumulate the log probability in `logps` (separately for each thread), and
205
- return the sampled value.
202
+ model inputs), accumulate the log probability, and return the sampled value.
206
203
207
204
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
208
205
"""
209
- function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi, logps )
206
+ function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi)
210
207
value, logp = dot_tilde (ctx, sampler, right, left, vn, inds, vi)
211
- logps[Threads . threadid ()] += logp
208
+ acclogp! (vi, logp)
212
209
return value
213
210
end
214
211
@@ -240,7 +237,7 @@ function _dot_tilde(
240
237
right:: Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}} ,
241
238
left:: AbstractMatrix{>:AbstractVector} ,
242
239
vn:: AbstractVector{<:VarName} ,
243
- vi:: VarInfo ,
240
+ vi,
244
241
)
245
242
throw (ambiguity_error_msg ())
246
243
end
@@ -250,7 +247,7 @@ function dot_assume(
250
247
dist:: MultivariateDistribution ,
251
248
vns:: AbstractVector{<:VarName} ,
252
249
var:: AbstractMatrix ,
253
- vi:: VarInfo ,
250
+ vi,
254
251
)
255
252
@assert length (dist) == size (var, 1 )
256
253
r = get_and_set_val! (vi, vns, dist, spl)
@@ -263,7 +260,7 @@ function dot_assume(
263
260
dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
264
261
vns:: AbstractArray{<:VarName} ,
265
262
var:: AbstractArray ,
266
- vi:: VarInfo ,
263
+ vi,
267
264
)
268
265
r = get_and_set_val! (vi, vns, dists, spl)
269
266
# Make sure `r` is not a matrix for multivariate distributions
@@ -276,13 +273,13 @@ function dot_assume(
276
273
:: Any ,
277
274
:: AbstractArray{<:VarName} ,
278
275
:: Any ,
279
- :: VarInfo
276
+ :: Any ,
280
277
)
281
278
error (" [DynamicPPL] $(alg_str (spl)) doesn't support vectorizing assume statement" )
282
279
end
283
280
284
281
function get_and_set_val! (
285
- vi:: VarInfo ,
282
+ vi,
286
283
vns:: AbstractVector{<:VarName} ,
287
284
dist:: MultivariateDistribution ,
288
285
spl:: Union{SampleFromPrior,SampleFromUniform} ,
@@ -313,7 +310,7 @@ function get_and_set_val!(
313
310
end
314
311
315
312
function get_and_set_val! (
316
- vi:: VarInfo ,
313
+ vi,
317
314
vns:: AbstractArray{<:VarName} ,
318
315
dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
319
316
spl:: Union{SampleFromPrior,SampleFromUniform} ,
@@ -344,7 +341,7 @@ function get_and_set_val!(
344
341
end
345
342
346
343
function set_val! (
347
- vi:: VarInfo ,
344
+ vi,
348
345
vns:: AbstractVector{<:VarName} ,
349
346
dist:: MultivariateDistribution ,
350
347
val:: AbstractMatrix ,
@@ -356,7 +353,7 @@ function set_val!(
356
353
return val
357
354
end
358
355
function set_val! (
359
- vi:: VarInfo ,
356
+ vi,
360
357
vns:: AbstractArray{<:VarName} ,
361
358
dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
362
359
val:: AbstractArray ,
@@ -384,36 +381,34 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
384
381
end
385
382
386
383
"""
387
- dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps )
384
+ dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
388
385
389
386
Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
390
- accumulate the log probability in `logps` (separately for each thread), and return the
391
- observed value.
387
+ accumulate the log probability, and return the observed value.
392
388
393
389
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
394
390
name and indices; if needed, these can be accessed through this function, though.
395
391
"""
396
- function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi, logps )
392
+ function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi)
397
393
logp = dot_tilde (ctx, sampler, right, left, vi)
398
- logps[Threads . threadid ()] += logp
394
+ acclogp! (vi, logp)
399
395
return left
400
396
end
401
397
402
398
"""
403
- dot_tilde_observe(ctx, sampler, right, left, vi, logps )
399
+ dot_tilde_observe(ctx, sampler, right, left, vi)
404
400
405
401
Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
406
- probability in `logps` (separately for each thread) , and return the observed value.
402
+ probability, and return the observed value.
407
403
408
404
Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
409
405
"""
410
- function dot_tilde_observe (ctx, sampler, right, left, vi, logps )
406
+ function dot_tilde_observe (ctx, sampler, right, left, vi)
411
407
logp = dot_tilde (ctx, sampler, right, left, vi)
412
- logps[Threads . threadid ()] += logp
408
+ acclogp! (vi, logp)
413
409
return left
414
410
end
415
411
416
-
417
412
function _dot_tilde (sampler, right, left:: AbstractArray , vi)
418
413
return dot_observe (sampler, right, left, vi)
419
414
end
@@ -422,7 +417,7 @@ function _dot_tilde(
422
417
sampler:: AbstractSampler ,
423
418
right:: Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}} ,
424
419
left:: AbstractMatrix{>:AbstractVector} ,
425
- vi:: VarInfo ,
420
+ vi,
426
421
)
427
422
throw (ambiguity_error_msg ())
428
423
end
@@ -431,7 +426,7 @@ function dot_observe(
431
426
spl:: Union{SampleFromPrior, SampleFromUniform} ,
432
427
dist:: MultivariateDistribution ,
433
428
value:: AbstractMatrix ,
434
- vi:: VarInfo ,
429
+ vi,
435
430
)
436
431
increment_num_produce! (vi)
437
432
DynamicPPL. DEBUG && @debug " dist = $dist "
@@ -442,7 +437,7 @@ function dot_observe(
442
437
spl:: Union{SampleFromPrior, SampleFromUniform} ,
443
438
dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
444
439
value:: AbstractArray ,
445
- vi:: VarInfo ,
440
+ vi,
446
441
)
447
442
increment_num_produce! (vi)
448
443
DynamicPPL. DEBUG && @debug " dists = $dists "
@@ -453,7 +448,7 @@ function dot_observe(
453
448
spl:: Sampler ,
454
449
:: Any ,
455
450
:: Any ,
456
- :: VarInfo ,
451
+ :: Any ,
457
452
)
458
453
error (" [DynamicPPL] $(alg_str (spl)) doesn't support vectorizing observe statement" )
459
454
end
0 commit comments