Skip to content

Commit 648a34a

Browse files
committed
Implement remaining factorization rrules
1 parent 05299b7 commit 648a34a

File tree

2 files changed

+25
-153
lines changed

2 files changed

+25
-153
lines changed

ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ using VectorInterface: promote_scale, promote_add
1414

1515
using MatrixAlgebraKit
1616
using MatrixAlgebraKit: TruncationStrategy,
17-
svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!
17+
svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!,
18+
qr_compact_pullback!, lq_compact_pullback!
1819

1920
include("utility.jl")
2021
include("constructors.jl")

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 23 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -99,167 +99,38 @@ end
9999
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
100100
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
101101
error("only `alg=QR()` and `alg=QRpos()` are supported")
102-
Q, R = leftorth(t; alg)
103-
function leftorth!_pullback((_ΔQ, _ΔR))
104-
ΔQ, ΔR = unthunk(_ΔQ), unthunk(_ΔR)
105-
Δt = similar(t)
106-
for (c, b) in blocks(Δt)
107-
qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c))
102+
QR = leftorth(t; alg)
103+
function leftorth!_pullback(ΔQR′)
104+
ΔQR = unthunk.(ΔQR′)
105+
Δt = zerovector(t)
106+
foreachblock(Δt) do c, (b,)
107+
QRc = block.(QR, Ref(c))
108+
ΔQRc = block.(ΔQR, Ref(c))
109+
qr_compact_pullback!(b, QRc, ΔQRc)
110+
return nothing
108111
end
109112
return NoTangent(), Δt
110113
end
111-
leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
112-
return (Q, R), leftorth!_pullback
114+
leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
115+
116+
return QR, leftorth!_pullback
113117
end
114118

115119
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
116120
alg isa TensorKit.LQ || alg isa TensorKit.LQpos ||
117121
error("only `alg=LQ()` and `alg=LQpos()` are supported")
118-
L, Q = rightorth(t; alg)
119-
function rightorth!_pullback((_ΔL, _ΔQ))
120-
ΔL, ΔQ = unthunk(_ΔL), unthunk(_ΔQ)
121-
Δt = similar(t)
122-
for (c, b) in blocks(Δt)
123-
lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c))
122+
LQ = rightorth(t; alg)
123+
function rightorth!_pullback(ΔLQ′)
124+
ΔLQ = unthunk(ΔLQ′)
125+
Δt = zerovector(t)
126+
foreachblock(Δt) do c, (b,)
127+
LQc = block.(LQ, Ref(c))
128+
ΔLQc = block.(ΔLQ, Ref(c))
129+
lq_compact_pullback!(b, LQc, ΔLQc)
130+
return nothing
124131
end
125132
return NoTangent(), Δt
126133
end
127-
rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
128-
return (L, Q), rightorth!_pullback
129-
end
130-
131-
# Corresponding matrix factorisations: implemented as mutating methods
132-
# ---------------------------------------------------------------------
133-
# helper routines
134-
safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a)
135-
136-
function lowertriangularind(A::AbstractMatrix)
137-
m, n = size(A)
138-
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
139-
offset = 0
140-
for j in 1:n
141-
r = (j + 1):m
142-
I[offset .- j .+ r] = (j - 1) * m .+ r
143-
offset += length(r)
144-
end
145-
return I
146-
end
147-
148-
function uppertriangularind(A::AbstractMatrix)
149-
m, n = size(A)
150-
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
151-
offset = 0
152-
for i in 1:m
153-
r = (i + 1):n
154-
I[offset .- i .+ r] = i .+ m .* (r .- 1)
155-
offset += length(r)
156-
end
157-
return I
158-
end
159-
160-
function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
161-
tol::Real=default_pullback_gaugetol(R))
162-
Rd = view(R, diagind(R))
163-
p = something(findlast((tol) abs, Rd), 0)
164-
m, n = size(R)
165-
166-
Q1 = view(Q, :, 1:p)
167-
R1 = view(R, 1:p, :)
168-
R11 = view(R, 1:p, 1:p)
169-
170-
ΔA1 = view(ΔA, :, 1:p)
171-
ΔQ1 = view(ΔQ, :, 1:p)
172-
ΔR1 = view(ΔR, 1:p, :)
173-
174-
M = similar(R, (p, p))
175-
ΔR isa AbstractZero || mul!(M, ΔR1, R1')
176-
ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero))
177-
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
178-
if eltype(M) <: Complex
179-
Md = view(M, diagind(M))
180-
Md .= real.(Md)
181-
end
182-
183-
ΔA1 .= ΔQ1
184-
mul!(ΔA1, Q1, M, +1, 1)
185-
186-
if n > p
187-
R12 = view(R, 1:p, (p + 1):n)
188-
ΔA2 = view(ΔA, :, (p + 1):n)
189-
ΔR12 = view(ΔR, 1:p, (p + 1):n)
190-
191-
if ΔR isa AbstractZero
192-
ΔA2 .= zero(eltype(ΔA))
193-
else
194-
mul!(ΔA2, Q1, ΔR12)
195-
mul!(ΔA1, ΔA2, R12', -1, 1)
196-
end
197-
end
198-
if m > p && !(ΔQ isa AbstractZero) # case where R is not full rank
199-
Q2 = view(Q, :, (p + 1):m)
200-
ΔQ2 = view(ΔQ, :, (p + 1):m)
201-
Q1dΔQ2 = Q1' * ΔQ2
202-
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
203-
Δgauge < tol ||
204-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
205-
mul!(ΔA1, Q2, Q1dΔQ2', -1, 1)
206-
end
207-
rdiv!(ΔA1, UpperTriangular(R11)')
208-
return ΔA
209-
end
210-
211-
function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ;
212-
tol::Real=default_pullback_gaugetol(L))
213-
Ld = view(L, diagind(L))
214-
p = something(findlast((tol) abs, Ld), 0)
215-
m, n = size(L)
216-
217-
L1 = view(L, :, 1:p)
218-
L11 = view(L, 1:p, 1:p)
219-
Q1 = view(Q, 1:p, :)
220-
221-
ΔA1 = view(ΔA, 1:p, :)
222-
ΔQ1 = view(ΔQ, 1:p, :)
223-
ΔL1 = view(ΔL, :, 1:p)
224-
225-
M = similar(L, (p, p))
226-
ΔL isa AbstractZero || mul!(M, L1', ΔL1)
227-
ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero))
228-
view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M)))
229-
if eltype(M) <: Complex
230-
Md = view(M, diagind(M))
231-
Md .= real.(Md)
232-
end
233-
234-
ΔA1 .= ΔQ1
235-
mul!(ΔA1, M, Q1, +1, 1)
236-
237-
if m > p
238-
L21 = view(L, (p + 1):m, 1:p)
239-
ΔA2 = view(ΔA, (p + 1):m, :)
240-
ΔL21 = view(ΔL, (p + 1):m, 1:p)
241-
242-
if ΔL isa AbstractZero
243-
ΔA2 .= zero(eltype(ΔA))
244-
else
245-
mul!(ΔA2, ΔL21, Q1)
246-
mul!(ΔA1, L21', ΔA2, -1, 1)
247-
end
248-
end
249-
if n > p && !(ΔQ isa AbstractZero) # case where R is not full rank
250-
Q2 = view(Q, (p + 1):n, :)
251-
ΔQ2 = view(ΔQ, (p + 1):n, :)
252-
ΔQ2Q1d = ΔQ2 * Q1'
253-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1))
254-
Δgauge < tol ||
255-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
256-
mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1)
257-
end
258-
ldiv!(LowerTriangular(L11)', ΔA1)
259-
return ΔA
260-
end
261-
262-
function default_pullback_gaugetol(a)
263-
n = norm(a, Inf)
264-
return eps(eltype(n))^(3 / 4) * max(n, one(n))
134+
rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
135+
return LQ, rightorth!_pullback
265136
end

0 commit comments

Comments
 (0)