Skip to content

Commit 39f20b1

Browse files
committed
remove factorization chainrules
1 parent 3c20424 commit 39f20b1

File tree

2 files changed

+0
-216
lines changed

2 files changed

+0
-216
lines changed

ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,9 @@ import TensorOperations as TO
1212
using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract
1313
using VectorInterface: promote_scale, promote_add
1414

15-
using MatrixAlgebraKit
16-
using MatrixAlgebraKit: TruncationStrategy, TruncatedAlgorithm,
17-
svd_pullback!, eig_pullback!, eigh_pullback!,
18-
qr_compact_pullback!, lq_compact_pullback!,
19-
left_polar_pullback!, right_polar_pullback!
20-
2115
include("utility.jl")
2216
include("constructors.jl")
2317
include("linalg.jl")
2418
include("tensoroperations.jl")
25-
include("factorizations.jl")
2619

2720
end
Lines changed: 0 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1,211 +1,2 @@
11
# Factorizations rules
22
# --------------------
3-
function ChainRulesCore.rrule(::typeof(MatrixAlgebraKit.copy_input), f,
4-
t::AbstractTensorMap)
5-
project = ProjectTo(t)
6-
copy_input_pullback(Δt) = (NoTangent(), NoTangent(), project(unthunk(Δt)))
7-
return MatrixAlgebraKit.copy_input(f, t), copy_input_pullback
8-
end
9-
10-
@non_differentiable MatrixAlgebraKit.initialize_output(f, t::AbstractTensorMap, args...)
11-
@non_differentiable MatrixAlgebraKit.check_input(f, t::AbstractTensorMap, args...)
12-
13-
for qr_f in (:qr_compact, :qr_full)
14-
qr_f! = Symbol(qr_f, '!')
15-
@eval function ChainRulesCore.rrule(::typeof($qr_f!), t::AbstractTensorMap, QR, alg)
16-
tc = MatrixAlgebraKit.copy_input($qr_f, t)
17-
QR = $(qr_f!)(tc, QR, alg)
18-
function qr_pullback(ΔQR′)
19-
ΔQR = unthunk.(ΔQR′)
20-
Δt = zerovector(t)
21-
MatrixAlgebraKit.qr_compact_pullback!(Δt, t, QR, ΔQR)
22-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
23-
end
24-
function qr_pullback(::Tuple{ZeroTangent,ZeroTangent})
25-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
26-
end
27-
return QR, qr_pullback
28-
end
29-
end
30-
function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg)
31-
Q, R = qr_full(t, alg)
32-
for (c, b) in blocks(t)
33-
m, n = size(b)
34-
copy!(block(N, c), view(block(Q, c), 1:m, (n + 1):m))
35-
end
36-
37-
function qr_null_pullback(ΔN′)
38-
ΔN = unthunk(ΔN′)
39-
Δt = zerovector(t)
40-
ΔQ = zerovector!(similar(Q, codomain(Q) fuse(codomain(Q))))
41-
foreachblock(ΔN) do c, (b,)
42-
n = size(b, 2)
43-
ΔQc = block(ΔQ, c)
44-
return copy!(@view(ΔQc[:, (end - n + 1):end]), b)
45-
end
46-
ΔR = ZeroTangent()
47-
MatrixAlgebraKit.qr_compact_pullback!(Δt, t, (Q, R), (ΔQ, ΔR))
48-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
49-
end
50-
qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
51-
52-
return N, qr_null_pullback
53-
end
54-
55-
for lq_f in (:lq_compact, :lq_full)
56-
lq_f! = Symbol(lq_f, '!')
57-
@eval function ChainRulesCore.rrule(::typeof($lq_f!), t::AbstractTensorMap, LQ, alg)
58-
tc = MatrixAlgebraKit.copy_input($lq_f, t)
59-
LQ = $(lq_f!)(tc, LQ, alg)
60-
function lq_pullback(ΔLQ′)
61-
ΔLQ = unthunk.(ΔLQ′)
62-
Δt = zerovector(t)
63-
MatrixAlgebraKit.lq_compact_pullback!(Δt, t, LQ, ΔLQ)
64-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
65-
end
66-
function lq_pullback(::Tuple{ZeroTangent,ZeroTangent})
67-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
68-
end
69-
return LQ, lq_pullback
70-
end
71-
end
72-
function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, alg)
73-
L, Q = lq_full(t, alg)
74-
for (c, b) in blocks(t)
75-
m, n = size(b)
76-
copy!(block(Nᴴ, c), view(block(Q, c), (m + 1):n, 1:n))
77-
end
78-
79-
function lq_null_pullback(ΔNᴴ′)
80-
ΔNᴴ = unthunk(ΔNᴴ′)
81-
Δt = zerovector(t)
82-
ΔQ = zerovector!(similar(Q, codomain(Q) fuse(codomain(Q))))
83-
foreachblock(ΔNᴴ) do c, (b,)
84-
m = size(b, 1)
85-
ΔQc = block(ΔQ, c)
86-
return copy!(@view(ΔQc[(end - m + 1):end, :]), b)
87-
end
88-
ΔL = ZeroTangent()
89-
MatrixAlgebraKit.lq_compact_pullback!(Δt, t, (L, Q), (ΔL, ΔQ))
90-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
91-
end
92-
lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
93-
94-
return Nᴴ, lq_null_pullback
95-
end
96-
97-
for eig in (:eig, :eigh)
98-
eig_f = Symbol(eig, "_full")
99-
eig_f! = Symbol(eig_f, "!")
100-
eig_f_pb! = Symbol(eig, "_pullback!")
101-
eig_pb = Symbol(eig, "_pullback")
102-
@eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg)
103-
tc = MatrixAlgebraKit.copy_input($eig_f, t)
104-
DV = $(eig_f!)(tc, DV, alg)
105-
function $eig_pb(ΔDV)
106-
Δt = zerovector(t)
107-
MatrixAlgebraKit.$eig_f_pb!(Δt, t, DV, unthunk.(ΔDV))
108-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
109-
end
110-
function $eig_pb(::Tuple{ZeroTangent,ZeroTangent})
111-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
112-
end
113-
return DV, $eig_pb
114-
end
115-
end
116-
117-
for svd_f in (:svd_compact, :svd_full)
118-
svd_f! = Symbol(svd_f, "!")
119-
@eval begin
120-
function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg)
121-
tc = MatrixAlgebraKit.copy_input($svd_f, t)
122-
USVᴴ = $(svd_f!)(tc, USVᴴ, alg)
123-
function svd_pullback(ΔUSVᴴ)
124-
Δt = zerovector(t)
125-
MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ))
126-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
127-
end
128-
function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
129-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
130-
end
131-
return USVᴴ, svd_pullback
132-
end
133-
end
134-
end
135-
136-
function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ,
137-
alg::TruncatedAlgorithm)
138-
tc = MatrixAlgebraKit.copy_input(svd_compact, t)
139-
USVᴴ = svd_compact!(tc, USVᴴ, alg.alg)
140-
USVᴴ_trunc, ind = TensorKit.Factorizations.truncate(svd_trunc!, USVᴴ, alg.trunc)
141-
svd_trunc_pullback = _make_svd_trunc_pullback(t, USVᴴ, ind)
142-
return USVᴴ_trunc, svd_trunc_pullback
143-
end
144-
function _make_svd_trunc_pullback(t::AbstractTensorMap, USVᴴ, ind)
145-
function svd_trunc_pullback(ΔUSVᴴ)
146-
Δt = zerovector(t)
147-
MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ), ind)
148-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
149-
end
150-
function svd_trunc_pullback(::NTuple{3,ZeroTangent})
151-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
152-
end
153-
return svd_trunc_pullback
154-
end
155-
156-
function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg)
157-
tc = MatrixAlgebraKit.copy_input(left_polar, t)
158-
WP = left_polar!(tc, WP, alg)
159-
function left_polar_pullback(ΔWP)
160-
Δt = zerovector(t)
161-
MatrixAlgebraKit.left_polar_pullback!(Δt, t, WP, unthunk.(ΔWP))
162-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
163-
end
164-
function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent})
165-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
166-
end
167-
return WP, left_polar_pullback
168-
end
169-
170-
function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg)
171-
tc = MatrixAlgebraKit.copy_input(left_polar, t)
172-
PWᴴ = right_polar!(tc, PWᴴ, alg)
173-
function right_polar_pullback(ΔPWᴴ)
174-
Δt = zerovector(t)
175-
MatrixAlgebraKit.right_polar_pullback!(Δt, t, PWᴴ, unthunk.(ΔPWᴴ))
176-
return NoTangent(), Δt, ZeroTangent(), NoTangent()
177-
end
178-
function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent})
179-
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
180-
end
181-
return PWᴴ, right_polar_pullback
182-
end
183-
184-
# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
185-
# U, S, V⁺ = tsvd(t)
186-
# s = diag(S)
187-
# project_t = ProjectTo(t)
188-
189-
# function svdvals_pullback(Δs′)
190-
# Δs = unthunk(Δs′)
191-
# ΔS = diagm(codomain(S), domain(S), Δs)
192-
# return NoTangent(), project_t(U * ΔS * V⁺)
193-
# end
194-
195-
# return s, svdvals_pullback
196-
# end
197-
198-
# function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
199-
# sortby=nothing, kwargs...)
200-
# @assert sortby === nothing "only `sortby=nothing` is supported"
201-
# (D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
202-
# d = diag(D)
203-
# project_t = ProjectTo(t)
204-
# function eigvals_pullback(Δd′)
205-
# Δd = unthunk(Δd′)
206-
# ΔD = diagm(codomain(D), domain(D), Δd)
207-
# return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
208-
# end
209-
210-
# return d, eigvals_pullback
211-
# end

0 commit comments

Comments
 (0)