Skip to content

Commit d256ffb

Browse files
committed
rework AD rules
1 parent 7c5419b commit d256ffb

File tree

7 files changed

+789
-430
lines changed

7 files changed

+789
-430
lines changed

ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract
1313
using VectorInterface: promote_scale, promote_add
1414

1515
using MatrixAlgebraKit
16-
using MatrixAlgebraKit: TruncationStrategy,
16+
using MatrixAlgebraKit: TruncationStrategy, TruncatedAlgorithm,
1717
svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!,
18-
qr_compact_pullback!, lq_compact_pullback!
18+
qr_compact_pullback!, lq_compact_pullback!, left_polar_pullback!,
19+
right_polar_pullback!
1920

2021
include("utility.jl")
2122
include("constructors.jl")
Lines changed: 254 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,280 @@
11
# Factorizations rules
22
# --------------------
3-
for f in (:tsvd, :eig, :eigh)
4-
f! = Symbol(f, :!)
5-
f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!)
6-
f_pullback = Symbol(f, :_pullback)
7-
f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!)
8-
@eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap;
9-
trunc::TruncationStrategy=TensorKit.notrunc(),
10-
kwargs...)
11-
# TODO: I think we can use f! here without issues because we don't actually require
12-
# the data of `t` anymore.
13-
F = $f(t; trunc=TensorKit.notrunc(), kwargs...)
14-
15-
if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
16-
F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc)
17-
else
18-
F′ = F
19-
end
3+
function ChainRulesCore.rrule(::typeof(MatrixAlgebraKit.copy_input), f,
4+
t::AbstractTensorMap)
5+
project = ProjectTo(t)
6+
copy_input_pullback(Δt) = (NoTangent(), NoTangent(), project(unthunk(Δt)))
7+
return MatrixAlgebraKit.copy_input(f, t), copy_input_pullback
8+
end
9+
10+
@non_differentiable MatrixAlgebraKit.initialize_output(f, t::AbstractTensorMap, args...)
11+
@non_differentiable MatrixAlgebraKit.check_input(f, t::AbstractTensorMap, args...)
2012

21-
function $f_pullback(ΔF′)
22-
ΔF = unthunk.(ΔF′)
13+
for qr_f in (:qr_compact, :qr_full)
14+
qr_f! = Symbol(qr_f, '!')
15+
@eval function ChainRulesCore.rrule(::typeof($qr_f!), t::AbstractTensorMap, QR, alg)
16+
tc = MatrixAlgebraKit.copy_input($qr_f, t)
17+
QR = $(qr_f!)(tc, QR, alg)
18+
function qr_pullback(ΔQR′)
19+
ΔQR = unthunk.(ΔQR′)
2320
Δt = zerovector(t)
24-
foreachblock(Δt) do c, (b,)
25-
Fc = block.(F, Ref(c))
26-
ΔFc = block.(ΔF, Ref(c))
27-
$f_pullback!(b, Fc, ΔFc)
28-
return nothing
29-
end
30-
return NoTangent(), Δt
21+
MatrixAlgebraKit.qr_compact_pullback!(Δt, QR, ΔQR)
22+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
23+
end
24+
function qr_pullback(::Tuple{ZeroTangent,ZeroTangent})
25+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
3126
end
32-
$f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent()
27+
return QR, qr_pullback
28+
end
29+
end
30+
function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg)
31+
Q, R = qr_full(t, alg)
32+
for (c, b) in blocks(t)
33+
m, n = size(b)
34+
copy!(block(N, c), view(block(Q, c), 1:m, (n + 1):m))
35+
end
3336

34-
return F′, $f_pullback
37+
function qr_null_pullback(ΔN′)
38+
ΔN = unthunk(ΔN′)
39+
Δt = zerovector(t)
40+
ΔQ = zerovector!(similar(Q, codomain(Q) fuse(codomain(Q))))
41+
foreachblock(ΔN) do c, (b,)
42+
n = size(b, 2)
43+
ΔQc = block(ΔQ, c)
44+
return copy!(@view(ΔQc[:, (end - n + 1):end]), b)
45+
end
46+
ΔR = ZeroTangent()
47+
MatrixAlgebraKit.qr_compact_pullback!(Δt, (Q, R), (ΔQ, ΔR))
48+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
3549
end
50+
qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
51+
52+
return N, qr_null_pullback
3653
end
3754

38-
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
39-
U, S, V⁺ = tsvd(t)
40-
s = diag(S)
41-
project_t = ProjectTo(t)
55+
for lq_f in (:lq_compact, :lq_full)
56+
lq_f! = Symbol(lq_f, '!')
57+
@eval function ChainRulesCore.rrule(::typeof($lq_f!), t::AbstractTensorMap, LQ, alg)
58+
tc = MatrixAlgebraKit.copy_input($lq_f, t)
59+
LQ = $(lq_f!)(tc, LQ, alg)
60+
function lq_pullback(ΔLQ′)
61+
ΔLQ = unthunk.(ΔLQ′)
62+
Δt = zerovector(t)
63+
MatrixAlgebraKit.lq_compact_pullback!(Δt, LQ, ΔLQ)
64+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
65+
end
66+
function lq_pullback(::Tuple{ZeroTangent,ZeroTangent})
67+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
68+
end
69+
return LQ, lq_pullback
70+
end
71+
end
72+
function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, alg)
73+
L, Q = lq_full(t, alg)
74+
for (c, b) in blocks(t)
75+
m, n = size(b)
76+
copy!(block(Nᴴ, c), view(block(Q, c), (m + 1):n, 1:n))
77+
end
4278

43-
function svdvals_pullback(Δs′)
44-
Δs = unthunk(Δs′)
45-
ΔS = diagm(codomain(S), domain(S), Δs)
46-
return NoTangent(), project_t(U * ΔS * V⁺)
79+
function lq_null_pullback(ΔNᴴ′)
80+
ΔNᴴ = unthunk(ΔNᴴ′)
81+
Δt = zerovector(t)
82+
ΔQ = zerovector!(similar(Q, codomain(Q) fuse(codomain(Q))))
83+
foreachblock(ΔNᴴ) do c, (b,)
84+
m = size(b, 1)
85+
ΔQc = block(ΔQ, c)
86+
return copy!(@view(ΔQc[(end - m + 1):end, :]), b)
87+
end
88+
ΔL = ZeroTangent()
89+
MatrixAlgebraKit.lq_compact_pullback!(Δt, (L, Q), (ΔL, ΔQ))
90+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
4791
end
92+
lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
4893

49-
return s, svdvals_pullback
94+
return Nᴴ, lq_null_pullback
5095
end
5196

52-
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
53-
sortby=nothing, kwargs...)
54-
@assert sortby === nothing "only `sortby=nothing` is supported"
55-
(D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
56-
d = diag(D)
57-
project_t = ProjectTo(t)
58-
function eigvals_pullback(Δd′)
59-
Δd = unthunk(Δd′)
60-
ΔD = diagm(codomain(D), domain(D), Δd)
61-
return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
97+
for eig in (:eig, :eigh)
98+
eig_f = Symbol(eig, "_full")
99+
eig_f! = Symbol(eig_f, "!")
100+
eig_f_pb! = Symbol(eig, "_full_pullback!")
101+
eig_pb = Symbol(eig, "_pullback")
102+
@eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg)
103+
tc = copy_input($eig_f, t)
104+
DV = $(eig_f!)(tc, DV, alg)
105+
function $eig_pb(ΔDV)
106+
Δt = zerovector(t)
107+
MatrixAlgebraKit.$eig_f_pb!(Δt, DV, unthunk.(ΔDV))
108+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
109+
end
110+
function $eig_pb(::Tuple{ZeroTangent,ZeroTangent})
111+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
112+
end
113+
return DV, $eig_pb
62114
end
115+
end
63116

64-
return d, eigvals_pullback
117+
for svd_f in (:svd_compact, :svd_full)
118+
svd_f! = Symbol(svd_f, "!")
119+
@eval begin
120+
function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg)
121+
tc = copy_input($svd_f, t)
122+
USVᴴ = $(svd_f!)(tc, USVᴴ, alg)
123+
function svd_pullback(ΔUSVᴴ)
124+
Δt = zerovector(t)
125+
MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ))
126+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
127+
end
128+
function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
129+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
130+
end
131+
return USVᴴ, svd_pullback
132+
end
133+
end
65134
end
66135

67-
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
68-
alg isa MatrixAlgebraKit.LAPACK_HouseholderQR ||
69-
error("only `alg=QR()` and `alg=QRpos()` are supported")
70-
QR = leftorth(t; alg)
71-
function leftorth!_pullback(ΔQR′)
72-
ΔQR = unthunk.(ΔQR′)
136+
function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ,
137+
alg::TruncatedAlgorithm)
138+
tc = MatrixAlgebraKit.copy_input(svd_compact, t)
139+
USVᴴ = svd_compact!(tc, USVᴴ, alg.alg)
140+
function svd_trunc_pullback(ΔUSVᴴ)
73141
Δt = zerovector(t)
74-
foreachblock(Δt) do c, (b,)
75-
QRc = block.(QR, Ref(c))
76-
ΔQRc = block.(ΔQR, Ref(c))
77-
qr_compact_pullback!(b, QRc, ΔQRc)
78-
return nothing
79-
end
80-
return NoTangent(), Δt
142+
MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ))
143+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
144+
end
145+
function svd_trunc_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
146+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
81147
end
82-
leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
148+
return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, alg.trunc), svd_trunc_pullback
149+
end
83150

84-
return QR, leftorth!_pullback
151+
function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg)
152+
tc = copy_input(left_polar, t)
153+
WP = left_polar!(tc, WP, alg)
154+
function left_polar_pullback(ΔWP)
155+
Δt = zerovector(t)
156+
MatrixAlgebraKit.left_polar_pullback!(Δt, WP, unthunk.(ΔWP))
157+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
158+
end
159+
function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent})
160+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
161+
end
162+
return WP, left_polar_pullback
85163
end
86164

87-
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
88-
alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ ||
89-
error("only `alg=LQ()` and `alg=LQpos()` are supported")
90-
LQ = rightorth(t; alg)
91-
function rightorth!_pullback(ΔLQ′)
92-
ΔLQ = unthunk(ΔLQ′)
165+
function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg)
166+
tc = copy_input(left_polar, t)
167+
PWᴴ = right_polar!(Ac, PWᴴ, alg)
168+
function right_polar_pullback(ΔPWᴴ)
93169
Δt = zerovector(t)
94-
foreachblock(Δt) do c, (b,)
95-
LQc = block.(LQ, Ref(c))
96-
ΔLQc = block.(ΔLQ, Ref(c))
97-
lq_compact_pullback!(b, LQc, ΔLQc)
98-
return nothing
99-
end
100-
return NoTangent(), Δt
170+
MatrixAlgebraKit.right_polar_pullback!(Δt, PWᴴ, unthunk.(ΔPWᴴ))
171+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
172+
end
173+
function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent})
174+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
101175
end
102-
rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
103-
return LQ, rightorth!_pullback
176+
return PWᴴ, right_polar_pullback
104177
end
178+
179+
# for f in (:tsvd, :eig, :eigh)
180+
# f! = Symbol(f, :!)
181+
# f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!)
182+
# f_pullback = Symbol(f, :_pullback)
183+
# f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!)
184+
# @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap;
185+
# trunc::TruncationStrategy=TensorKit.notrunc(),
186+
# kwargs...)
187+
# # TODO: I think we can use f! here without issues because we don't actually require
188+
# # the data of `t` anymore.
189+
# F = $f(t; trunc=TensorKit.notrunc(), kwargs...)
190+
191+
# if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
192+
# F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc)
193+
# else
194+
# F′ = F
195+
# end
196+
197+
# function $f_pullback(ΔF′)
198+
# ΔF = unthunk.(ΔF′)
199+
# Δt = zerovector(t)
200+
# foreachblock(Δt) do c, (b,)
201+
# Fc = block.(F, Ref(c))
202+
# ΔFc = block.(ΔF, Ref(c))
203+
# $f_pullback!(b, Fc, ΔFc)
204+
# return nothing
205+
# end
206+
# return NoTangent(), Δt
207+
# end
208+
# $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent()
209+
210+
# return F′, $f_pullback
211+
# end
212+
# end
213+
214+
# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
215+
# U, S, V⁺ = tsvd(t)
216+
# s = diag(S)
217+
# project_t = ProjectTo(t)
218+
219+
# function svdvals_pullback(Δs′)
220+
# Δs = unthunk(Δs′)
221+
# ΔS = diagm(codomain(S), domain(S), Δs)
222+
# return NoTangent(), project_t(U * ΔS * V⁺)
223+
# end
224+
225+
# return s, svdvals_pullback
226+
# end
227+
228+
# function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
229+
# sortby=nothing, kwargs...)
230+
# @assert sortby === nothing "only `sortby=nothing` is supported"
231+
# (D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
232+
# d = diag(D)
233+
# project_t = ProjectTo(t)
234+
# function eigvals_pullback(Δd′)
235+
# Δd = unthunk(Δd′)
236+
# ΔD = diagm(codomain(D), domain(D), Δd)
237+
# return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
238+
# end
239+
240+
# return d, eigvals_pullback
241+
# end
242+
243+
# function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
244+
# alg isa MatrixAlgebraKit.LAPACK_HouseholderQR ||
245+
# error("only `alg=QR()` and `alg=QRpos()` are supported")
246+
# QR = leftorth(t; alg)
247+
# function leftorth!_pullback(ΔQR′)
248+
# ΔQR = unthunk.(ΔQR′)
249+
# Δt = zerovector(t)
250+
# foreachblock(Δt) do c, (b,)
251+
# QRc = block.(QR, Ref(c))
252+
# ΔQRc = block.(ΔQR, Ref(c))
253+
# qr_compact_pullback!(b, QRc, ΔQRc)
254+
# return nothing
255+
# end
256+
# return NoTangent(), Δt
257+
# end
258+
# leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
259+
260+
# return QR, leftorth!_pullback
261+
# end
262+
263+
# function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
264+
# alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ ||
265+
# error("only `alg=LQ()` and `alg=LQpos()` are supported")
266+
# LQ = rightorth(t; alg)
267+
# function rightorth!_pullback(ΔLQ′)
268+
# ΔLQ = unthunk(ΔLQ′)
269+
# Δt = zerovector(t)
270+
# foreachblock(Δt) do c, (b,)
271+
# LQc = block.(LQ, Ref(c))
272+
# ΔLQc = block.(ΔLQ, Ref(c))
273+
# lq_compact_pullback!(b, LQc, ΔLQc)
274+
# return nothing
275+
# end
276+
# return NoTangent(), Δt
277+
# end
278+
# rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
279+
# return LQ, rightorth!_pullback
280+
# end

src/tensors/factorizations/adjoint.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# ----------------
33
# map algorithms to their adjoint counterpart
44
# TODO: this probably belongs in MatrixAlgebraKit
5-
_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.positive, alg.blocksize)
6-
_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.positive, alg.blocksize)
7-
_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.positive, alg.blocksize)
8-
_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.positive, alg.blocksize)
5+
_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.kwargs...)
6+
_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.kwargs...)
7+
_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.kwargs...)
8+
_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.kwargs...)
99
_adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

0 commit comments

Comments
 (0)