Skip to content

Commit 383d808

Browse files
committed
Update eig(h) rrule
1 parent 1ecf11e commit 383d808

File tree

1 file changed

+21
-96
lines changed

1 file changed

+21
-96
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 21 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using MatrixAlgebraKit: svd_compact_pullback!
1+
using MatrixAlgebraKit: svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!
22

33
# Factorizations rules
44
# --------------------
@@ -46,47 +46,41 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTenso
4646
end
4747

4848
function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
49-
D, V = eig(t; kwargs...)
49+
DV = eig(t; kwargs...)
5050

51-
function eig!_pullback((_ΔD, _ΔV))
52-
ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV)
51+
function eig!_pullback(ΔDV′)
52+
ΔDV = unthunk.(ΔDV′)
5353
Δt = similar(t)
54-
for (c, b) in blocks(Δt)
55-
Dc, Vc = block(D, c), block(V, c)
56-
ΔDc, ΔVc = block(ΔD, c), block(ΔV, c)
57-
Ddc = view(Dc, diagind(Dc))
58-
ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc))
59-
eig_pullback!(b, Ddc, Vc, ΔDdc, ΔVc)
54+
foreachblock(Δt) do (c, b)
55+
DVc = block.(DV, Ref(c))
56+
ΔDVc = block.(ΔDV, Ref(c))
57+
eig_full_pullback!(b, DVc, ΔDVc)
58+
return nothing
6059
end
6160
return NoTangent(), Δt
6261
end
63-
function eig!_pullback(::Tuple{ZeroTangent,ZeroTangent})
64-
return NoTangent(), ZeroTangent()
65-
end
62+
eig!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
6663

67-
return (D, V), eig!_pullback
64+
return DV, eig!_pullback
6865
end
6966

7067
function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...)
71-
D, V = eigh(t; kwargs...)
68+
DV = eigh(t; kwargs...)
7269

73-
function eigh!_pullback((_ΔD, _ΔV))
74-
ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV)
70+
function eigh!_pullback(ΔDV′)
71+
ΔDV = unthunk.(ΔDV′)
7572
Δt = similar(t)
76-
for (c, b) in blocks(Δt)
77-
Dc, Vc = block(D, c), block(V, c)
78-
ΔDc, ΔVc = block(ΔD, c), block(ΔV, c)
79-
Ddc = view(Dc, diagind(Dc))
80-
ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc))
81-
eigh_pullback!(b, Ddc, Vc, ΔDdc, ΔVc)
73+
foreachblock(Δt) do (c, b)
74+
DVc = block.(DV, Ref(c))
75+
ΔDVc = block.(ΔDV, Ref(c))
76+
eigh_full_pullback!(b, DVc, ΔDVc)
77+
return nothing
8278
end
8379
return NoTangent(), Δt
8480
end
85-
function eigh!_pullback(::Tuple{ZeroTangent,ZeroTangent})
86-
return NoTangent(), ZeroTangent()
87-
end
81+
eigh!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
8882

89-
return (D, V), eigh!_pullback
83+
return DV, eigh!_pullback
9084
end
9185

9286
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
@@ -165,75 +159,6 @@ function uppertriangularind(A::AbstractMatrix)
165159
return I
166160
end
167161

168-
function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
169-
tol::Real=default_pullback_gaugetol(D))
170-
171-
# Basic size checks and determination
172-
n = LinearAlgebra.checksquare(V)
173-
n == length(D) || throw(DimensionMismatch())
174-
175-
if !(ΔV isa AbstractZero)
176-
VdΔV = V' * ΔV
177-
178-
mask = abs.(transpose(D) .- D) .< tol
179-
Δgauge = norm(view(VdΔV, mask), Inf)
180-
Δgauge < tol ||
181-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
182-
183-
VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol))
184-
185-
if !(ΔD isa AbstractZero)
186-
view(VdΔV, diagind(VdΔV)) .+= ΔD
187-
end
188-
PΔV = V' \ VdΔV
189-
if eltype(ΔA) <: Real
190-
ΔAc = mul!(VdΔV, PΔV, V') # recycle VdΔV memory
191-
ΔA .= real.(ΔAc)
192-
else
193-
mul!(ΔA, PΔV, V')
194-
end
195-
else
196-
PΔV = V' \ Diagonal(ΔD)
197-
if eltype(ΔA) <: Real
198-
ΔAc = PΔV * V'
199-
ΔA .= real.(ΔAc)
200-
else
201-
mul!(ΔA, PΔV, V')
202-
end
203-
end
204-
return ΔA
205-
end
206-
207-
function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
208-
tol::Real=default_pullback_gaugetol(D))
209-
210-
# Basic size checks and determination
211-
n = LinearAlgebra.checksquare(V)
212-
n == length(D) || throw(DimensionMismatch())
213-
214-
if !(ΔV isa AbstractZero)
215-
VdΔV = V' * ΔV
216-
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
217-
218-
mask = abs.(D' .- D) .< tol
219-
Δgauge = norm(view(aVdΔV, mask))
220-
Δgauge < tol ||
221-
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
222-
223-
aVdΔV .*= safe_inv.(D' .- D, tol)
224-
225-
if !(ΔD isa AbstractZero)
226-
view(aVdΔV, diagind(aVdΔV)) .+= real.(ΔD)
227-
# in principle, ΔD is real, but maybe not if coming from an anyonic tensor
228-
end
229-
# recylce VdΔV space
230-
mul!(ΔA, mul!(VdΔV, V, aVdΔV), V')
231-
else
232-
mul!(ΔA, V * Diagonal(ΔD), V')
233-
end
234-
return ΔA
235-
end
236-
237162
function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
238163
tol::Real=default_pullback_gaugetol(R))
239164
Rd = view(R, diagind(R))

0 commit comments

Comments
 (0)