Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
789ad0a
test and code for frule
matbesancon Apr 18, 2022
24861ca
added tests
matbesancon Apr 24, 2022
6e6c029
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Apr 25, 2022
2bcc217
diff on MvNormal only for now
matbesancon Apr 25, 2022
1e9571b
unthunk when only one element
matbesancon May 23, 2022
ac0995e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
522c13e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
6529a4a
no backing
matbesancon May 24, 2022
4e4d982
conflict
matbesancon May 24, 2022
d398140
fix inference
matbesancon May 24, 2022
661de16
unthunk common computation
matbesancon May 26, 2022
3c5007c
unthunk common computation
matbesancon May 26, 2022
1f18e67
remove rules for logpdf
matbesancon May 27, 2022
b60147d
revert changes
matbesancon Jul 30, 2022
375cad8
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
c34a3ed
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
2d96680
revert changes
matbesancon Jul 30, 2022
f32c223
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 30, 2022
b225298
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
cf1242a
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Jul 31, 2022
e2846e8
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 31, 2022
6bc85da
avoid materializing Matrix
matbesancon Jul 31, 2022
e303416
revert
matbesancon Jul 31, 2022
e21506a
fix revert
matbesancon Jul 31, 2022
fd272ae
no alloc
matbesancon Jul 31, 2022
8b7d451
revert fdiff
matbesancon Jul 31, 2022
53601fe
Update src/multivariate/mvnormal.jl
matbesancon Jul 31, 2022
2c75061
fix op
matbesancon Jul 31, 2022
6d88e4d
fix op bis, revert to invcov
matbesancon Jul 31, 2022
1f00a3b
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Aug 20, 2022
8ebf419
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Oct 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 67 additions & 13 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) =
length(d::MvNormal) = length(d.μ)
mean(d::MvNormal) = d.μ
params(d::MvNormal) = (d.μ, d.Σ)
@inline partype(d::MvNormal{T}) where {T<:Real} = T
@inline partype(::MvNormal{T}) where {T<:Real} = T

var(d::MvNormal) = diag(d.Σ)
cov(d::MvNormal) = Matrix(d.Σ)
Expand Down Expand Up @@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats
tw::Float64 # total sample weight
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move unrelated changes to a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there were all relatively minor things (unused variables)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert non-CR changes? It seems not only unused names were removed but also some types and dispatches changed, creating slight inconsistencies and related to the open issue about type parameters in fit. IMO it would br much cleaner to avoid these additonal changes in this PR here and instead fix the dispatches (and names) in a separate PR in a consistent way.

d = size(x, 1)
n = size(x, 2)
s = vec(sum(x, dims=2))
Expand All @@ -382,7 +382,7 @@ function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
MvNormalStats(s, m, s2, Float64(n))
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
d = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions."))
Expand Down Expand Up @@ -410,13 +410,13 @@ end
# each kind of covariance
#

fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss)
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x)
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w)
fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss)
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x)
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w)

fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw))
fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw))

function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64})
n = size(x, 2)
mu = vec(mean(x, dims=2))
z = x .- mu
Expand All @@ -425,7 +425,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, PDMat(C))
end

function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -445,7 +445,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVec
MvNormal(mu, PDMat(C))
end

function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64})
m = size(x, 1)
n = size(x, 2)

Expand All @@ -460,7 +460,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, PDiagMat(va))
end

function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -479,7 +479,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVec
MvNormal(mu, PDiagMat(va))
end

function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64})
m = size(x, 1)
n = size(x, 2)

Expand All @@ -495,7 +495,7 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, ScalMat(m, va / (m * n)))
end

function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -515,3 +515,57 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect
end
MvNormal(mu, ScalMat(m, va / (m * sw)))
end

## Differentiation

function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
Δd = ChainRulesCore.unthunk(Δd)
Δy = -dot(Δd.Σ, invcov(d)) / 2
return y, Δy
end

function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
function mvnormal_c0_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
∂Σ = -dy/2 * invcov(d)
∂d = ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ)
return ChainRulesCore.NoTangent(), ∂d
end
return y, mvnormal_c0_pullback
end

function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
(_, Δd, Δx) = dargs
Δd = ChainRulesCore.unthunk(Δd)
Δx = ChainRulesCore.unthunk(Δx)
Σinv = invcov(d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid computing the inverse?

# TODO optimize
dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ')
dx = 2 * dot(Σinv * (x - d.μ), Δx)
dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ)
Δy = dΣ + dx + dμ
return (y, Δy)
end

function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
function sqmahal_pullback(dy)
Σinv = invcov(d)
dy = ChainRulesCore.unthunk(dy)
∂x = ChainRulesCore.@thunk(begin
2dy * Σinv * (x - d.μ)
end)
∂d = ChainRulesCore.@thunk(begin
cx = x - d.μ
∂μ = -2dy * Σinv * cx
∂J = dy * cx * cx'
∂Σ = - Σinv * ∂J * Σinv
ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ)
end)
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
end
return y, sqmahal_pullback
end
53 changes: 53 additions & 0 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tests on Multivariate Normal distributions

import PDMats
import PDMats: ScalMat, PDiagMat, PDMat
if isdefined(PDMats, :PDSparseMat)
import PDMats: PDSparseMat
Expand All @@ -9,6 +10,8 @@ using Distributions
using LinearAlgebra, Random, Test
using SparseArrays
using FillArrays
using ChainRulesCore
using ChainRulesTestUtils

###### General Testing

Expand Down Expand Up @@ -302,3 +305,53 @@ end
x = rand(d)
@test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2])
end

@testset "MvNormal differentiation rules" begin
for n in (3, 10)
for _ in 1:10
A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n)
@assert isposdef(A)
d = MvNormal(randn(n), A)
# make ΔΣ symmetric, such that Σ ± ΔΣ is PSD
t = 0.001 * ChainRulesTestUtils.rand_tangent(d)
t.Σ .+= t.Σ'
if eigmin(t.Σ) < 0
while eigmin(d.Σ + t.Σ) < 0
t.Σ .*= 0.8
end
end
if eigmax(t.Σ) > 0
while eigmin(d.Σ - t.Σ) < 0
t.Σ .*= 0.8
end
end
# mvnormal_c0
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d)
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d)
@test y_r ≈ y
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ))
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ))
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
(_, ∇c0) = @inferred c0_pullback(1.0)
∇c0 = ChainRulesCore.unthunk(∇c0)
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4
# sqmahal
x = randn(n)
Δx = 0.0001 * randn(n)
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x)
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x)
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0)
∇s_d = ChainRulesCore.unthunk(∇s_d)
∇s_x = ChainRulesCore.unthunk(∇s_x)
@test yr ≈ y
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
end
end
end