@@ -442,42 +442,45 @@ end
442
442
# ###
443
443
# ### Compiler interface, i.e. tilde operators.
444
444
# ###
445
- function DynamicPPL. assume (rng, spl:: Sampler{<:MH} , dist:: Distribution , vn:: VarName , vi)
445
+ function DynamicPPL. assume (
446
+ rng:: Random.AbstractRNG , spl:: Sampler{<:MH} , dist:: Distribution , vn:: VarName , vi
447
+ )
448
+ # Just defer to `SampleFromPrior`.
449
+ retval = DynamicPPL. assume (rng, SampleFromPrior (), dist, vn, vi)
450
+ # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
446
451
DynamicPPL. updategid! (vi, vn, spl)
447
- r = vi[vn]
448
- return r, logpdf_with_trans (dist, r, istrans (vi, vn)), vi
452
+ # Return.
453
+ return retval
449
454
end
450
455
451
456
function DynamicPPL. dot_assume (
452
457
rng,
453
458
spl:: Sampler{<:MH} ,
454
459
dist:: MultivariateDistribution ,
455
- vn :: VarName ,
460
+ vns :: AbstractVector{<: VarName} ,
456
461
var:: AbstractMatrix ,
457
- vi,
462
+ vi:: AbstractVarInfo ,
458
463
)
459
- @assert dim (dist) == size (var, 1 )
460
- getvn = i -> VarName (vn, vn. indexing * " [:,$i ]" )
461
- vns = getvn .(1 : size (var, 2 ))
462
- DynamicPPL. updategid! .(Ref (vi), vns, Ref (spl))
463
- r = vi[vns]
464
- var .= r
465
- return var, sum (logpdf_with_trans (dist, r, istrans (vi, vns[1 ]))), vi
464
+ # Just defer to `SampleFromPrior`.
465
+ retval = DynamicPPL. dot_assume (rng, SampleFromPrior (), dist, vns[1 ], var, vi)
466
+ # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
467
+ DynamicPPL. updategid! .((vi,), vns, (spl,))
468
+ # Return.
469
+ return retval
466
470
end
467
471
function DynamicPPL. dot_assume (
468
472
rng,
469
473
spl:: Sampler{<:MH} ,
470
474
dists:: Union{Distribution,AbstractArray{<:Distribution}} ,
471
- vn :: VarName ,
475
+ vns :: AbstractArray{<: VarName} ,
472
476
var:: AbstractArray ,
473
- vi,
477
+ vi:: AbstractVarInfo ,
474
478
)
475
- getvn = ind -> VarName (vn, vn. indexing * " [" * join (Tuple (ind), " ," ) * " ]" )
476
- vns = getvn .(CartesianIndices (var))
477
- DynamicPPL. updategid! .(Ref (vi), vns, Ref (spl))
478
- r = reshape (vi[vec (vns)], size (var))
479
- var .= r
480
- return var, sum (logpdf_with_trans .(dists, r, istrans (vi, vns[1 ]))), vi
479
+ # Just defer to `SampleFromPrior`.
480
+ retval = DynamicPPL. dot_assume (rng, SampleFromPrior (), dists, vns, var, vi)
481
+ # Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
482
+ DynamicPPL. updategid! .((vi,), vns, (spl,))
483
+ return retval
481
484
end
482
485
483
486
function DynamicPPL. observe (spl:: Sampler{<:MH} , d:: Distribution , value, vi)
0 commit comments