- 
                Notifications
    You must be signed in to change notification settings 
- Fork 432
Differentiating mvnormal #1554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Differentiating mvnormal #1554
Changes from 4 commits
789ad0a
              24861ca
              6e6c029
              2bcc217
              1e9571b
              ac0995e
              522c13e
              6529a4a
              4e4d982
              d398140
              661de16
              3c5007c
              1f18e67
              b60147d
              375cad8
              c34a3ed
              2d96680
              f32c223
              b225298
              cf1242a
              e2846e8
              6bc85da
              e303416
              e21506a
              fd272ae
              8b7d451
              53601fe
              2c75061
              6d88e4d
              1f00a3b
              8ebf419
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) = | |||||||||||
| length(d::MvNormal) = length(d.μ) | ||||||||||||
| mean(d::MvNormal) = d.μ | ||||||||||||
| params(d::MvNormal) = (d.μ, d.Σ) | ||||||||||||
| @inline partype(d::MvNormal{T}) where {T<:Real} = T | ||||||||||||
| @inline partype(::MvNormal{T}) where {T<:Real} = T | ||||||||||||
|  | ||||||||||||
| var(d::MvNormal) = diag(d.Σ) | ||||||||||||
| cov(d::MvNormal) = Matrix(d.Σ) | ||||||||||||
|  | @@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats | |||||||||||
| tw::Float64 # total sample weight | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
|          | ||||||||||||
| d = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
| s = vec(sum(x, dims=2)) | ||||||||||||
|  | @@ -382,7 +382,7 @@ function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) | |||||||||||
| MvNormalStats(s, m, s2, Float64(n)) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| d = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
| length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions.")) | ||||||||||||
|  | @@ -410,13 +410,13 @@ end | |||||||||||
| # each kind of covariance | ||||||||||||
| # | ||||||||||||
|  | ||||||||||||
| fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) | ||||||||||||
| fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) | ||||||||||||
| fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) | ||||||||||||
| fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) | ||||||||||||
| fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) | ||||||||||||
| fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) | ||||||||||||
|  | ||||||||||||
| fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) | ||||||||||||
| fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) | ||||||||||||
|  | ||||||||||||
| function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| n = size(x, 2) | ||||||||||||
| mu = vec(mean(x, dims=2)) | ||||||||||||
| z = x .- mu | ||||||||||||
|  | @@ -425,7 +425,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) | |||||||||||
| MvNormal(mu, PDMat(C)) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| m = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
| length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||||||||
|  | @@ -445,7 +445,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVec | |||||||||||
| MvNormal(mu, PDMat(C)) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| m = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
|  | ||||||||||||
|  | @@ -460,7 +460,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) | |||||||||||
| MvNormal(mu, PDiagMat(va)) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| m = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
| length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||||||||
|  | @@ -479,7 +479,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVec | |||||||||||
| MvNormal(mu, PDiagMat(va)) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
| m = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
|  | ||||||||||||
|  | @@ -495,7 +495,7 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) | |||||||||||
| MvNormal(mu, ScalMat(m, va / (m * n))) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
| m = size(x, 1) | ||||||||||||
| n = size(x, 2) | ||||||||||||
| length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||||||||
|  | @@ -515,3 +515,95 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect | |||||||||||
| end | ||||||||||||
| MvNormal(mu, ScalMat(m, va / (m * sw))) | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| ## Differentiation | ||||||||||||
|  | ||||||||||||
| function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) | ||||||||||||
| c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) | ||||||||||||
| sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) | ||||||||||||
|          | ||||||||||||
| return c0 - sq/2, ChainRulesCore.@thunk(begin | ||||||||||||
| Δc0 = ChainRulesCore.unthunk(Δc0) | ||||||||||||
| Δsq = ChainRulesCore.unthunk(Δsq) | ||||||||||||
| Δc0 - Δsq/2 | ||||||||||||
| end) | ||||||||||||
|          | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) | ||||||||||||
| c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) | ||||||||||||
| sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) | ||||||||||||
|          | ||||||||||||
| function logpdf_MvNormal_pullback(dy) | ||||||||||||
| dy = ChainRulesCore.unthunk(dy) | ||||||||||||
| (_, ∂d_c0) = c0_pullback(dy) | ||||||||||||
| ∂d_c0 = ChainRulesCore.unthunk(∂d_c0) | ||||||||||||
| (_, ∂d_sq, ∂x_sq) = sq_pullback(dy) | ||||||||||||
| ∂d_sq = ChainRulesCore.unthunk(∂d_sq) | ||||||||||||
| ∂x_sq = ChainRulesCore.unthunk(∂x_sq) | ||||||||||||
| backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}(( | ||||||||||||
| (∂d_c0.μ - 0.5 * ∂d_sq.μ), | ||||||||||||
| (∂d_c0.Σ - 0.5 * ∂d_sq.Σ), | ||||||||||||
| )) | ||||||||||||
| ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) | ||||||||||||
|         
                  devmotion marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||||||||||
| return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq | ||||||||||||
|         
                  matbesancon marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||||||||||
| end | ||||||||||||
| return c0 - 0.5 * sq, logpdf_MvNormal_pullback | ||||||||||||
|         
                  matbesancon marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||||||||||
| end | ||||||||||||
|  | ||||||||||||
| function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) | ||||||||||||
| y = mvnormal_c0(d) | ||||||||||||
| Δy = ChainRulesCore.@thunk(begin | ||||||||||||
| Δd = ChainRulesCore.unthunk(Δd) | ||||||||||||
| -dot(Δd.Σ, invcov(d)) / 2 | ||||||||||||
| end) | ||||||||||||
|          | ||||||||||||
| Δy = ChainRulesCore.@thunk(begin | |
| Δd = ChainRulesCore.unthunk(Δd) | |
| -dot(Δd.Σ, invcov(d)) / 2 | |
| end) | |
| Δy = -dot(Δd.Σ, invcov(d)) / 2 | 
        
          
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No thunk 🙂
        
          
              
                  matbesancon marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
              
                  devmotion marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
              
                  matbesancon marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
      | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # Tests on Multivariate Normal distributions | ||
|  | ||
| import PDMats | ||
| import PDMats: ScalMat, PDiagMat, PDMat | ||
| if isdefined(PDMats, :PDSparseMat) | ||
| import PDMats: PDSparseMat | ||
|  | @@ -9,6 +10,8 @@ using Distributions | |
| using LinearAlgebra, Random, Test | ||
| using SparseArrays | ||
| using FillArrays | ||
| using ChainRulesCore | ||
| using ChainRulesTestUtils | ||
|  | ||
| ###### General Testing | ||
|  | ||
|  | @@ -302,3 +305,67 @@ end | |
| x = rand(d) | ||
| @test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2]) | ||
| end | ||
|  | ||
| @testset "MvNormal differentiation rules" begin | ||
| for n in (3, 10) | ||
| for _ in 1:10 | ||
| A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n) | ||
| @assert isposdef(A) | ||
| d = MvNormal(randn(n), A) | ||
| # make ΔΣ symmetric, such that Σ ± ΔΣ is PSD | ||
| t = 0.001 * ChainRulesTestUtils.rand_tangent(d) | ||
| t.Σ .+= t.Σ' | ||
| if eigmin(t.Σ) < 0 | ||
| while eigmin(d.Σ + t.Σ) < 0 | ||
| t.Σ .*= 0.8 | ||
| end | ||
| end | ||
| if eigmax(t.Σ) > 0 | ||
| while eigmin(d.Σ - t.Σ) < 0 | ||
| t.Σ .*= 0.8 | ||
| end | ||
| end | ||
| # mvnormal_c0 | ||
| (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d) | ||
| y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d) | ||
| @test y_r ≈ y | ||
| y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ)) | ||
| @test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 | ||
| y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) | ||
| @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
| (_, ∇c0) = c0_pullback(1.0) | ||
| ∇c0 = ChainRulesCore.unthunk(∇c0) | ||
| @test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4 | ||
| @test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4 | ||
| # sqmahal | ||
| x = randn(n) | ||
| Δx = 0.0001 * randn(n) | ||
| (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x) | ||
| (yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x) | ||
| (_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0) | ||
| ∇s_d = ChainRulesCore.unthunk(∇s_d) | ||
| ∇s_x = ChainRulesCore.unthunk(∇s_x) | ||
| @test yr ≈ y | ||
| y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) | ||
| y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) | ||
| @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 | ||
| @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
| @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 | ||
| @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 | ||
| # _logpdf | ||
| (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) | ||
| (yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x) | ||
| @test y ≈ yr | ||
| # inference broken | ||
| # (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) | ||
| (_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0) | ||
|  | ||
| y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) | ||
| y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) | ||
| @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
| @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 | ||
| @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 | ||
| @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 | ||
|          | ||
| end | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move unrelated changes to a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there were all relatively minor things (unused variables)