Skip to content

Commit e055f7a

Browse files
authored
do not convert covariance/precision matrices of AbstractMvNormal subtypes to dense Matrix (#1373)
1 parent d666acb commit e055f7a

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

src/multivariate/mvnormal.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,10 @@ function kldivergence(p::AbstractMvNormal, q::AbstractMvNormal)
106106
# This is the generic implementation for AbstractMvNormal, you might need to specialize for your type
107107
length(p) == length(q) ||
108108
throw(DimensionMismatch("Distributions p and q have different dimensions $(length(p)) and $(length(q))"))
109-
# logdetcov is used separately from _cov for any potential optimization done there
110-
return (tr(_cov(q) \ _cov(p)) + sqmahal(q, mean(p)) - length(p) + logdetcov(q) - logdetcov(p)) / 2
109+
# logdetcov is used for any potential optimization done there
110+
return (tr(cov(q) \ cov(p)) + sqmahal(q, mean(p)) - length(p) + logdetcov(q) - logdetcov(p)) / 2
111111
end
112112

113-
# This is a workaround to take advantage of the PDMats objects for MvNormal and avoid copies as Matrix
114-
# TODO: Remove this once `cov(::MvNormal)` returns the PDMats object
115-
_cov(d::AbstractMvNormal) = cov(d)
116-
117113
"""
118114
invcov(d::AbstractMvNormal)
119115
@@ -256,10 +252,9 @@ params(d::MvNormal) = (d.μ, d.Σ)
256252
@inline partype(d::MvNormal{T}) where {T<:Real} = T
257253

258254
var(d::MvNormal) = diag(d.Σ)
259-
cov(d::MvNormal) = Matrix(d.Σ)
260-
_cov(d::MvNormal) = d.Σ
255+
cov(d::MvNormal) = d.Σ
261256

262-
invcov(d::MvNormal) = Matrix(inv(d.Σ))
257+
invcov(d::MvNormal) = inv(d.Σ)
263258
logdetcov(d::MvNormal) = logdet(d.Σ)
264259

265260
### Evaluation

src/multivariate/mvnormalcanon.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ params(d::MvNormalCanon) = (d.μ, d.h, d.J)
157157
Base.eltype(::Type{<:MvNormalCanon{T}}) where {T} = T
158158

159159
var(d::MvNormalCanon) = diag(inv(d.J))
160-
cov(d::MvNormalCanon) = Matrix(inv(d.J))
161-
invcov(d::MvNormalCanon) = Matrix(d.J)
160+
cov(d::MvNormalCanon) = inv(d.J)
161+
invcov(d::MvNormalCanon) = d.J
162162
logdetcov(d::MvNormalCanon) = -logdet(d.J)
163163

164164

0 commit comments

Comments
 (0)