Skip to content

Commit 216d50c

Browse files
committed
revert dot_assume to not explicitly resolve components of sum
1 parent c6653b9 commit 216d50c

File tree

1 file changed

+4
-21
lines changed

1 file changed

+4
-21
lines changed

src/context_implementations.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -425,15 +425,6 @@ function dot_assume(
425425
var::AbstractMatrix,
426426
vns::AbstractVector{<:VarName},
427427
vi::AbstractVarInfo,
428-
)
429-
r, lp, vi = dot_assume_vec(dist, var, vns, vi)
430-
return r, sum(lp), vi
431-
end
432-
function dot_assume_vec(
433-
dist::MultivariateDistribution,
434-
var::AbstractMatrix,
435-
vns::AbstractVector{<:VarName},
436-
vi::AbstractVarInfo,
437428
)
438429
@assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))"
439430
# NOTE: We cannot work with `var` here because we might have a model of the form
@@ -443,7 +434,7 @@ function dot_assume_vec(
443434
#
444435
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
445436
r = vi[vns, dist]
446-
lp = map(zip(vns, eachcol(r))) do (vn, ri)
437+
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
447438
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
448439
end
449440
return r, lp, vi
@@ -464,29 +455,21 @@ function dot_assume(
464455
end
465456

466457
function dot_assume(
467-
dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
468-
)
469-
# possibility to acesss the single logpriors
470-
r, lp, vi = dot_assume_vec(dist, var, vns, vi)
471-
return r, sum(lp), vi
472-
end
473-
474-
function dot_assume_vec(
475458
dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
476459
)
477460
r = getindex.((vi,), vns, (dist,))
478-
lp = Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))
461+
lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)))
479462
return r, lp, vi
480463
end
481464

482-
function dot_assume_vec(
465+
function dot_assume(
483466
dists::AbstractArray{<:Distribution},
484467
var::AbstractArray,
485468
vns::AbstractArray{<:VarName},
486469
vi,
487470
)
488471
r = getindex.((vi,), vns, dists)
489-
lp = Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))
472+
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
490473
return r, lp, vi
491474
end
492475

0 commit comments

Comments
 (0)