Skip to content

Commit f74f95d

Browse files
committed
use TuringScalMvNormal when x is tracked
1 parent 105e5e8 commit f74f95d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/multivariate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
7575
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
7676
end
7777

78+
for T in (:TrackedVector, :TrackedMatrix)
79+
@eval function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
80+
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
81+
end
82+
end
83+
7884
import StatsBase: entropy
7985
function entropy(d::TuringDiagMvNormal)
8086
T = eltype(d.σ)

0 commit comments

Comments
 (0)