@@ -425,15 +425,6 @@ function dot_assume(
425
425
var:: AbstractMatrix ,
426
426
vns:: AbstractVector{<:VarName} ,
427
427
vi:: AbstractVarInfo ,
428
- )
429
- r, lp, vi = dot_assume_vec (dist, var, vns, vi)
430
- return r, sum (lp), vi
431
- end
432
- function dot_assume_vec (
433
- dist:: MultivariateDistribution ,
434
- var:: AbstractMatrix ,
435
- vns:: AbstractVector{<:VarName} ,
436
- vi:: AbstractVarInfo ,
437
428
)
438
429
@assert length (dist) == size (var, 1 ) " dimensionality of `var` ($(size (var, 1 )) ) is incompatible with dimensionality of `dist` $(length (dist)) "
439
430
# NOTE: We cannot work with `var` here because we might have a model of the form
@@ -443,7 +434,7 @@ function dot_assume_vec(
443
434
#
444
435
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
445
436
r = vi[vns, dist]
446
- lp = map (zip (vns, eachcol (r))) do (vn, ri)
437
+ lp = sum (zip (vns, eachcol (r))) do (vn, ri)
447
438
return Bijectors. logpdf_with_trans (dist, ri, istrans (vi, vn))
448
439
end
449
440
return r, lp, vi
@@ -464,29 +455,21 @@ function dot_assume(
464
455
end
465
456
466
457
function dot_assume (
467
- dist:: Union{Distribution,AbstractArray{<:Distribution}} , var:: AbstractArray , vns:: AbstractArray{<:VarName} , vi
468
- )
469
- # possibility to acesss the single logpriors
470
- r, lp, vi = dot_assume_vec (dist, var, vns, vi)
471
- return r, sum (lp), vi
472
- end
473
-
474
- function dot_assume_vec (
475
458
dist:: Distribution , var:: AbstractArray , vns:: AbstractArray{<:VarName} , vi
476
459
)
477
460
r = getindex .((vi,), vns, (dist,))
478
- lp = Bijectors. logpdf_with_trans .((dist,), r, istrans .((vi,), vns))
461
+ lp = sum ( Bijectors. logpdf_with_trans .((dist,), r, istrans .((vi,), vns) ))
479
462
return r, lp, vi
480
463
end
481
464
482
- function dot_assume_vec (
465
+ function dot_assume (
483
466
dists:: AbstractArray{<:Distribution} ,
484
467
var:: AbstractArray ,
485
468
vns:: AbstractArray{<:VarName} ,
486
469
vi,
487
470
)
488
471
r = getindex .((vi,), vns, dists)
489
- lp = Bijectors. logpdf_with_trans .(dists, r, istrans .((vi,), vns))
472
+ lp = sum ( Bijectors. logpdf_with_trans .(dists, r, istrans .((vi,), vns) ))
490
473
return r, lp, vi
491
474
end
492
475
0 commit comments