Skip to content

Commit 07fb93c

Browse files
Jutholkdvos
andauthored
Pullback rules for truncation methods (#53)
* prepare pullback rules for truncation * some updates * Fix small typos * fix svd_trunc_pullback and add tests * fix and test eig pullback, including trunc alternative * fix rank edge case * default value of `ind` is untruncated * update docstrings * remove some allocations * mark pullbacks as public * Bump v0.4.1 [skip ci] --------- Co-authored-by: Lukas Devos <[email protected]>
1 parent 6c0b297 commit 07fb93c

File tree

10 files changed

+616
-133
lines changed

10 files changed

+616
-133
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <[email protected]> and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module MatrixAlgebraKitChainRulesCoreExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: copy_input, TruncatedAlgorithm, zero!
4+
using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5+
TruncatedAlgorithm, findtruncated, findtruncated_svd
56
using ChainRulesCore
67
using LinearAlgebra
78

@@ -24,7 +25,7 @@ for qr_f in (:qr_compact, :qr_full)
2425
QR = $(qr_f!)(Ac, QR, alg)
2526
function qr_pullback(ΔQR)
2627
ΔA = zero(A)
27-
MatrixAlgebraKit.qr_compact_pullback!(ΔA, QR, unthunk.(ΔQR))
28+
MatrixAlgebraKit.qr_compact_pullback!(ΔA, A, QR, unthunk.(ΔQR))
2829
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
2930
end
3031
function qr_pullback(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
@@ -36,7 +37,7 @@ for qr_f in (:qr_compact, :qr_full)
3637
end
3738
function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
3839
Ac = copy_input(qr_full, A)
39-
QR = MatrixAlgebraKit.initialize_output(qr_full!, A, alg)
40+
QR = initialize_output(qr_full!, A, alg)
4041
Q, R = qr_full!(Ac, QR, alg)
4142
N = copy!(N, view(Q, 1:size(A, 1), (size(A, 2) + 1):size(A, 1)))
4243
function qr_null_pullback(ΔN)
@@ -45,7 +46,7 @@ function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
4546
minmn = min(m, n)
4647
ΔQ = zero!(similar(A, (m, m)))
4748
view(ΔQ, 1:m, (minmn + 1):m) .= unthunk.(ΔN)
48-
MatrixAlgebraKit.qr_compact_pullback!(ΔA, (Q, R), (ΔQ, ZeroTangent()))
49+
MatrixAlgebraKit.qr_compact_pullback!(ΔA, A, (Q, R), (ΔQ, ZeroTangent()))
4950
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
5051
end
5152
function qr_null_pullback(::ZeroTangent) # is this extra definition useful?
@@ -62,7 +63,7 @@ for lq_f in (:lq_compact, :lq_full)
6263
LQ = $(lq_f!)(Ac, LQ, alg)
6364
function lq_pullback(ΔLQ)
6465
ΔA = zero(A)
65-
MatrixAlgebraKit.lq_compact_pullback!(ΔA, LQ, unthunk.(ΔLQ))
66+
MatrixAlgebraKit.lq_compact_pullback!(ΔA, A, LQ, unthunk.(ΔLQ))
6667
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
6768
end
6869
function lq_pullback(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
@@ -74,7 +75,7 @@ for lq_f in (:lq_compact, :lq_full)
7475
end
7576
function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
7677
Ac = copy_input(lq_full, A)
77-
LQ = MatrixAlgebraKit.initialize_output(lq_full!, A, alg)
78+
LQ = initialize_output(lq_full!, A, alg)
7879
L, Q = lq_full!(Ac, LQ, alg)
7980
Nᴴ = copy!(Nᴴ, view(Q, (size(A, 1) + 1):size(A, 2), 1:size(A, 2)))
8081
function lq_null_pullback(ΔNᴴ)
@@ -83,7 +84,7 @@ function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
8384
minmn = min(m, n)
8485
ΔQ = zero!(similar(A, (n, n)))
8586
view(ΔQ, (minmn + 1):n, 1:n) .= unthunk.(ΔNᴴ)
86-
MatrixAlgebraKit.lq_compact_pullback!(ΔA, (L, Q), (ZeroTangent(), ΔQ))
87+
MatrixAlgebraKit.lq_compact_pullback!(ΔA, A, (L, Q), (ZeroTangent(), ΔQ))
8788
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
8889
end
8990
function lq_null_pullback(::ZeroTangent) # is this extra definition useful?
@@ -95,22 +96,46 @@ end
9596
for eig in (:eig, :eigh)
9697
eig_f = Symbol(eig, "_full")
9798
eig_f! = Symbol(eig_f, "!")
98-
eig_f_pb! = Symbol(eig, "_full_pullback!")
99+
eig_pb! = Symbol(eig, "_pullback!")
99100
eig_pb = Symbol(eig, "_pullback")
101+
eig_t! = Symbol(eig, "_trunc!")
102+
eig_t_pb = Symbol(eig, "_trunc_pullback")
103+
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
100104
@eval begin
101105
function ChainRulesCore.rrule(::typeof($eig_f!), A::AbstractMatrix, DV, alg)
102106
Ac = copy_input($eig_f, A)
103107
DV = $(eig_f!)(Ac, DV, alg)
104108
function $eig_pb(ΔDV)
105109
ΔA = zero(A)
106-
MatrixAlgebraKit.$eig_f_pb!(ΔA, DV, unthunk.(ΔDV))
110+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV))
107111
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
108112
end
109113
function $eig_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
110114
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
111115
end
112116
return DV, $eig_pb
113117
end
118+
function ChainRulesCore.rrule(
119+
::typeof($eig_t!), A::AbstractMatrix, DV,
120+
alg::TruncatedAlgorithm
121+
)
122+
Ac = copy_input($eig_f, A)
123+
D, V = $(eig_f!)(Ac, DV, alg.alg)
124+
ind = findtruncated(diagview(D), alg.trunc)
125+
return (Diagonal(diagview(D)[ind]), V[:, ind]),
126+
$(_make_eig_t_pb)(A, (D, V), ind)
127+
end
128+
function $(_make_eig_t_pb)(A, DV, ind)
129+
function $eig_t_pb(ΔDV)
130+
ΔA = zero(A)
131+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind)
132+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
133+
end
134+
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
135+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
136+
end
137+
return $eig_t_pb
138+
end
114139
end
115140
end
116141

@@ -122,7 +147,7 @@ for svd_f in (:svd_compact, :svd_full)
122147
USVᴴ = $(svd_f!)(Ac, USVᴴ, alg)
123148
function svd_pullback(ΔUSVᴴ)
124149
ΔA = zero(A)
125-
MatrixAlgebraKit.svd_compact_pullback!(ΔA, USVᴴ, unthunk.(ΔUSVᴴ))
150+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ))
126151
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
127152
end
128153
function svd_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
@@ -134,27 +159,33 @@ for svd_f in (:svd_compact, :svd_full)
134159
end
135160

136161
function ChainRulesCore.rrule(
137-
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm
162+
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ,
163+
alg::TruncatedAlgorithm
138164
)
139-
Ac = MatrixAlgebraKit.copy_input(svd_compact, A)
140-
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
165+
Ac = copy_input(svd_compact, A)
166+
U, S, Vᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
167+
ind = findtruncated_svd(diagview(S), alg.trunc)
168+
return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]),
169+
_make_svd_trunc_pullback(A, (U, S, Vᴴ), ind)
170+
end
171+
function _make_svd_trunc_pullback(A, USVᴴ, ind)
141172
function svd_trunc_pullback(ΔUSVᴴ)
142173
ΔA = zero(A)
143-
MatrixAlgebraKit.svd_compact_pullback!(ΔA, USVᴴ, unthunk.(ΔUSVᴴ))
174+
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind)
144175
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
145176
end
146177
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
147178
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
148179
end
149-
return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, alg.trunc), svd_trunc_pullback
180+
return svd_trunc_pullback
150181
end
151182

152183
function ChainRulesCore.rrule(::typeof(left_polar!), A::AbstractMatrix, WP, alg)
153184
Ac = copy_input(left_polar, A)
154185
WP = left_polar!(Ac, WP, alg)
155186
function left_polar_pullback(ΔWP)
156187
ΔA = zero(A)
157-
MatrixAlgebraKit.left_polar_pullback!(ΔA, WP, unthunk.(ΔWP))
188+
MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, unthunk.(ΔWP))
158189
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
159190
end
160191
function left_polar_pullback(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
@@ -168,7 +199,7 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ,
168199
PWᴴ = right_polar!(Ac, PWᴴ, alg)
169200
function right_polar_pullback(ΔPWᴴ)
170201
ΔA = zero(A)
171-
MatrixAlgebraKit.right_polar_pullback!(ΔA, PWᴴ, unthunk.(ΔPWᴴ))
202+
MatrixAlgebraKit.right_polar_pullback!(ΔA, A, PWᴴ, unthunk.(ΔPWᴴ))
172203
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
173204
end
174205
function right_polar_pullback(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?

src/MatrixAlgebraKit.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
5353
:TruncationByError, :TruncationIntersection
5454
)
5555
)
56+
eval(
57+
Expr(
58+
:public, :qr_compact_pullback!, :lq_compact_pullback!, :left_polar_pullback!,
59+
:right_polar_pullback!, :eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!,
60+
:eigh_trunc_pullback!, :svd_pullback!, :svd_trunc_pullback!
61+
)
62+
)
5663
end
5764

5865
include("common/defaults.jl")

src/pullbacks/eig.jl

Lines changed: 124 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
1-
function eig_full_pullback!(
2-
ΔA::AbstractMatrix, DV, ΔDV;
1+
"""
2+
eig_pullback!(
3+
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
4+
tol = default_pullback_gaugetol(DV[1]),
5+
degeneracy_atol = tol,
6+
gauge_atol = tol
7+
)
8+
9+
Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output
10+
`DV` of `eig_full` and the cotangent `ΔDV` of `eig_full` or `eig_trunc`.
11+
12+
In particular, it is assumed that `A ≈ V * D * inv(V)` with thus
13+
`size(A) == size(V) == size(D)` and `D` diagonal. For the cotangents, an arbitrary number of
14+
eigenvectors or eigenvalues can be missing, i.e. for a matrix `A` of size `(n, n)`, `ΔV` can
15+
have size `(n, pV)` and `diagview(ΔD)` can have length `pD`. In those cases, additionally
16+
`ind` is required to specify which eigenvectors or eigenvalues are present in `ΔV` or `ΔD`.
17+
By default, it is assumed that all eigenvectors and eigenvalues are present.
18+
19+
A warning will be printed if the cotangents are not gauge-invariant, i.e. if the restriction
20+
of `V' * ΔV` to rows `i` and columns `j` for which `abs(D[i] - D[j]) < degeneracy_atol`, is
21+
not small compared to `gauge_atol`.
22+
"""
23+
function eig_pullback!(
24+
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
325
tol::Real = default_pullback_gaugetol(DV[1]),
4-
degeneracy_atol::Real = tol, gauge_atol::Real = tol
26+
degeneracy_atol::Real = tol,
27+
gauge_atol::Real = tol
528
)
629

730
# Basic size checks and determination
@@ -10,35 +33,125 @@ function eig_full_pullback!(
1033
ΔDmat, ΔV = ΔDV
1134
n = LinearAlgebra.checksquare(V)
1235
n == length(D) || throw(DimensionMismatch())
36+
(n, n) == size(ΔA) || throw(DimensionMismatch())
1337

1438
if !iszerotangent(ΔV)
15-
VdΔV = V' * ΔV
39+
n == size(ΔV, 1) || throw(DimensionMismatch())
40+
pV = size(ΔV, 2)
41+
VᴴΔV = fill!(similar(V), 0)
42+
indV = axes(V, 2)[ind]
43+
length(indV) == pV || throw(DimensionMismatch())
44+
mul!(view(VᴴΔV, :, indV), V', ΔV)
1645
1746
mask = abs.(transpose(D) .- D) .< degeneracy_atol
18-
Δgauge = norm(view(VdΔV, mask), Inf)
47+
Δgauge = norm(view(VᴴΔV, mask), Inf)
1948
Δgauge < gauge_atol ||
2049
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
2150
22-
VdΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
51+
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
2352
2453
if !iszerotangent(ΔDmat)
25-
diagview(VdΔV) .+= diagview(ΔDmat)
54+
ΔDvec = diagview(ΔDmat)
55+
pD = length(ΔDvec)
56+
indD = axes(D, 1)[ind]
57+
length(indD) == pD || throw(DimensionMismatch())
58+
view(diagview(VᴴΔV), indD) .+= ΔDvec
2659
end
27-
PΔV = V' \ VdΔV
60+
PΔV = V' \ VᴴΔV
2861
if eltype(ΔA) <: Real
29-
ΔAc = mul!(VdΔV, PΔV, V') # recycle VdΔV memory
62+
ΔAc = mul!(VᴴΔV, PΔV, V') # recycle VdΔV memory
3063
ΔA .+= real.(ΔAc)
3164
else
3265
ΔA = mul!(ΔA, PΔV, V', 1, 1)
3366
end
3467
elseif !iszerotangent(ΔDmat)
35-
PΔV = V' \ Diagonal(diagview(ΔDmat))
68+
ΔDvec = diagview(ΔDmat)
69+
pD = length(ΔDvec)
70+
indD = axes(D, 1)[ind]
71+
length(indD) == pD || throw(DimensionMismatch())
72+
Vp = view(V, :, indD)
73+
PΔV = Vp' \ Diagonal(ΔDvec)
3674
if eltype(ΔA) <: Real
37-
ΔAc = PΔV * V'
75+
ΔAc = PΔV * Vp'
3876
ΔA .+= real.(ΔAc)
3977
else
4078
ΔA = mul!(ΔA, PΔV, V', 1, 1)
4179
end
4280
end
4381
return ΔA
4482
end
83+
84+
"""
85+
eig_trunc_pullback!(
86+
ΔA::AbstractMatrix, ΔDV, A, DV;
87+
tol = default_pullback_gaugetol(DV[1]),
88+
degeneracy_atol = tol,
89+
gauge_atol = tol
90+
)
91+
92+
Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the
93+
output `DV` and the cotangent `ΔDV` of `eig_trunc`.
94+
95+
In particular, it is assumed that `A * V ≈ V * D` with `V` a rectangular matrix of
96+
eigenvectors and `D` diagonal. For the cotangents, it is assumed that if `ΔV` is not zero,
97+
then it has the same number of columns as `V`, and if `ΔD` is not zero, then it is a
98+
diagonal matrix of the same size as `D`.
99+
100+
For this method to work correctly, it is also assumed that the remaining eigenvalues
101+
(not included in `D`) are (sufficiently) separated from those in `D`.
102+
103+
A warning will be printed if the cotangents are not gauge-invariant, i.e. if the restriction
104+
of `V' * ΔV` to rows `i` and columns `j` for which `abs(D[i] - D[j]) < degeneracy_atol`, is
105+
not small compared to `gauge_atol`.
106+
"""
107+
function eig_trunc_pullback!(
108+
ΔA::AbstractMatrix, A, DV, ΔDV;
109+
tol::Real = default_pullback_gaugetol(DV[1]),
110+
degeneracy_atol::Real = tol,
111+
gauge_atol::Real = tol
112+
)
113+
114+
# Basic size checks and determination
115+
Dmat, V = DV
116+
D = diagview(Dmat)
117+
ΔDmat, ΔV = ΔDV
118+
(n, p) = size(V)
119+
p == length(D) || throw(DimensionMismatch())
120+
(n, n) == size(ΔA) || throw(DimensionMismatch())
121+
G = V' * V
122+
123+
if !iszerotangent(ΔV)
124+
(n, p) == size(ΔV) || throw(DimensionMismatch())
125+
VᴴΔV = V' * ΔV
126+
mask = abs.(transpose(D) .- D) .< degeneracy_atol
127+
Δgauge = norm(view(VᴴΔV, mask), Inf)
128+
Δgauge < gauge_atol ||
129+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
130+
131+
ΔVperp = ΔV - V * inv(G) * VᴴΔV
132+
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
133+
else
134+
VᴴΔV = zero(G)
135+
end
136+
137+
if !iszerotangent(ΔDmat)
138+
ΔDvec = diagview(ΔDmat)
139+
p == length(ΔDvec) || throw(DimensionMismatch())
140+
diagview(VᴴΔV) .+= ΔDvec
141+
end
142+
Z = V' \ VᴴΔV
143+
144+
# add contribution from orthogonal complement
145+
PA = A - (A * V) / V
146+
Y = mul!(ΔVperp, PA', Z, 1, 1)
147+
X = sylvester(PA', -Dmat', Y)
148+
Z .+= X
149+
150+
if eltype(ΔA) <: Real
151+
ΔAc = Z * V'
152+
ΔA .+= real.(ΔAc)
153+
else
154+
ΔA = mul!(ΔA, Z, V', 1, 1)
155+
end
156+
return ΔA
157+
end

0 commit comments

Comments
 (0)