@@ -18,7 +18,7 @@ for qr_f in (:qr_compact, :qr_full)
1818 function qr_pullback (ΔQR′)
1919 ΔQR = unthunk .(ΔQR′)
2020 Δt = zerovector (t)
21- MatrixAlgebraKit. qr_compact_pullback! (Δt, QR, ΔQR)
21+ MatrixAlgebraKit. qr_compact_pullback! (Δt, t, QR, ΔQR)
2222 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
2323 end
2424 function qr_pullback (:: Tuple{ZeroTangent,ZeroTangent} )
@@ -44,7 +44,7 @@ function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg)
4444 return copy! (@view (ΔQc[:, (end - n + 1 ): end ]), b)
4545 end
4646 ΔR = ZeroTangent ()
47- MatrixAlgebraKit. qr_compact_pullback! (Δt, (Q, R), (ΔQ, ΔR))
47+ MatrixAlgebraKit. qr_compact_pullback! (Δt, t, (Q, R), (ΔQ, ΔR))
4848 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
4949 end
5050 qr_null_pullback (:: ZeroTangent ) = NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
@@ -60,7 +60,7 @@ for lq_f in (:lq_compact, :lq_full)
6060 function lq_pullback (ΔLQ′)
6161 ΔLQ = unthunk .(ΔLQ′)
6262 Δt = zerovector (t)
63- MatrixAlgebraKit. lq_compact_pullback! (Δt, LQ, ΔLQ)
63+ MatrixAlgebraKit. lq_compact_pullback! (Δt, t, LQ, ΔLQ)
6464 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
6565 end
6666 function lq_pullback (:: Tuple{ZeroTangent,ZeroTangent} )
@@ -86,7 +86,7 @@ function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, al
8686 return copy! (@view (ΔQc[(end - m + 1 ): end , :]), b)
8787 end
8888 ΔL = ZeroTangent ()
89- MatrixAlgebraKit. lq_compact_pullback! (Δt, (L, Q), (ΔL, ΔQ))
89+ MatrixAlgebraKit. lq_compact_pullback! (Δt, t, (L, Q), (ΔL, ΔQ))
9090 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
9191 end
9292 lq_null_pullback (:: ZeroTangent ) = NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
9797for eig in (:eig , :eigh )
9898 eig_f = Symbol (eig, " _full" )
9999 eig_f! = Symbol (eig_f, " !" )
100- eig_f_pb! = Symbol (eig, " _full_pullback !" )
100+ eig_f_pb! = Symbol (eig, " _pullback !" )
101101 eig_pb = Symbol (eig, " _pullback" )
102102 @eval function ChainRulesCore. rrule (:: typeof ($ eig_f!), t:: AbstractTensorMap , DV, alg)
103- tc = copy_input ($ eig_f, t)
103+ tc = MatrixAlgebraKit . copy_input ($ eig_f, t)
104104 DV = $ (eig_f!)(tc, DV, alg)
105105 function $eig_pb (ΔDV)
106106 Δt = zerovector (t)
107- MatrixAlgebraKit.$ eig_f_pb! (Δt, DV, unthunk .(ΔDV))
107+ MatrixAlgebraKit.$ eig_f_pb! (Δt, t, DV, unthunk .(ΔDV))
108108 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
109109 end
110110 function $eig_pb (:: Tuple{ZeroTangent,ZeroTangent} )
@@ -118,11 +118,11 @@ for svd_f in (:svd_compact, :svd_full)
118118 svd_f! = Symbol (svd_f, " !" )
119119 @eval begin
120120 function ChainRulesCore. rrule (:: typeof ($ svd_f!), t:: AbstractTensorMap , USVᴴ, alg)
121- tc = copy_input ($ svd_f, t)
121+ tc = MatrixAlgebraKit . copy_input ($ svd_f, t)
122122 USVᴴ = $ (svd_f!)(tc, USVᴴ, alg)
123123 function svd_pullback (ΔUSVᴴ)
124124 Δt = zerovector (t)
125- MatrixAlgebraKit. svd_compact_pullback ! (Δt, USVᴴ, unthunk .(ΔUSVᴴ))
125+ MatrixAlgebraKit. svd_pullback ! (Δt, t , USVᴴ, unthunk .(ΔUSVᴴ))
126126 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
127127 end
128128 function svd_pullback (:: Tuple{ZeroTangent,ZeroTangent,ZeroTangent} )
@@ -137,23 +137,28 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ
137137 alg:: TruncatedAlgorithm )
138138 tc = MatrixAlgebraKit. copy_input (svd_compact, t)
139139 USVᴴ = svd_compact! (tc, USVᴴ, alg. alg)
140+ USVᴴ_trunc, ind = TensorKit. Factorizations. truncate (svd_trunc!, USVᴴ, alg. trunc)
141+ svd_trunc_pullback = _make_svd_trunc_pullback (t, USVᴴ, ind)
142+ return USVᴴ_trunc, svd_trunc_pullback
143+ end
144+ function _make_svd_trunc_pullback (t:: AbstractTensorMap , USVᴴ, ind)
140145 function svd_trunc_pullback (ΔUSVᴴ)
141146 Δt = zerovector (t)
142- MatrixAlgebraKit. svd_compact_pullback ! (Δt, USVᴴ, unthunk .(ΔUSVᴴ))
143- return NoTangent (), ΔA , ZeroTangent (), NoTangent ()
147+ MatrixAlgebraKit. svd_pullback ! (Δt, t, USVᴴ, unthunk .(ΔUSVᴴ), ind )
148+ return NoTangent (), Δt , ZeroTangent (), NoTangent ()
144149 end
145- function svd_trunc_pullback (:: Tuple{ZeroTangent,ZeroTangent ,ZeroTangent} )
150+ function svd_trunc_pullback (:: NTuple{3 ,ZeroTangent} )
146151 return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
147152 end
148- return MatrixAlgebraKit . truncate! (svd_trunc!, USVᴴ, alg . trunc), svd_trunc_pullback
153+ return svd_trunc_pullback
149154end
150155
151156function ChainRulesCore. rrule (:: typeof (left_polar!), t:: AbstractTensorMap , WP, alg)
152- tc = copy_input (left_polar, t)
157+ tc = MatrixAlgebraKit . copy_input (left_polar, t)
153158 WP = left_polar! (tc, WP, alg)
154159 function left_polar_pullback (ΔWP)
155160 Δt = zerovector (t)
156- MatrixAlgebraKit. left_polar_pullback! (Δt, WP, unthunk .(ΔWP))
161+ MatrixAlgebraKit. left_polar_pullback! (Δt, t, WP, unthunk .(ΔWP))
157162 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
158163 end
159164 function left_polar_pullback (:: Tuple{ZeroTangent,ZeroTangent} )
@@ -163,11 +168,11 @@ function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, a
163168end
164169
165170function ChainRulesCore. rrule (:: typeof (right_polar!), t:: AbstractTensorMap , PWᴴ, alg)
166- tc = copy_input (left_polar, t)
167- PWᴴ = right_polar! (Ac , PWᴴ, alg)
171+ tc = MatrixAlgebraKit . copy_input (left_polar, t)
172+ PWᴴ = right_polar! (tc , PWᴴ, alg)
168173 function right_polar_pullback (ΔPWᴴ)
169174 Δt = zerovector (t)
170- MatrixAlgebraKit. right_polar_pullback! (Δt, PWᴴ, unthunk .(ΔPWᴴ))
175+ MatrixAlgebraKit. right_polar_pullback! (Δt, t, PWᴴ, unthunk .(ΔPWᴴ))
171176 return NoTangent (), Δt, ZeroTangent (), NoTangent ()
172177 end
173178 function right_polar_pullback (:: Tuple{ZeroTangent,ZeroTangent} )
@@ -176,41 +181,6 @@ function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PW
176181 return PWᴴ, right_polar_pullback
177182end
178183
179- # for f in (:tsvd, :eig, :eigh)
180- # f! = Symbol(f, :!)
181- # f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!)
182- # f_pullback = Symbol(f, :_pullback)
183- # f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!)
184- # @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap;
185- # trunc::TruncationStrategy=TensorKit.notrunc(),
186- # kwargs...)
187- # # TODO : I think we can use f! here without issues because we don't actually require
188- # # the data of `t` anymore.
189- # F = $f(t; trunc=TensorKit.notrunc(), kwargs...)
190-
191- # if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
192- # F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc)
193- # else
194- # F′ = F
195- # end
196-
197- # function $f_pullback(ΔF′)
198- # ΔF = unthunk.(ΔF′)
199- # Δt = zerovector(t)
200- # foreachblock(Δt) do c, (b,)
201- # Fc = block.(F, Ref(c))
202- # ΔFc = block.(ΔF, Ref(c))
203- # $f_pullback!(b, Fc, ΔFc)
204- # return nothing
205- # end
206- # return NoTangent(), Δt
207- # end
208- # $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent()
209-
210- # return F′, $f_pullback
211- # end
212- # end
213-
214184# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
215185# U, S, V⁺ = tsvd(t)
216186# s = diag(S)
239209
240210# return d, eigvals_pullback
241211# end
242-
243- # function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
244- # alg isa MatrixAlgebraKit.LAPACK_HouseholderQR ||
245- # error("only `alg=QR()` and `alg=QRpos()` are supported")
246- # QR = leftorth(t; alg)
247- # function leftorth!_pullback(ΔQR′)
248- # ΔQR = unthunk.(ΔQR′)
249- # Δt = zerovector(t)
250- # foreachblock(Δt) do c, (b,)
251- # QRc = block.(QR, Ref(c))
252- # ΔQRc = block.(ΔQR, Ref(c))
253- # qr_compact_pullback!(b, QRc, ΔQRc)
254- # return nothing
255- # end
256- # return NoTangent(), Δt
257- # end
258- # leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
259-
260- # return QR, leftorth!_pullback
261- # end
262-
263- # function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
264- # alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ ||
265- # error("only `alg=LQ()` and `alg=LQpos()` are supported")
266- # LQ = rightorth(t; alg)
267- # function rightorth!_pullback(ΔLQ′)
268- # ΔLQ = unthunk(ΔLQ′)
269- # Δt = zerovector(t)
270- # foreachblock(Δt) do c, (b,)
271- # LQc = block.(LQ, Ref(c))
272- # ΔLQc = block.(ΔLQ, Ref(c))
273- # lq_compact_pullback!(b, LQc, ΔLQc)
274- # return nothing
275- # end
276- # return NoTangent(), Δt
277- # end
278- # rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
279- # return LQ, rightorth!_pullback
280- # end
0 commit comments