Skip to content

Commit e0b165a

Browse files
sethaxenoxinabox
andauthored
Rules for LU decomposition of StridedMatrixes (#354)
* Add frule for lu decomposition * Eliminate allocation * Add rrule for lu * Test combinations of Zero and non-Zero * Test check=false is passed * Test getproperty LU * Add rules for inverse of LU * Increment version number * Avoid ops new to 1.6 * Efficiency improvements * Project tangents before use * Add link to blog post * Apply suggestions from code review Co-authored-by: Lyndon White <[email protected]> * Use check_equal throughout * Avoid reusing variable name * Refactor to use cotangent of `factor` * Add to_vec for LU * Use frule_test and rrule_test * Add additional comment * Don't declare variable name again * Correctly standardize factor cotangent * Don't reuse type name * Obviate explanation * Don't re-allocate ∂A Co-authored-by: Lyndon White <[email protected]>
1 parent de9e0a2 commit e0b165a

File tree

3 files changed

+288
-1
lines changed

3 files changed

+288
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.49"
3+
version = "0.7.50"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,204 @@
11
using LinearAlgebra: checksquare
22
using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
33

4+
#####
5+
##### `lu`
6+
#####
7+
8+
# These rules are necessary because the primals call LAPACK functions
9+
10+
# frule for square matrix was introduced in Eq. 3.6 of
11+
# de Hoog, F.R., Anderssen, R.S. and Lukas, M.A. (2011)
12+
# Differentiation of matrix functionals using triangular factorization.
13+
# Mathematics of Computation, 80 (275). p. 1585.
14+
# doi: http://doi.org/10.1090/S0025-5718-2011-02451-8
15+
# for derivations for wide and tall matrices, see
16+
# https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/
17+
18+
function frule(
19+
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
20+
)
21+
F = lu!(A, pivot; kwargs...)
22+
∂factors = pivot === Val(true) ? ΔA[F.p, :] : ΔA
23+
m, n = size(∂factors)
24+
q = min(m, n)
25+
if m == n # square A
26+
# minimal allocation computation of
27+
# ∂L = L * tril(L \ (P * ΔA) / U, -1)
28+
# ∂U = triu(L \ (P * ΔA) / U) * U
29+
# ∂factors = ∂L + ∂U
30+
L = UnitLowerTriangular(F.factors)
31+
U = UpperTriangular(F.factors)
32+
rdiv!(∂factors, U)
33+
ldiv!(L, ∂factors)
34+
∂L = lmul!(L, tril(∂factors, -1))
35+
∂U = rmul!(triu(∂factors), U)
36+
∂factors .= ∂L .+ ∂U
37+
elseif m < n # wide A, system is [P*A1 P*A2] = [L*U1 L*U2]
38+
L = UnitLowerTriangular(F.L)
39+
U = F.U
40+
ldiv!(L, ∂factors)
41+
@views begin
42+
∂factors1 = ∂factors[:, 1:q]
43+
∂factors2 = ∂factors[:, (q + 1):end]
44+
U1 = UpperTriangular(U[:, 1:q])
45+
U2 = U[:, (q + 1):end]
46+
end
47+
rdiv!(∂factors1, U1)
48+
∂L = tril(∂factors1, -1)
49+
mul!(∂factors2, ∂L, U2, -1, 1)
50+
lmul!(L, ∂L)
51+
rmul!(triu!(∂factors1), U1)
52+
∂factors1 .+= ∂L
53+
else # tall A, system is [P1*A; P2*A] = [L1*U; L2*U]
54+
L = F.L
55+
U = UpperTriangular(F.U)
56+
rdiv!(∂factors, U)
57+
@views begin
58+
∂factors1 = ∂factors[1:q, :]
59+
∂factors2 = ∂factors[(q + 1):end, :]
60+
L1 = UnitLowerTriangular(L[1:q, :])
61+
L2 = L[(q + 1):end, :]
62+
end
63+
ldiv!(L1, ∂factors1)
64+
∂U = triu(∂factors1)
65+
mul!(∂factors2, L2, ∂U, -1, 1)
66+
rmul!(∂U, U)
67+
lmul!(L1, tril!(∂factors1, -1))
68+
∂factors1 .+= ∂U
69+
end
70+
∂F = Composite{typeof(F)}(; factors=∂factors)
71+
return F, ∂F
72+
end
73+
74+
function rrule(
75+
::typeof(lu), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
76+
)
77+
F = lu(A, pivot; kwargs...)
78+
function lu_pullback(ΔF::Composite)
79+
Δfactors = ΔF.factors
80+
Δfactors isa AbstractZero && return (NO_FIELDS, Δfactors, DoesNotExist())
81+
factors = F.factors
82+
∂factors = eltype(A) <: Real ? real(Δfactors) : Δfactors
83+
∂A = similar(factors)
84+
m, n = size(A)
85+
q = min(m, n)
86+
if m == n # square A
87+
# ∂A = P' * (L' \ (tril(L' * ∂L, -1) + triu(∂U * U')) / U')
88+
L = UnitLowerTriangular(factors)
89+
U = UpperTriangular(factors)
90+
∂U = UpperTriangular(∂factors)
91+
tril!(copyto!(∂A, ∂factors), -1)
92+
lmul!(L', ∂A)
93+
copyto!(UpperTriangular(∂A), UpperTriangular(∂U * U'))
94+
rdiv!(∂A, U')
95+
ldiv!(L', ∂A)
96+
elseif m < n # wide A, system is [P*A1 P*A2] = [L*U1 L*U2]
97+
triu!(copyto!(∂A, ∂factors))
98+
@views begin
99+
factors1 = factors[:, 1:q]
100+
U2 = factors[:, (q + 1):end]
101+
∂A1 = ∂A[:, 1:q]
102+
∂A2 = ∂A[:, (q + 1):end]
103+
∂L = tril(∂factors[:, 1:q], -1)
104+
end
105+
L = UnitLowerTriangular(factors1)
106+
U1 = UpperTriangular(factors1)
107+
triu!(rmul!(∂A1, U1'))
108+
∂A1 .+= tril!(mul!(lmul!(L', ∂L), ∂A2, U2', -1, 1), -1)
109+
rdiv!(∂A1, U1')
110+
ldiv!(L', ∂A)
111+
else # tall A, system is [P1*A; P2*A] = [L1*U; L2*U]
112+
tril!(copyto!(∂A, ∂factors), -1)
113+
@views begin
114+
factors1 = factors[1:q, :]
115+
L2 = factors[(q + 1):end, :]
116+
∂A1 = ∂A[1:q, :]
117+
∂A2 = ∂A[(q + 1):end, :]
118+
∂U = triu(∂factors[1:q, :])
119+
end
120+
U = UpperTriangular(factors1)
121+
L1 = UnitLowerTriangular(factors1)
122+
tril!(lmul!(L1', ∂A1), -1)
123+
∂A1 .+= triu!(mul!(rmul!(∂U, U'), L2', ∂A2, -1, 1))
124+
ldiv!(L1', ∂A1)
125+
rdiv!(∂A, U')
126+
end
127+
if pivot === Val(true)
128+
∂A = ∂A[invperm(F.p), :]
129+
end
130+
return NO_FIELDS, ∂A, DoesNotExist()
131+
end
132+
return F, lu_pullback
133+
end
134+
135+
#####
136+
##### functions of `LU`
137+
#####
138+
139+
# this rrule is necessary because the primal mutates
140+
141+
function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:StridedMatrix{T}}}
142+
function getproperty_LU_pullback(ΔY)
143+
∂factors = if x === :L
144+
m, n = size(F.factors)
145+
S = eltype(ΔY)
146+
tril!([ΔY zeros(S, m, max(0, n - m))], -1)
147+
elseif x === :U
148+
m, n = size(F.factors)
149+
S = eltype(ΔY)
150+
triu!([ΔY; zeros(S, max(0, m - n), n)])
151+
elseif x === :factors
152+
Matrix(ΔY)
153+
else
154+
return (NO_FIELDS, DoesNotExist(), DoesNotExist())
155+
end
156+
∂F = Composite{TF}(; factors=∂factors)
157+
return NO_FIELDS, ∂F, DoesNotExist()
158+
end
159+
return getproperty(F, x), getproperty_LU_pullback
160+
end
161+
162+
# these rules are needed because the primal calls a LAPACK function
163+
164+
function frule((_, ΔF), ::typeof(LinearAlgebra.inv!), F::LU{<:Any,<:StridedMatrix})
165+
# factors must be square if the primal did not error
166+
L = UnitLowerTriangular(F.factors)
167+
U = UpperTriangular(F.factors)
168+
# compute ∂Y = -(U \ (L \ ∂L + ∂U / U) / L) * P while minimizing allocations
169+
m, n = size(F.factors)
170+
q = min(m, n)
171+
∂L = tril(m n ? ΔF.factors : view(ΔF.factors, :, 1:q), -1)
172+
∂U = triu(m n ? ΔF.factors : view(ΔF.factors, 1:q, :))
173+
∂Y = ldiv!(L, ∂L)
174+
∂Y .+= rdiv!(∂U, U)
175+
ldiv!(U, ∂Y)
176+
rdiv!(∂Y, L)
177+
rmul!(∂Y, -1)
178+
return LinearAlgebra.inv!(F), ∂Y[:, invperm(F.p)]
179+
end
180+
181+
function rrule(::typeof(inv), F::LU{<:Any,<:StridedMatrix})
182+
function inv_LU_pullback(ΔY)
183+
# factors must be square if the primal did not error
184+
L = UnitLowerTriangular(F.factors)
185+
U = UpperTriangular(F.factors)
186+
# compute the following while minimizing allocations
187+
# ∂U = - triu((U' \ ∂Y * P' / L') / U')
188+
# ∂L = - tril(L' \ (U' \ ∂Y * P' / L'), -1)
189+
∂factors = ΔY[:, F.p]
190+
ldiv!(U', ∂factors)
191+
rdiv!(∂factors, L')
192+
rmul!(∂factors, -1)
193+
∂L = tril!(L' \ ∂factors, -1)
194+
triu!(rdiv!(∂factors, U'))
195+
∂factors .+= ∂L
196+
∂F = Composite{typeof(F)}(; factors=∂factors)
197+
return NO_FIELDS, ∂F
198+
end
199+
return inv(F), inv_LU_pullback
200+
end
201+
4202
#####
5203
##### `svd`
6204
#####

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# TODO: move this to FiniteDifferences
2+
function FiniteDifferences.to_vec(X::LU)
3+
x_vec, back = to_vec(Matrix(X.factors))
4+
function LU_from_vec(x_vec)
5+
return LU(back(x_vec), X.ipiv, X.info)
6+
end
7+
return x_vec, LU_from_vec
8+
end
9+
110
function FiniteDifferences.to_vec(C::Cholesky)
211
C_vec, factors_from_vec = to_vec(C.factors)
312
function cholesky_from_vec(v)
@@ -12,6 +21,86 @@ function FiniteDifferences.to_vec(x::Val)
1221
end
1322

1423
@testset "Factorizations" begin
24+
@testset "lu decomposition" begin
25+
n = 10
26+
@testset "lu! frule" begin
27+
@testset "lu!(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
28+
T in (Float64, ComplexF64),
29+
pivot in (Val(true), Val(false)),
30+
m in (7, 10, 13)
31+
32+
A = randn(T, m, n)
33+
ΔA = rand_tangent(A)
34+
frule_test(lu!, (A, ΔA), (pivot, nothing))
35+
end
36+
@testset "check=false passed to primal function" begin
37+
Asingular = zeros(n, n)
38+
ΔAsingular = rand_tangent(Asingular)
39+
@test_throws SingularException frule(
40+
(Zero(), copy(ΔAsingular)), lu!, copy(Asingular), Val(true)
41+
)
42+
frule((Zero(), ΔAsingular), lu!, Asingular, Val(true); check=false)
43+
end
44+
end
45+
@testset "lu rrule" begin
46+
@testset "lu(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
47+
T in (Float64, ComplexF64),
48+
pivot in (Val(true), Val(false)),
49+
m in (7, 10, 13)
50+
51+
A = randn(T, m, n)
52+
ΔA = rand_tangent(A)
53+
F = lu(A, pivot)
54+
Δfactors = rand_tangent(F.factors)
55+
ΔF = Composite{typeof(F)}(; factors=Δfactors)
56+
rrule_test(lu, ΔF, (A, ΔA), (pivot, nothing))
57+
end
58+
@testset "check=false passed to primal function" begin
59+
Asingular = zeros(n, n)
60+
F = lu(Asingular, Val(true); check=false)
61+
ΔF = Composite{typeof(F)}(; U=rand_tangent(F.U), L=rand_tangent(F.L))
62+
@test_throws SingularException rrule(lu, Asingular, Val(true))
63+
_, back = rrule(lu, Asingular, Val(true); check=false)
64+
back(ΔF)
65+
end
66+
end
67+
@testset "LU" begin
68+
@testset "getproperty(::LU, k) rrule" begin
69+
# test that the getproperty rrule composes correctly with the lu rrule
70+
@testset "getproperty(lu(A::Matrix), :$k) for size(A)=($m, $n)" for
71+
k in (:U, :L, :factors),
72+
m in (7, 10, 13)
73+
74+
A = randn(m, n)
75+
F = lu(A)
76+
X = getproperty(F, k)
77+
ΔF = Composite{typeof(F)}(; factors=rand_tangent(F.factors))
78+
ΔX = rand_tangent(X)
79+
rrule_test(getproperty, ΔX, (F, ΔF), (k, nothing); check_inferred=false)
80+
end
81+
end
82+
@testset "matrix inverse using LU" begin
83+
@testset "LinearAlgebra.inv!(::LU) frule" begin
84+
@testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64)
85+
A = randn(T, n, n)
86+
F = lu(A, Val(true))
87+
ΔF = Composite{typeof(F)}(; factors=rand_tangent(F.factors))
88+
frule_test(LinearAlgebra.inv!, (F, ΔF))
89+
end
90+
end
91+
@testset "inv(::LU) rrule" begin
92+
@testset "inv(::LU{$T,<:StridedMatrix})" for T in (Float64,ComplexF64)
93+
A = randn(T, n, n)
94+
F = lu(A, Val(true))
95+
Y = inv(A)
96+
ΔF = Composite{typeof(F)}(; factors=rand_tangent(F.factors))
97+
ΔY = rand_tangent(Y)
98+
rrule_test(inv, ΔY, (F, ΔF))
99+
end
100+
end
101+
end
102+
end
103+
end
15104
@testset "svd" begin
16105
for n in [4, 6, 10], m in [3, 5, 10]
17106
X = randn(n, m)

0 commit comments

Comments
 (0)