|
1 | | -using MatrixAlgebraKit: svd_compact_pullback! |
| 1 | +using MatrixAlgebraKit: svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback! |
2 | 2 |
|
3 | 3 | # Factorizations rules |
4 | 4 | # -------------------- |
@@ -46,47 +46,41 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTenso |
46 | 46 | end |
47 | 47 |
|
48 | 48 | function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...) |
49 | | - D, V = eig(t; kwargs...) |
| 49 | + DV = eig(t; kwargs...) |
50 | 50 |
|
51 | | - function eig!_pullback((_ΔD, _ΔV)) |
52 | | - ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV) |
| 51 | + function eig!_pullback(ΔDV′) |
| 52 | + ΔDV = unthunk.(ΔDV′) |
53 | 53 | Δt = similar(t) |
54 | | - for (c, b) in blocks(Δt) |
55 | | - Dc, Vc = block(D, c), block(V, c) |
56 | | - ΔDc, ΔVc = block(ΔD, c), block(ΔV, c) |
57 | | - Ddc = view(Dc, diagind(Dc)) |
58 | | - ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc)) |
59 | | - eig_pullback!(b, Ddc, Vc, ΔDdc, ΔVc) |
| 54 | + foreachblock(Δt) do (c, b) |
| 55 | + DVc = block.(DV, Ref(c)) |
| 56 | + ΔDVc = block.(ΔDV, Ref(c)) |
| 57 | + eig_full_pullback!(b, DVc, ΔDVc) |
| 58 | + return nothing |
60 | 59 | end |
61 | 60 | return NoTangent(), Δt |
62 | 61 | end |
63 | | - function eig!_pullback(::Tuple{ZeroTangent,ZeroTangent}) |
64 | | - return NoTangent(), ZeroTangent() |
65 | | - end |
| 62 | + eig!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() |
66 | 63 |
|
67 | | - return (D, V), eig!_pullback |
| 64 | + return DV, eig!_pullback |
68 | 65 | end |
69 | 66 |
|
70 | 67 | function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...) |
71 | | - D, V = eigh(t; kwargs...) |
| 68 | + DV = eigh(t; kwargs...) |
72 | 69 |
|
73 | | - function eigh!_pullback((_ΔD, _ΔV)) |
74 | | - ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV) |
| 70 | + function eigh!_pullback(ΔDV′) |
| 71 | + ΔDV = unthunk.(ΔDV′) |
75 | 72 | Δt = similar(t) |
76 | | - for (c, b) in blocks(Δt) |
77 | | - Dc, Vc = block(D, c), block(V, c) |
78 | | - ΔDc, ΔVc = block(ΔD, c), block(ΔV, c) |
79 | | - Ddc = view(Dc, diagind(Dc)) |
80 | | - ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc)) |
81 | | - eigh_pullback!(b, Ddc, Vc, ΔDdc, ΔVc) |
| 73 | + foreachblock(Δt) do (c, b) |
| 74 | + DVc = block.(DV, Ref(c)) |
| 75 | + ΔDVc = block.(ΔDV, Ref(c)) |
| 76 | + eigh_full_pullback!(b, DVc, ΔDVc) |
| 77 | + return nothing |
82 | 78 | end |
83 | 79 | return NoTangent(), Δt |
84 | 80 | end |
85 | | - function eigh!_pullback(::Tuple{ZeroTangent,ZeroTangent}) |
86 | | - return NoTangent(), ZeroTangent() |
87 | | - end |
| 81 | + eigh!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() |
88 | 82 |
|
89 | | - return (D, V), eigh!_pullback |
| 83 | + return DV, eigh!_pullback |
90 | 84 | end |
91 | 85 |
|
92 | 86 | function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; |
@@ -165,75 +159,6 @@ function uppertriangularind(A::AbstractMatrix) |
165 | 159 | return I |
166 | 160 | end |
167 | 161 |
|
168 | | -function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; |
169 | | - tol::Real=default_pullback_gaugetol(D)) |
170 | | - |
171 | | - # Basic size checks and determination |
172 | | - n = LinearAlgebra.checksquare(V) |
173 | | - n == length(D) || throw(DimensionMismatch()) |
174 | | - |
175 | | - if !(ΔV isa AbstractZero) |
176 | | - VdΔV = V' * ΔV |
177 | | - |
178 | | - mask = abs.(transpose(D) .- D) .< tol |
179 | | - Δgauge = norm(view(VdΔV, mask), Inf) |
180 | | - Δgauge < tol || |
181 | | - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" |
182 | | - |
183 | | - VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol)) |
184 | | - |
185 | | - if !(ΔD isa AbstractZero) |
186 | | - view(VdΔV, diagind(VdΔV)) .+= ΔD |
187 | | - end |
188 | | - PΔV = V' \ VdΔV |
189 | | - if eltype(ΔA) <: Real |
190 | | - ΔAc = mul!(VdΔV, PΔV, V') # recycle VdΔV memory |
191 | | - ΔA .= real.(ΔAc) |
192 | | - else |
193 | | - mul!(ΔA, PΔV, V') |
194 | | - end |
195 | | - else |
196 | | - PΔV = V' \ Diagonal(ΔD) |
197 | | - if eltype(ΔA) <: Real |
198 | | - ΔAc = PΔV * V' |
199 | | - ΔA .= real.(ΔAc) |
200 | | - else |
201 | | - mul!(ΔA, PΔV, V') |
202 | | - end |
203 | | - end |
204 | | - return ΔA |
205 | | -end |
206 | | - |
207 | | -function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; |
208 | | - tol::Real=default_pullback_gaugetol(D)) |
209 | | - |
210 | | - # Basic size checks and determination |
211 | | - n = LinearAlgebra.checksquare(V) |
212 | | - n == length(D) || throw(DimensionMismatch()) |
213 | | - |
214 | | - if !(ΔV isa AbstractZero) |
215 | | - VdΔV = V' * ΔV |
216 | | - aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2) |
217 | | - |
218 | | - mask = abs.(D' .- D) .< tol |
219 | | - Δgauge = norm(view(aVdΔV, mask)) |
220 | | - Δgauge < tol || |
221 | | - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" |
222 | | - |
223 | | - aVdΔV .*= safe_inv.(D' .- D, tol) |
224 | | - |
225 | | - if !(ΔD isa AbstractZero) |
226 | | - view(aVdΔV, diagind(aVdΔV)) .+= real.(ΔD) |
227 | | - # in principle, ΔD is real, but maybe not if coming from an anyonic tensor |
228 | | - end |
229 | | - # recylce VdΔV space |
230 | | - mul!(ΔA, mul!(VdΔV, V, aVdΔV), V') |
231 | | - else |
232 | | - mul!(ΔA, V * Diagonal(ΔD), V') |
233 | | - end |
234 | | - return ΔA |
235 | | -end |
236 | | - |
237 | 162 | function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR; |
238 | 163 | tol::Real=default_pullback_gaugetol(R)) |
239 | 164 | Rd = view(R, diagind(R)) |
|
0 commit comments