Skip to content

Commit ff9e391

Browse files
committed
Adapt to MatrixAlgebraKit v0.4.1
1 parent d256ffb commit ff9e391

File tree

8 files changed

+215
-272
lines changed

8 files changed

+215
-272
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Combinatorics = "1"
3333
FiniteDifferences = "0.12"
3434
LRUCache = "1.0.2"
3535
LinearAlgebra = "1"
36-
MatrixAlgebraKit = "0.4.0"
36+
MatrixAlgebraKit = "0.4.1"
3737
OhMyThreads = "0.8.0"
3838
PackageExtensionCompat = "1"
3939
Random = "1"

ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ using VectorInterface: promote_scale, promote_add
1414

1515
using MatrixAlgebraKit
1616
using MatrixAlgebraKit: TruncationStrategy, TruncatedAlgorithm,
17-
svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!,
18-
qr_compact_pullback!, lq_compact_pullback!, left_polar_pullback!,
19-
right_polar_pullback!
17+
svd_pullback!, eig_pullback!, eigh_pullback!,
18+
qr_compact_pullback!, lq_compact_pullback!,
19+
left_polar_pullback!, right_polar_pullback!
2020

2121
include("utility.jl")
2222
include("constructors.jl")

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 23 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ for qr_f in (:qr_compact, :qr_full)
1818
function qr_pullback(ΔQR′)
1919
ΔQR = unthunk.(ΔQR′)
2020
Δt = zerovector(t)
21-
MatrixAlgebraKit.qr_compact_pullback!(Δt, QR, ΔQR)
21+
MatrixAlgebraKit.qr_compact_pullback!(Δt, t, QR, ΔQR)
2222
return NoTangent(), Δt, ZeroTangent(), NoTangent()
2323
end
2424
function qr_pullback(::Tuple{ZeroTangent,ZeroTangent})
@@ -44,7 +44,7 @@ function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg)
4444
return copy!(@view(ΔQc[:, (end - n + 1):end]), b)
4545
end
4646
ΔR = ZeroTangent()
47-
MatrixAlgebraKit.qr_compact_pullback!(Δt, (Q, R), (ΔQ, ΔR))
47+
MatrixAlgebraKit.qr_compact_pullback!(Δt, t, (Q, R), (ΔQ, ΔR))
4848
return NoTangent(), Δt, ZeroTangent(), NoTangent()
4949
end
5050
qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
@@ -60,7 +60,7 @@ for lq_f in (:lq_compact, :lq_full)
6060
function lq_pullback(ΔLQ′)
6161
ΔLQ = unthunk.(ΔLQ′)
6262
Δt = zerovector(t)
63-
MatrixAlgebraKit.lq_compact_pullback!(Δt, LQ, ΔLQ)
63+
MatrixAlgebraKit.lq_compact_pullback!(Δt, t, LQ, ΔLQ)
6464
return NoTangent(), Δt, ZeroTangent(), NoTangent()
6565
end
6666
function lq_pullback(::Tuple{ZeroTangent,ZeroTangent})
@@ -86,7 +86,7 @@ function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, al
8686
return copy!(@view(ΔQc[(end - m + 1):end, :]), b)
8787
end
8888
ΔL = ZeroTangent()
89-
MatrixAlgebraKit.lq_compact_pullback!(Δt, (L, Q), (ΔL, ΔQ))
89+
MatrixAlgebraKit.lq_compact_pullback!(Δt, t, (L, Q), (ΔL, ΔQ))
9090
return NoTangent(), Δt, ZeroTangent(), NoTangent()
9191
end
9292
lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
@@ -97,14 +97,14 @@ end
9797
for eig in (:eig, :eigh)
9898
eig_f = Symbol(eig, "_full")
9999
eig_f! = Symbol(eig_f, "!")
100-
eig_f_pb! = Symbol(eig, "_full_pullback!")
100+
eig_f_pb! = Symbol(eig, "_pullback!")
101101
eig_pb = Symbol(eig, "_pullback")
102102
@eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg)
103-
tc = copy_input($eig_f, t)
103+
tc = MatrixAlgebraKit.copy_input($eig_f, t)
104104
DV = $(eig_f!)(tc, DV, alg)
105105
function $eig_pb(ΔDV)
106106
Δt = zerovector(t)
107-
MatrixAlgebraKit.$eig_f_pb!(Δt, DV, unthunk.(ΔDV))
107+
MatrixAlgebraKit.$eig_f_pb!(Δt, t, DV, unthunk.(ΔDV))
108108
return NoTangent(), Δt, ZeroTangent(), NoTangent()
109109
end
110110
function $eig_pb(::Tuple{ZeroTangent,ZeroTangent})
@@ -118,11 +118,11 @@ for svd_f in (:svd_compact, :svd_full)
118118
svd_f! = Symbol(svd_f, "!")
119119
@eval begin
120120
function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg)
121-
tc = copy_input($svd_f, t)
121+
tc = MatrixAlgebraKit.copy_input($svd_f, t)
122122
USVᴴ = $(svd_f!)(tc, USVᴴ, alg)
123123
function svd_pullback(ΔUSVᴴ)
124124
Δt = zerovector(t)
125-
MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ))
125+
MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ))
126126
return NoTangent(), Δt, ZeroTangent(), NoTangent()
127127
end
128128
function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
@@ -137,23 +137,28 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ
137137
alg::TruncatedAlgorithm)
138138
tc = MatrixAlgebraKit.copy_input(svd_compact, t)
139139
USVᴴ = svd_compact!(tc, USVᴴ, alg.alg)
140+
USVᴴ_trunc, ind = TensorKit.Factorizations.truncate(svd_trunc!, USVᴴ, alg.trunc)
141+
svd_trunc_pullback = _make_svd_trunc_pullback(t, USVᴴ, ind)
142+
return USVᴴ_trunc, svd_trunc_pullback
143+
end
144+
function _make_svd_trunc_pullback(t::AbstractTensorMap, USVᴴ, ind)
140145
function svd_trunc_pullback(ΔUSVᴴ)
141146
Δt = zerovector(t)
142-
MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ))
143-
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
147+
MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ), ind)
148+
return NoTangent(), Δt, ZeroTangent(), NoTangent()
144149
end
145-
function svd_trunc_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
150+
function svd_trunc_pullback(::NTuple{3,ZeroTangent})
146151
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
147152
end
148-
return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, alg.trunc), svd_trunc_pullback
153+
return svd_trunc_pullback
149154
end
150155

151156
function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg)
152-
tc = copy_input(left_polar, t)
157+
tc = MatrixAlgebraKit.copy_input(left_polar, t)
153158
WP = left_polar!(tc, WP, alg)
154159
function left_polar_pullback(ΔWP)
155160
Δt = zerovector(t)
156-
MatrixAlgebraKit.left_polar_pullback!(Δt, WP, unthunk.(ΔWP))
161+
MatrixAlgebraKit.left_polar_pullback!(Δt, t, WP, unthunk.(ΔWP))
157162
return NoTangent(), Δt, ZeroTangent(), NoTangent()
158163
end
159164
function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent})
@@ -163,11 +168,11 @@ function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, a
163168
end
164169

165170
function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg)
166-
tc = copy_input(left_polar, t)
167-
PWᴴ = right_polar!(Ac, PWᴴ, alg)
171+
tc = MatrixAlgebraKit.copy_input(left_polar, t)
172+
PWᴴ = right_polar!(tc, PWᴴ, alg)
168173
function right_polar_pullback(ΔPWᴴ)
169174
Δt = zerovector(t)
170-
MatrixAlgebraKit.right_polar_pullback!(Δt, PWᴴ, unthunk.(ΔPWᴴ))
175+
MatrixAlgebraKit.right_polar_pullback!(Δt, t, PWᴴ, unthunk.(ΔPWᴴ))
171176
return NoTangent(), Δt, ZeroTangent(), NoTangent()
172177
end
173178
function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent})
@@ -176,41 +181,6 @@ function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PW
176181
return PWᴴ, right_polar_pullback
177182
end
178183

179-
# for f in (:tsvd, :eig, :eigh)
180-
# f! = Symbol(f, :!)
181-
# f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!)
182-
# f_pullback = Symbol(f, :_pullback)
183-
# f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!)
184-
# @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap;
185-
# trunc::TruncationStrategy=TensorKit.notrunc(),
186-
# kwargs...)
187-
# # TODO: I think we can use f! here without issues because we don't actually require
188-
# # the data of `t` anymore.
189-
# F = $f(t; trunc=TensorKit.notrunc(), kwargs...)
190-
191-
# if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
192-
# F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc)
193-
# else
194-
# F′ = F
195-
# end
196-
197-
# function $f_pullback(ΔF′)
198-
# ΔF = unthunk.(ΔF′)
199-
# Δt = zerovector(t)
200-
# foreachblock(Δt) do c, (b,)
201-
# Fc = block.(F, Ref(c))
202-
# ΔFc = block.(ΔF, Ref(c))
203-
# $f_pullback!(b, Fc, ΔFc)
204-
# return nothing
205-
# end
206-
# return NoTangent(), Δt
207-
# end
208-
# $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent()
209-
210-
# return F′, $f_pullback
211-
# end
212-
# end
213-
214184
# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
215185
# U, S, V⁺ = tsvd(t)
216186
# s = diag(S)
@@ -239,42 +209,3 @@ end
239209

240210
# return d, eigvals_pullback
241211
# end
242-
243-
# function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
244-
# alg isa MatrixAlgebraKit.LAPACK_HouseholderQR ||
245-
# error("only `alg=QR()` and `alg=QRpos()` are supported")
246-
# QR = leftorth(t; alg)
247-
# function leftorth!_pullback(ΔQR′)
248-
# ΔQR = unthunk.(ΔQR′)
249-
# Δt = zerovector(t)
250-
# foreachblock(Δt) do c, (b,)
251-
# QRc = block.(QR, Ref(c))
252-
# ΔQRc = block.(ΔQR, Ref(c))
253-
# qr_compact_pullback!(b, QRc, ΔQRc)
254-
# return nothing
255-
# end
256-
# return NoTangent(), Δt
257-
# end
258-
# leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
259-
260-
# return QR, leftorth!_pullback
261-
# end
262-
263-
# function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
264-
# alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ ||
265-
# error("only `alg=LQ()` and `alg=LQpos()` are supported")
266-
# LQ = rightorth(t; alg)
267-
# function rightorth!_pullback(ΔLQ′)
268-
# ΔLQ = unthunk(ΔLQ′)
269-
# Δt = zerovector(t)
270-
# foreachblock(Δt) do c, (b,)
271-
# LQc = block.(LQ, Ref(c))
272-
# ΔLQc = block.(ΔLQ, Ref(c))
273-
# lq_compact_pullback!(b, LQc, ΔLQc)
274-
# return nothing
275-
# end
276-
# return NoTangent(), Δt
277-
# end
278-
# rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
279-
# return LQ, rightorth!_pullback
280-
# end

src/tensors/factorizations/factorizations.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ import MatrixAlgebraKit: default_algorithm,
4242
left_orth!, right_orth!, left_null!, right_null!,
4343
truncate!, findtruncated, findtruncated_svd,
4444
diagview, isisometry
45-
import MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!, svd_compact_pullback!,
46-
left_polar_pullback!, right_polar_pullback!, eig_full_pullback!,
47-
eigh_full_pullback!
45+
using MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!,
46+
svd_pullback!, svd_trunc_pullback!,
47+
eig_pullback!, eig_trunc_pullback!,
48+
eigh_pullback!, eigh_trunc_pullback!,
49+
left_polar_pullback!, right_polar_pullback!
4850

4951
include("utility.jl")
5052
include("matrixalgebrakit.jl")

src/tensors/factorizations/pullbacks.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,41 @@
11
for pullback! in (:qr_compact_pullback!, :lq_compact_pullback!,
2-
:svd_compact_pullback!,
3-
:left_polar_pullback!, :right_polar_pullback!,
4-
:eig_full_pullback!, :eigh_full_pullback!)
5-
@eval function $pullback!(Δt::AbstractTensorMap, F, ΔF; kwargs...)
6-
foreachblock(Δt) do c, (b,)
2+
:left_polar_pullback!, :right_polar_pullback!)
3+
@eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap,
4+
F, ΔF; kwargs...)
5+
foreachblock(Δt, t) do c, (Δb, b)
76
Fc = block.(F, Ref(c))
87
ΔFc = block.(ΔF, Ref(c))
9-
return $pullback!(b, Fc, ΔFc; kwargs...)
8+
return $pullback!(Δb, b, Fc, ΔFc; kwargs...)
9+
end
10+
return Δt
11+
end
12+
end
13+
14+
_notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t))
15+
16+
for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!)
17+
@eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap,
18+
F, ΔF, inds=_notrunc_ind(t);
19+
kwargs...)
20+
for (c, ind) in inds
21+
Δb = block(Δt, c)
22+
b = block(t, c)
23+
Fc = block.(F, Ref(c))
24+
ΔFc = block.(ΔF, Ref(c))
25+
$pullback!(Δb, b, Fc, ΔFc, ind; kwargs...)
26+
end
27+
return Δt
28+
end
29+
end
30+
31+
for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_pullback!)
32+
@eval function MatrixAlgebraKit.$pullback_trunc!(Δt::AbstractTensorMap,
33+
t::AbstractTensorMap,
34+
F, ΔF; kwargs...)
35+
foreachblock(Δt, t) do c, (Δb, b)
36+
Fc = block.(F, Ref(c))
37+
ΔFc = block.(ΔF, Ref(c))
38+
return $pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
1039
end
1140
return Δt
1241
end

0 commit comments

Comments
 (0)