Skip to content

Commit 1ecf11e

Browse files
committed
Update svd rrule
1 parent 942121b commit 1ecf11e

File tree

1 file changed

+15
-154
lines changed

1 file changed

+15
-154
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 15 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,32 @@ using MatrixAlgebraKit: svd_compact_pullback!
33
# Factorizations rules
44
# --------------------
55
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
6-
trunc::TensorKit.TruncationScheme=TensorKit.notrunc(),
7-
alg::Union{TensorKit.SVD,TensorKit.SDD}=TensorKit.SDD())
8-
U, Σ, V⁺, truncerr = tsvd(t; trunc=TensorKit.notrunc(), alg)
9-
10-
if !(trunc == TensorKit.notrunc()) && !isempty(blocksectors(t))
11-
Σdata = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(Σ))
12-
13-
truncdim = TensorKit._compute_truncdim(Σdata, trunc; p=2)
14-
truncerr = TensorKit._compute_truncerr(Σdata, truncdim; p=2)
15-
16-
SVDdata = TensorKit.SectorDict(c => (block(U, c), Σc, block(V⁺, c))
17-
for (c, Σc) in Σdata)
18-
19-
Ũ, Σ̃, Ṽ⁺ = TensorKit._create_svdtensors(t, SVDdata, truncdim)
6+
trunc::TruncationStrategy=TensorKit.notrunc(),
7+
kwargs...)
8+
# TODO: I think we can use tsvd! here without issues because we don't actually require
9+
# the data of `t` anymore.
10+
USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), alg)
11+
12+
if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
13+
USVᴴ′ = MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, trunc)
2014
else
21-
Ũ, Σ̃, Ṽ⁺ = U, Σ, V⁺
15+
USVᴴ′ = USVᴴ
2216
end
2317

24-
function tsvd!_pullback(ΔUSVϵ)
25-
ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ)
18+
function tsvd!_pullback(ΔUSVᴴ′)
19+
ΔUSVᴴ = unthunk.(ΔUSVᴴ′)
2620
Δt = similar(t)
2721
foreachblock(Δt) do (c, b)
28-
USVᴴc = (block(U, c), block(Σ, c), block(V⁺, c))
29-
ΔUSVᴴc = (block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c))
22+
USVᴴc = block.(USVᴴ, Ref(c))
23+
ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c))
3024
svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc)
3125
return nothing
3226
end
3327
return NoTangent(), Δt
3428
end
35-
function tsvd!_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
36-
return NoTangent(), ZeroTangent()
37-
end
29+
tsvd!_pullback(::NTuple{3,ZeroTangent}) = NoTangent(), ZeroTangent()
3830

39-
return (Ũ, Σ̃, Ṽ⁺, truncerr), tsvd!_pullback
31+
return USVᴴ′, tsvd!_pullback
4032
end
4133

4234
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
@@ -173,137 +165,6 @@ function uppertriangularind(A::AbstractMatrix)
173165
return I
174166
end
175167

176-
# SVD_pullback: pullback implementation for general (possibly truncated) SVD
177-
#
178-
# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as
179-
# cotangent ΔU, ΔS, ΔVd variables of truncated SVD
180-
#
181-
# Checks whether the cotangent variables are such that they would couple to gauge-dependent
182-
# degrees of freedom (phases of singular vectors), and prints a warning if this is the case
183-
#
184-
# An implementation that only uses U, S, and Vd from truncated SVD is also possible, but
185-
# requires solving a Sylvester equation, which does not seem to be supported on GPUs.
186-
#
187-
# Other implementation considerations for GPU compatibility:
188-
# no scalar indexing, lots of broadcasting and views
189-
#
190-
# function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
191-
# Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
192-
# tol::Real=default_pullback_gaugetol(S))
193-
194-
# # Basic size checks and determination
195-
# m, n = size(U, 1), size(Vd, 2)
196-
# size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
197-
# p = -1
198-
# if !(ΔU isa AbstractZero)
199-
# m == size(ΔU, 1) || throw(DimensionMismatch())
200-
# p = size(ΔU, 2)
201-
# end
202-
# if !(ΔVd isa AbstractZero)
203-
# n == size(ΔVd, 2) || throw(DimensionMismatch())
204-
# if p == -1
205-
# p = size(ΔVd, 1)
206-
# else
207-
# p == size(ΔVd, 1) || throw(DimensionMismatch())
208-
# end
209-
# end
210-
# if !(ΔS isa AbstractZero)
211-
# if p == -1
212-
# p = length(ΔS)
213-
# else
214-
# p == length(ΔS) || throw(DimensionMismatch())
215-
# end
216-
# end
217-
# Up = view(U, :, 1:p)
218-
# Vp = view(Vd, 1:p, :)'
219-
# Sp = view(S, 1:p)
220-
221-
# # rank
222-
# r = searchsortedlast(S, tol; rev=true)
223-
224-
# # compute antihermitian part of projection of ΔU and ΔV onto U and V
225-
# # also already subtract this projection from ΔU and ΔV
226-
# if !(ΔU isa AbstractZero)
227-
# UΔU = Up' * ΔU
228-
# aUΔU = rmul!(UΔU - UΔU', 1 / 2)
229-
# if m > p
230-
# ΔU -= Up * UΔU
231-
# end
232-
# else
233-
# aUΔU = fill!(similar(U, (p, p)), 0)
234-
# end
235-
# if !(ΔVd isa AbstractZero)
236-
# VΔV = Vp' * ΔVd'
237-
# aVΔV = rmul!(VΔV - VΔV', 1 / 2)
238-
# if n > p
239-
# ΔVd -= VΔV' * Vp'
240-
# end
241-
# else
242-
# aVΔV = fill!(similar(Vd, (p, p)), 0)
243-
# end
244-
245-
# # check whether cotangents arise from gauge-invariance objective function
246-
# mask = abs.(Sp' .- Sp) .< tol
247-
# Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
248-
# if p > r
249-
# rprange = (r + 1):p
250-
# Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
251-
# Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
252-
# end
253-
# Δgauge < tol ||
254-
# @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
255-
256-
# UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
257-
# (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
258-
# if !(ΔS isa ZeroTangent)
259-
# UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
260-
# # in principle, ΔS is real, but maybe not if coming from an anyonic tensor
261-
# end
262-
# mul!(ΔA, Up, UdΔAV * Vp')
263-
264-
# if r > p # contribution from truncation
265-
# Ur = view(U, :, (p + 1):r)
266-
# Vr = view(Vd, (p + 1):r, :)'
267-
# Sr = view(S, (p + 1):r)
268-
269-
# if !(ΔU isa AbstractZero)
270-
# UrΔU = Ur' * ΔU
271-
# if m > r
272-
# ΔU -= Ur * UrΔU # subtract this part from ΔU
273-
# end
274-
# else
275-
# UrΔU = fill!(similar(U, (r - p, p)), 0)
276-
# end
277-
# if !(ΔVd isa AbstractZero)
278-
# VrΔV = Vr' * ΔVd'
279-
# if n > r
280-
# ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
281-
# end
282-
# else
283-
# VrΔV = fill!(similar(Vd, (r - p, p)), 0)
284-
# end
285-
286-
# X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
287-
# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
288-
# Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .-
289-
# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
290-
291-
# # ΔA += Ur * X * Vp' + Up * Y' * Vr'
292-
# mul!(ΔA, Ur, X * Vp', 1, 1)
293-
# mul!(ΔA, Up * Y', Vr', 1, 1)
294-
# end
295-
296-
# if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
297-
# # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
298-
# mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1)
299-
# end
300-
# if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
301-
# # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
302-
# mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1)
303-
# end
304-
# return ΔA
305-
# end
306-
307168
function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
308169
tol::Real=default_pullback_gaugetol(D))
309170

0 commit comments

Comments
 (0)