diff --git a/Project.toml b/Project.toml index 033930d7c..5adbe274c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.120" +version = "0.25.121" [deps] AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 5806c00f3..fa6414236 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -261,6 +261,25 @@ logdetcov(d::MvNormal) = logdet(d.Σ) sqmahal(d::MvNormal, x::AbstractVector) = invquad(d.Σ, x .- d.μ) +function sqmahal(d::DiagNormal, x::AbstractVector) + # Faster than above as this avoids calculating (x .- d.µ) + T = promote_type(partype(d), eltype(x)) + sum = zero(T) + for i in eachindex(x) + @inbounds sum += abs2(x[i] - d.μ[i]) / d.Σ[i, i] + end + return sum +end + +function sqmahal(d::IsoNormal, x::AbstractVector) + T = promote_type(partype(d), eltype(x)) + sum = zero(T) + for i in eachindex(x) + @inbounds sum += abs2(x[i] - d.μ[i]) + end + return sum / d.Σ[1, 1] +end + sqmahal!(r::AbstractVector, d::MvNormal, x::AbstractMatrix) = invquad!(r, d.Σ, x .- d.μ)