Skip to content

Commit d2c7537

Browse files
author
Alex Robson
committed
Add and test rrule for Woodbury
1 parent 11955cd commit d2c7537

File tree

6 files changed

+130
-6
lines changed

6 files changed

+130
-6
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
name = "PDMatsExtras"
22
uuid = "2c7acb1b-7338-470f-b38f-951d2bcb9193"
33
authors = ["Invenia Technical Computing"]
4-
version = "2.5.1"
4+
version = "2.5.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
9+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
810
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
911
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1012
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1113
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
14+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1215

1316
[compat]
14-
ChainRulesCore = "0.9.17, 0.10"
17+
ChainRulesCore = "1"
1518
Distributions = "0.23, 0.24"
1619
FiniteDifferences = "0.11, 0.12"
1720
PDMats = "0.9, 0.10, 0.11"
18-
Zygote = "0.5.5"
21+
Zygote = "0.4, 0.5, 0.6"
1922
julia = "1"
2023

2124
[extras]

src/PDMatsExtras.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ export submat
1212
include("psd_mat.jl")
1313
include("woodbury_pd_mat.jl")
1414
include("utils.jl")
15+
include("chainrules.jl")
1516

1617
end

src/chainrules.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
@non_differentiable validate_woodbury_arguments(A, D, S)
2+
3+
function ChainRulesCore.rrule(
4+
::typeof(*), A::Real, B::WoodburyPDMat{T, TA, TD, TS}
5+
) where {T, TA, TD, TS}
6+
project_A = ProjectTo(A)
7+
project_B = ProjectTo(B)
8+
9+
function times_pullback(ȳ::AbstractMatrix)
10+
= unthunk(ȳ)
11+
= dot(Ȳ, B)
12+
= A' *
13+
return (
14+
NoTangent(),
15+
@thunk(project_A(Ā')),
16+
@thunk(project_B(B̄)),
17+
)
18+
end
19+
20+
function times_pullback(ȳ::Tangent{<:WoodburyPDMat})
21+
= unthunk(ȳ)
22+
= dot(Ȳ.A *.D *.A' +.S, B)
23+
=.A * (A' *.D) *.A' + A' *.S
24+
return (
25+
NoTangent(),
26+
@thunk(project_A(Ā')),
27+
@thunk(project_B(B̄)),
28+
)
29+
end
30+
return A * B, times_pullback
31+
end
32+
33+
function ChainRulesCore.ProjectTo(W::WoodburyPDMat)
34+
function dW(W̄)
35+
(W̄) = ProjectTo(W.A)(collect((W.D * W.A' *' + W.D * W.A' * W̄)'))
36+
(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A)
37+
(W̄) = ProjectTo(W.S)(W̄)
38+
return Tangent{typeof(W)}(; A = (W̄), D = (W̄), S = (W̄))
39+
end
40+
return dW
41+
end
42+
43+
44+
45+
46+

src/woodbury_pd_mat.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ function validate_woodbury_arguments(A, D, S)
6262
end
6363
end
6464

65-
@non_differentiable validate_woodbury_arguments(A, D, S)
66-
6765
function LinearAlgebra.logdet(W::WoodburyPDMat)
6866
C_S = cholesky(W.S)
6967
B = C_S.U' \ (W.A * cholesky(W.D).U')
@@ -90,4 +88,6 @@ end
9088
# NOTE: the parameterisation to scale up the Woodbury matrix is not unique. Here we
9189
# implement one way to scale it.
9290
*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
93-
*(c::Real, a::WoodburyPDMat) = a * c
91+
*(c::Real, a::WoodburyPDMat) = a * c
92+
*(c::Diagonal{T}, a::WoodburyPDMat) where {T<:Real} = c * Matrix(a)
93+
*(c1::Diagonal{T}, a::WoodburyPDMat, c2::Diagonal{T}) where {T} = c1 * Matrix(a) * c2

test/chainrules.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Create a struct to hold the fields of the Woodbury.
2+
# We do this because the Woodbury cannot represent it's own tangent.
3+
# I.e. for Y = f(W,...), W̄ = ∂Y / ∂W, is not necessarily a valid Woodbury.
4+
# Consider, e.g. the case of a Positive Diagonal matrix
5+
struct WoodburyLike
6+
A
7+
D
8+
S
9+
end
10+
11+
# Overwrite the generic to_bec and replace with the almost identical Woodbury specific.
12+
# This means that in FiniteDifferences, the WoodburyLike matrix is created instead of the Woodbury.
13+
# Because the construction is forced there, this would bypass the valdidation checks on the constructor.
14+
function FiniteDifferences.to_vec(x::T) where {T<:WoodburyPDMat}
15+
val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T))
16+
vals = first.(val_vecs_and_backs)
17+
backs = last.(val_vecs_and_backs)
18+
19+
v, vals_from_vec = to_vec(vals)
20+
function structtype_from_vec(v::Vector{<:Real})
21+
val_vecs = vals_from_vec(v)
22+
values = map((b, v) -> b(v), backs, val_vecs)
23+
WoodburyLike(values...)
24+
end
25+
return v, structtype_from_vec
26+
end
27+
28+
# Assign some algebra for the WoodburyLike.
29+
WoodburyPDMat(S::WoodburyLike) = WoodburyPDMat(S.A, S.D, S.S)
30+
Base.:*(A::AbstractVecOrMat, B::WoodburyLike) = A * WoodburyPDMat(B)
31+
Base.:*(A::WoodburyLike, B::AbstractVecOrMat) = WoodburyPDMat(A) * B
32+
Base.:*(A::Real, B::WoodburyLike) = A * WoodburyPDMat(B)
33+
Base.:*(A::WoodburyLike, B::Real) = WoodburyPDMat(A) * B
34+
LinearAlgebra.dot(A, B::WoodburyLike) = dot(A, WoodburyPDMat(B))
35+
36+
@testset "ChainRules" begin
37+
38+
W = WoodburyPDMat(rand(4,2), Diagonal(rand(2,)), Diagonal(rand(4,)))
39+
R = 2.0
40+
D = Diagonal(rand(4,))
41+
42+
@testset "*(Matrix-Woodbury)" begin
43+
test_rrule(*, D, W)
44+
test_rrule(*, W, D)
45+
test_rrule(*, rand(4,4), W)
46+
end
47+
48+
@testset "*(Real-Woodbury" begin
49+
@testset "Matrix Tangent" begin
50+
###
51+
52+
primal = R * W
53+
54+
# Matrix Tangent
55+
T = rand_tangent(Matrix(primal))
56+
f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> Matrix(*(x...)), T, (R, W))[1]
57+
= dot(T, W')
58+
= conj(R) * T
59+
60+
f_jvp[1]
61+
(W.D * W.A' *' + W.D * W.A' * W̄) f_jvp[2].A' # A transpose.
62+
Diagonal(W.A' * (W̄) * W.A) f_jvp[2].D # D
63+
Diagonal(W̄) f_jvp[2].S # S
64+
65+
# Cannot get this to work. Here the T will be
66+
# T = rand_tangent(primal::WoodburyPDMat) which breaks.
67+
# test_rrule(*, 5.0, W)
68+
69+
end
70+
end
71+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
using Test
12
using PDMatsExtras
23
using ChainRulesCore
4+
using ChainRulesTestUtils
35
using Distributions
46
using FiniteDifferences
57
using LinearAlgebra
@@ -33,6 +35,7 @@ const TEST_MATRICES = Dict(
3335
include("test_ad.jl")
3436

3537
include("psd_mat.jl")
38+
include("chainrules.jl")
3639
include("woodbury_pd_mat.jl")
3740
include("utils.jl")
3841
end

0 commit comments

Comments
 (0)