@@ -3,40 +3,32 @@ using MatrixAlgebraKit: svd_compact_pullback!
33# Factorizations rules
44# --------------------
55function ChainRulesCore. rrule (:: typeof (TensorKit. tsvd!), t:: AbstractTensorMap ;
6- trunc:: TensorKit.TruncationScheme = TensorKit. notrunc (),
7- alg:: Union{TensorKit.SVD,TensorKit.SDD} = TensorKit. SDD ())
8- U, Σ, V⁺, truncerr = tsvd (t; trunc= TensorKit. notrunc (), alg)
9-
10- if ! (trunc == TensorKit. notrunc ()) && ! isempty (blocksectors (t))
11- Σdata = TensorKit. SectorDict (c => diag (b) for (c, b) in blocks (Σ))
12-
13- truncdim = TensorKit. _compute_truncdim (Σdata, trunc; p= 2 )
14- truncerr = TensorKit. _compute_truncerr (Σdata, truncdim; p= 2 )
15-
16- SVDdata = TensorKit. SectorDict (c => (block (U, c), Σc, block (V⁺, c))
17- for (c, Σc) in Σdata)
18-
19- Ũ, Σ̃, Ṽ⁺ = TensorKit. _create_svdtensors (t, SVDdata, truncdim)
6+ trunc:: TruncationStrategy = TensorKit. notrunc (),
7+ kwargs... )
8+ # TODO : I think we can use tsvd! here without issues because we don't actually require
9+ # the data of `t` anymore.
10+ USVᴴ = tsvd (t; trunc= TensorKit. notrunc (), alg)
11+
12+ if trunc != TensorKit. notrunc () && ! isempty (blocksectors (t))
13+ USVᴴ′ = MatrixAlgebraKit. truncate! (svd_trunc!, USVᴴ, trunc)
2014 else
21- Ũ, Σ̃, Ṽ⁺ = U, Σ, V⁺
15+ USVᴴ′ = USVᴴ
2216 end
2317
24- function tsvd!_pullback (ΔUSVϵ )
25- ΔU, ΔΣ, ΔV⁺, = unthunk .(ΔUSVϵ )
18+ function tsvd!_pullback (ΔUSVᴴ′ )
19+ ΔUSVᴴ = unthunk .(ΔUSVᴴ′ )
2620 Δt = similar (t)
2721 foreachblock (Δt) do (c, b)
28- USVᴴc = ( block (U, c), block (Σ, c), block (V⁺, c))
29- ΔUSVᴴc = ( block (ΔU, c), block (ΔΣ, c), block (ΔV⁺, c))
22+ USVᴴc = block .(USVᴴ, Ref ( c))
23+ ΔUSVᴴc = block .(ΔUSVᴴ, Ref ( c))
3024 svd_compact_pullback! (b, USVᴴc, ΔUSVᴴc)
3125 return nothing
3226 end
3327 return NoTangent (), Δt
3428 end
35- function tsvd!_pullback (:: Tuple{ZeroTangent,ZeroTangent,ZeroTangent} )
36- return NoTangent (), ZeroTangent ()
37- end
29+ tsvd!_pullback (:: NTuple{3,ZeroTangent} ) = NoTangent (), ZeroTangent ()
3830
39- return (Ũ, Σ̃, Ṽ⁺, truncerr) , tsvd!_pullback
31+ return USVᴴ′ , tsvd!_pullback
4032end
4133
4234function ChainRulesCore. rrule (:: typeof (LinearAlgebra. svdvals!), t:: AbstractTensorMap )
@@ -173,137 +165,6 @@ function uppertriangularind(A::AbstractMatrix)
173165 return I
174166end
175167
176- # SVD_pullback: pullback implementation for general (possibly truncated) SVD
177- #
178- # Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as
179- # cotangent ΔU, ΔS, ΔVd variables of truncated SVD
180- #
181- # Checks whether the cotangent variables are such that they would couple to gauge-dependent
182- # degrees of freedom (phases of singular vectors), and prints a warning if this is the case
183- #
184- # An implementation that only uses U, S, and Vd from truncated SVD is also possible, but
185- # requires solving a Sylvester equation, which does not seem to be supported on GPUs.
186- #
187- # Other implementation considerations for GPU compatibility:
188- # no scalar indexing, lots of broadcasting and views
189- #
190- # function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
191- # Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
192- # tol::Real=default_pullback_gaugetol(S))
193-
194- # # Basic size checks and determination
195- # m, n = size(U, 1), size(Vd, 2)
196- # size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
197- # p = -1
198- # if !(ΔU isa AbstractZero)
199- # m == size(ΔU, 1) || throw(DimensionMismatch())
200- # p = size(ΔU, 2)
201- # end
202- # if !(ΔVd isa AbstractZero)
203- # n == size(ΔVd, 2) || throw(DimensionMismatch())
204- # if p == -1
205- # p = size(ΔVd, 1)
206- # else
207- # p == size(ΔVd, 1) || throw(DimensionMismatch())
208- # end
209- # end
210- # if !(ΔS isa AbstractZero)
211- # if p == -1
212- # p = length(ΔS)
213- # else
214- # p == length(ΔS) || throw(DimensionMismatch())
215- # end
216- # end
217- # Up = view(U, :, 1:p)
218- # Vp = view(Vd, 1:p, :)'
219- # Sp = view(S, 1:p)
220-
221- # # rank
222- # r = searchsortedlast(S, tol; rev=true)
223-
224- # # compute antihermitian part of projection of ΔU and ΔV onto U and V
225- # # also already subtract this projection from ΔU and ΔV
226- # if !(ΔU isa AbstractZero)
227- # UΔU = Up' * ΔU
228- # aUΔU = rmul!(UΔU - UΔU', 1 / 2)
229- # if m > p
230- # ΔU -= Up * UΔU
231- # end
232- # else
233- # aUΔU = fill!(similar(U, (p, p)), 0)
234- # end
235- # if !(ΔVd isa AbstractZero)
236- # VΔV = Vp' * ΔVd'
237- # aVΔV = rmul!(VΔV - VΔV', 1 / 2)
238- # if n > p
239- # ΔVd -= VΔV' * Vp'
240- # end
241- # else
242- # aVΔV = fill!(similar(Vd, (p, p)), 0)
243- # end
244-
245- # # check whether cotangents arise from gauge-invariance objective function
246- # mask = abs.(Sp' .- Sp) .< tol
247- # Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
248- # if p > r
249- # rprange = (r + 1):p
250- # Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
251- # Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
252- # end
253- # Δgauge < tol ||
254- # @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
255-
256- # UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
257- # (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
258- # if !(ΔS isa ZeroTangent)
259- # UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
260- # # in principle, ΔS is real, but maybe not if coming from an anyonic tensor
261- # end
262- # mul!(ΔA, Up, UdΔAV * Vp')
263-
264- # if r > p # contribution from truncation
265- # Ur = view(U, :, (p + 1):r)
266- # Vr = view(Vd, (p + 1):r, :)'
267- # Sr = view(S, (p + 1):r)
268-
269- # if !(ΔU isa AbstractZero)
270- # UrΔU = Ur' * ΔU
271- # if m > r
272- # ΔU -= Ur * UrΔU # subtract this part from ΔU
273- # end
274- # else
275- # UrΔU = fill!(similar(U, (r - p, p)), 0)
276- # end
277- # if !(ΔVd isa AbstractZero)
278- # VrΔV = Vr' * ΔVd'
279- # if n > r
280- # ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
281- # end
282- # else
283- # VrΔV = fill!(similar(Vd, (r - p, p)), 0)
284- # end
285-
286- # X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
287- # (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
288- # Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .-
289- # (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
290-
291- # # ΔA += Ur * X * Vp' + Up * Y' * Vr'
292- # mul!(ΔA, Ur, X * Vp', 1, 1)
293- # mul!(ΔA, Up * Y', Vr', 1, 1)
294- # end
295-
296- # if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
297- # # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
298- # mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1)
299- # end
300- # if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
301- # # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
302- # mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1)
303- # end
304- # return ΔA
305- # end
306-
307168function eig_pullback! (ΔA:: AbstractMatrix , D:: AbstractVector , V:: AbstractMatrix , ΔD, ΔV;
308169 tol:: Real = default_pullback_gaugetol (D))
309170
0 commit comments