55
55
function tilde_assume (context:: PriorContext{<:NamedTuple} , right, vn, vi)
56
56
if haskey (context. vars, getsym (vn))
57
57
vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
58
- settrans! (vi, false , vn)
58
+ settrans!! (vi, false , vn)
59
59
end
60
60
return tilde_assume (PriorContext (), right, vn, vi)
61
61
end
@@ -64,15 +64,15 @@ function tilde_assume(
64
64
)
65
65
if haskey (context. vars, getsym (vn))
66
66
vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
67
- settrans! (vi, false , vn)
67
+ settrans!! (vi, false , vn)
68
68
end
69
69
return tilde_assume (rng, PriorContext (), sampler, right, vn, vi)
70
70
end
71
71
72
72
function tilde_assume (context:: LikelihoodContext{<:NamedTuple} , right, vn, vi)
73
73
if haskey (context. vars, getsym (vn))
74
74
vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
75
- settrans! (vi, false , vn)
75
+ settrans!! (vi, false , vn)
76
76
end
77
77
return tilde_assume (LikelihoodContext (), right, vn, vi)
78
78
end
@@ -86,7 +86,7 @@ function tilde_assume(
86
86
)
87
87
if haskey (context. vars, getsym (vn))
88
88
vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
89
- settrans! (vi, false , vn)
89
+ settrans!! (vi, false , vn)
90
90
end
91
91
return tilde_assume (rng, LikelihoodContext (), sampler, right, vn, vi)
92
92
end
194
194
195
195
# fallback without sampler
196
196
function assume (dist:: Distribution , vn:: VarName , vi)
197
- r = vi[vn]
197
+ r = vi[vn, dist ]
198
198
return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
199
199
end
200
200
@@ -211,16 +211,21 @@ function assume(
211
211
if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
212
212
unset_flag! (vi, vn, " del" )
213
213
r = init (rng, dist, sampler)
214
- vi[vn] = vectorize (dist, r)
215
- settrans! (vi, false , vn)
214
+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r))
216
215
setorder! (vi, vn, get_num_produce (vi))
217
216
else
218
- r = vi[vn]
217
+ # Otherwise we just extract it.
218
+ r = vi[vn, dist]
219
219
end
220
220
else
221
221
r = init (rng, dist, sampler)
222
- push!! (vi, vn, r, dist, sampler)
223
- settrans! (vi, false , vn)
222
+ if istrans (vi)
223
+ push!! (vi, vn, link (dist, r), dist, sampler)
224
+ # By default `push!!` sets the transformed flag to `false`.
225
+ settrans!! (vi, true , vn)
226
+ else
227
+ push!! (vi, vn, r, dist, sampler)
228
+ end
224
229
end
225
230
226
231
return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
@@ -286,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
286
291
var = get (context. vars, vn)
287
292
_right, _left, _vns = unwrap_right_left_vns (right, var, vn)
288
293
set_val! (vi, _vns, _right, _left)
289
- settrans! .( Ref (vi), false , _vns)
294
+ settrans!! .( (vi, ), false , _vns)
290
295
dot_tilde_assume (LikelihoodContext (), _right, _left, _vns, vi)
291
296
else
292
297
dot_tilde_assume (LikelihoodContext (), right, left, vn, vi)
@@ -305,19 +310,20 @@ function dot_tilde_assume(
305
310
var = get (context. vars, vn)
306
311
_right, _left, _vns = unwrap_right_left_vns (right, var, vn)
307
312
set_val! (vi, _vns, _right, _left)
308
- settrans! .( Ref (vi), false , _vns)
313
+ settrans!! .( (vi, ), false , _vns)
309
314
dot_tilde_assume (rng, LikelihoodContext (), sampler, _right, _left, _vns, vi)
310
315
else
311
316
dot_tilde_assume (rng, LikelihoodContext (), sampler, right, left, vn, vi)
312
317
end
313
318
end
319
+
314
320
function dot_tilde_assume (context:: LikelihoodContext , right, left, vn, vi)
315
- return dot_assume (NoDist . (right), left, vn, vi)
321
+ return dot_assume (nodist (right), left, vn, vi)
316
322
end
317
323
function dot_tilde_assume (
318
324
rng:: Random.AbstractRNG , context:: LikelihoodContext , sampler, right, left, vn, vi
319
325
)
320
- return dot_assume (rng, sampler, NoDist . (right), vn, left, vi)
326
+ return dot_assume (rng, sampler, nodist (right), vn, left, vi)
321
327
end
322
328
323
329
# `PriorContext`
@@ -326,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
326
332
var = get (context. vars, vn)
327
333
_right, _left, _vns = unwrap_right_left_vns (right, var, vn)
328
334
set_val! (vi, _vns, _right, _left)
329
- settrans! .( Ref (vi), false , _vns)
335
+ settrans!! .( (vi, ), false , _vns)
330
336
dot_tilde_assume (PriorContext (), _right, _left, _vns, vi)
331
337
else
332
338
dot_tilde_assume (PriorContext (), right, left, vn, vi)
@@ -345,7 +351,7 @@ function dot_tilde_assume(
345
351
var = get (context. vars, vn)
346
352
_right, _left, _vns = unwrap_right_left_vns (right, var, vn)
347
353
set_val! (vi, _vns, _right, _left)
348
- settrans! .( Ref (vi), false , _vns)
354
+ settrans!! .( (vi, ), false , _vns)
349
355
dot_tilde_assume (rng, PriorContext (), sampler, _right, _left, _vns, vi)
350
356
else
351
357
dot_tilde_assume (rng, PriorContext (), sampler, right, left, vn, vi)
@@ -383,14 +389,14 @@ function dot_assume(
383
389
vns:: AbstractVector{<:VarName} ,
384
390
vi:: AbstractVarInfo ,
385
391
)
386
- @assert length (dist) == size (var, 1 )
392
+ @assert length (dist) == size (var, 1 ) " dimensionality of `var` ( $( size (var, 1 )) ) is incompatible with dimensionality of `dist` $( length (dist)) "
387
393
# NOTE: We cannot work with `var` here because we might have a model of the form
388
394
#
389
395
# m = Vector{Float64}(undef, n)
390
396
# m .~ Normal()
391
397
#
392
398
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
393
- r = vi[vns]
399
+ r = vi[vns, dist ]
394
400
lp = sum (zip (vns, eachcol (r))) do (vn, ri)
395
401
return Bijectors. logpdf_with_trans (dist, ri, istrans (vi, vn))
396
402
end
@@ -412,19 +418,21 @@ function dot_assume(
412
418
end
413
419
414
420
function dot_assume (
415
- dists:: Union{Distribution,AbstractArray{<:Distribution}} ,
421
+ dist:: Distribution , var:: AbstractArray , vns:: AbstractArray{<:VarName} , vi
422
+ )
423
+ r = getindex .((vi,), vns, (dist,))
424
+ lp = sum (Bijectors. logpdf_with_trans .((dist,), r, istrans .((vi,), vns)))
425
+ return r, lp, vi
426
+ end
427
+
428
+ function dot_assume (
429
+ dists:: AbstractArray{<:Distribution} ,
416
430
var:: AbstractArray ,
417
431
vns:: AbstractArray{<:VarName} ,
418
432
vi,
419
433
)
420
- # NOTE: We cannot work with `var` here because we might have a model of the form
421
- #
422
- # m = Vector{Float64}(undef, n)
423
- # m .~ Normal()
424
- #
425
- # in which case `var` will have `undef` elements, even if `m` is present in `vi`.
426
- r = reshape (vi[vec (vns)], size (vns))
427
- lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans (vi, vns[1 ])))
434
+ r = getindex .((vi,), vns, dists)
435
+ lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans .((vi,), vns)))
428
436
return r, lp, vi
429
437
end
430
438
@@ -438,7 +446,7 @@ function dot_assume(
438
446
)
439
447
r = get_and_set_val! (rng, vi, vns, dists, spl)
440
448
# Make sure `r` is not a matrix for multivariate distributions
441
- lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans ( vi, vns[ 1 ] )))
449
+ lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans .(( vi,), vns)))
442
450
return r, lp, vi
443
451
end
444
452
function dot_assume (rng, spl:: Sampler , :: Any , :: AbstractArray{<:VarName} , :: Any , :: Any )
@@ -462,19 +470,23 @@ function get_and_set_val!(
462
470
r = init (rng, dist, spl, n)
463
471
for i in 1 : n
464
472
vn = vns[i]
465
- vi[vn] = vectorize (dist, r[:, i])
466
- settrans! (vi, false , vn)
473
+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r[:, i]))
467
474
setorder! (vi, vn, get_num_produce (vi))
468
475
end
469
476
else
470
- r = vi[vns]
477
+ r = vi[vns, dist ]
471
478
end
472
479
else
473
480
r = init (rng, dist, spl, n)
474
481
for i in 1 : n
475
482
vn = vns[i]
476
- push!! (vi, vn, r[:, i], dist, spl)
477
- settrans! (vi, false , vn)
483
+ if istrans (vi)
484
+ push!! (vi, vn, Bijectors. link (dist, r[:, i]), dist, spl)
485
+ # `push!!` sets the trans-flag to `false` by default.
486
+ settrans!! (vi, true , vn)
487
+ else
488
+ push!! (vi, vn, r[:, i], dist, spl)
489
+ end
478
490
end
479
491
end
480
492
return r
@@ -496,12 +508,13 @@ function get_and_set_val!(
496
508
for i in eachindex (vns)
497
509
vn = vns[i]
498
510
dist = dists isa AbstractArray ? dists[i] : dists
499
- vi[vn] = vectorize (dist, r[i])
500
- settrans! (vi, false , vn)
511
+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r[i]))
501
512
setorder! (vi, vn, get_num_produce (vi))
502
513
end
503
514
else
504
- r = reshape (vi[vec (vns)], size (vns))
515
+ # r = reshape(vi[vec(vns)], size(vns))
516
+ r_raw = getindex_raw (vi, vec (vns))
517
+ r = maybe_invlink .((vi,), vns, dists, reshape (r_raw, size (vns)))
505
518
end
506
519
else
507
520
f = (vn, dist) -> init (rng, dist, spl)
@@ -511,8 +524,13 @@ function get_and_set_val!(
511
524
# 1. Figure out the broadcast size and use a `foreach`.
512
525
# 2. Define an anonymous function which returns `nothing`, which
513
526
# we then broadcast. This will allocate a vector of `nothing` though.
514
- push!! .(Ref (vi), vns, r, dists, Ref (spl))
515
- settrans! .(Ref (vi), false , vns)
527
+ if istrans (vi)
528
+ push!! .((vi,), vns, link .((vi,), vns, dists, r), dists, (spl,))
529
+ # `push!!` sets the trans-flag to `false` by default.
530
+ settrans!! .((vi,), true , vns)
531
+ else
532
+ push!! .((vi,), vns, r, dists, (spl,))
533
+ end
516
534
end
517
535
return r
518
536
end
0 commit comments