diff --git a/Project.toml b/Project.toml index 48b74fc..4a52258 100644 --- a/Project.toml +++ b/Project.toml @@ -11,14 +11,18 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] -ChainRulesCore = "0.9.17, 0.10" -Distributions = "0.23, 0.24" +ChainRules = "1" +ChainRulesCore = "1" +ChainRulesTestUtils = "1" +Distributions = "0.23, 0.24, 0.25" FiniteDifferences = "0.11, 0.12" PDMats = "0.9, 0.10, 0.11" -Zygote = "0.5.5" +Zygote = "0.6" julia = "1" [extras] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -27,4 +31,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"] +test = ["ChainRules", "ChainRulesTestUtils", "Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"] diff --git a/src/PDMatsExtras.jl b/src/PDMatsExtras.jl index 1809b07..df1ada0 100644 --- a/src/PDMatsExtras.jl +++ b/src/PDMatsExtras.jl @@ -12,5 +12,6 @@ export submat include("psd_mat.jl") include("woodbury_pd_mat.jl") include("utils.jl") +include("chainrules.jl") end diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 0000000..80eb08b --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,69 @@ +@non_differentiable validate_woodbury_arguments(A, D, S) + +# Rule for Woodbury * Real. +# Ignoring Complex version for now. +function ChainRulesCore.rrule(::typeof(*), A::WoodburyPDMat, B::Real) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + primal = A * B + times_pullback(ȳ) = _times_pullback(ȳ, primal, A, B, (;A=project_A, B=project_B)) + return primal, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + primal = A * B + times_pullback(ȳ) = _times_pullback(ȳ, primal, A, B, (;A=project_A, B=project_B)) + return primal, times_pullback +end + +_times_pullback(ȳ::AbstractThunk, primal, A, B, proj) = _times_pullback(unthunk(ȳ), primal, A, B, proj) +# If the cotangent is a Matrix we first need to project down, otherwise ignore +_times_pullback(Ȳ::AbstractMatrix, primal, A, B, proj) = _times_pullback(ProjectTo(primal)(Ȳ), A, B, proj) +_times_pullback(ȳ::Tangent, primal, A, B, proj) = _times_pullback(ȳ, A, B, proj) + +function _times_pullback(Ȳ::Tangent, A::T, B::Real, proj) where {T<:WoodburyPDMat} + Ā = @thunk proj.A(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * B', S = Ȳ.S * B')) + B̄ = @thunk proj.B(dot(Ȳ.D, A.D) + dot(Ȳ.S, A.S)) + return (NoTangent(), Ā, B̄) +end + +function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat} + Ā = @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S)) + B̄ = @thunk proj.B(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * A, S = Ȳ.S * A)) + return (NoTangent(), Ā, B̄) +end + +# Composite pullbacks +function ChainRulesCore.rrule( + ::Type{T}, + A::AbstractMatrix, + D::Diagonal, + S::Diagonal, + ) where {T<:WoodburyPDMat} + return WoodburyPDMat(A, D, S), X̄ -> WoodburyPDMat_pullback(X̄, A, D, S) +end +WoodburyPDMat_pullback(X̄::Tangent, A, D, S) = (NoTangent(), X̄.A, X̄.D, X̄.S) +WoodburyPDMat_pullback(X̄::AbstractThunk, A, D, S) = WoodburyPDMat_pullback(unthunk(X̄), A, D, S) + +function ChainRulesCore.ProjectTo(W::T) where {T<:WoodburyPDMat} + fields = (A = W.A, D = W.D, S = W.S) + ChainRulesCore.ProjectTo{T}(; fields...) +end + +# +# Project the differential onto the Tangent{WoodburyPDMat}. +# This essentially computes the pullbacks for the components of the Woodbury +# i.e. from the definition: W = ADA' + S +# dW = ADdA' + AdDA' + dS +# => Ā = 2ADW̄, D̄=AW̄A', S̄ = W̄ +# More precise formulation available e.g. here: +# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf +function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat} + Ā = ProjectTo(project.A)((X̄ + X̄') * (project.A * project.D)) + D̄ = ProjectTo(project.D)(project.A' * (X̄) * project.A) + S̄ = ProjectTo(project.S)(X̄) + return Tangent{WoodburyPDMat}(; A = Ā, D = D̄, S = S̄) +end +(project::ProjectTo{T})(W::Tangent) where {T<:WoodburyPDMat} = W diff --git a/src/woodbury_pd_mat.jl b/src/woodbury_pd_mat.jl index 7f2e36a..bc0b407 100644 --- a/src/woodbury_pd_mat.jl +++ b/src/woodbury_pd_mat.jl @@ -62,8 +62,6 @@ function validate_woodbury_arguments(A, D, S) end end -@non_differentiable validate_woodbury_arguments(A, D, S) - function LinearAlgebra.logdet(W::WoodburyPDMat) C_S = cholesky(W.S) B = C_S.U' \ (W.A * cholesky(W.D).U') diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 0000000..fa93df3 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,46 @@ +@testset "ChainRules" begin + + A = randn(4, 2) + D = Diagonal(randn(2).^2 .+ 1) + S = Diagonal(randn(4).^2 .+ 1) + + W = WoodburyPDMat(A, D, S) + R = 2.0 + Dmat = Diagonal(rand(4,)) + + x = randn(size(A, 1)) + + @testset "Constructors" begin + test_rrule(WoodburyPDMat, W.A, W.D, W.S) + # This is a gradient, should be able to deal with negative elements (does not have to be PSD like Woodbury itself) + test_rrule(WoodburyPDMat, W.A, W.D, W.S; + output_tangent=Tangent{WoodburyPDMat}(; + A = rand(4,2), D = Diagonal(-1 * rand(2,)), S = Diagonal(-1 * rand(4,))) + ) + end + + # The rrules already in ChainRules are sufficient for these to work. We just test an example here. + @testset "*(Matrix-Woodbury)" begin + test_rrule(*, Dmat, W) + test_rrule(*, W, Dmat) + test_rrule(*, rand(4,4), W) + end + + @testset "*(Woodbury-Real)" begin + test_rrule(*, W, R) + test_rrule(*, R, W) + + # We can't test test_rrule(*, R, W; output_tangent = rand(size(W)...)) i.e. with a Matrix because + # FD requires the primal and tangent to be the same size. However, we can just call FD directly and overload + # the primal computation to return a Matrix: + @testset "Matrix CoTangent" begin + res, pb = ChainRulesCore.rrule(*, R, W) + output_tangent = rand(size(W)...) + f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> Matrix(*(x...)), output_tangent, (R, W))[1] + @test unthunk(pb(output_tangent)[3]).A ≈ f_jvp[2].A + @test unthunk(pb(output_tangent)[3]).D ≈ f_jvp[2].D + @test unthunk(pb(output_tangent)[3]).S ≈ f_jvp[2].S + @test unthunk(pb(output_tangent)[2]) ≈ f_jvp[1] + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 04fc6cc..f18a209 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using PDMatsExtras +using ChainRules using ChainRulesCore +using ChainRulesTestUtils using Distributions using FiniteDifferences using LinearAlgebra @@ -33,6 +35,7 @@ const TEST_MATRICES = Dict( include("test_ad.jl") include("psd_mat.jl") + include("chainrules.jl") include("woodbury_pd_mat.jl") include("utils.jl") end