|
99 | 99 | function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) |
100 | 100 | alg isa TensorKit.QR || alg isa TensorKit.QRpos || |
101 | 101 | error("only `alg=QR()` and `alg=QRpos()` are supported") |
102 | | - Q, R = leftorth(t; alg) |
103 | | - function leftorth!_pullback((_ΔQ, _ΔR)) |
104 | | - ΔQ, ΔR = unthunk(_ΔQ), unthunk(_ΔR) |
105 | | - Δt = similar(t) |
106 | | - for (c, b) in blocks(Δt) |
107 | | - qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c)) |
| 102 | + QR = leftorth(t; alg) |
| 103 | + function leftorth!_pullback(ΔQR′) |
| 104 | + ΔQR = unthunk.(ΔQR′) |
| 105 | + Δt = zerovector(t) |
| 106 | + foreachblock(Δt) do c, (b,) |
| 107 | + QRc = block.(QR, Ref(c)) |
| 108 | + ΔQRc = block.(ΔQR, Ref(c)) |
| 109 | + qr_compact_pullback!(b, QRc, ΔQRc) |
| 110 | + return nothing |
108 | 111 | end |
109 | 112 | return NoTangent(), Δt |
110 | 113 | end |
111 | | - leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent() |
112 | | - return (Q, R), leftorth!_pullback |
| 114 | + leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() |
| 115 | + |
| 116 | + return QR, leftorth!_pullback |
113 | 117 | end |
114 | 118 |
|
115 | 119 | function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) |
116 | 120 | alg isa TensorKit.LQ || alg isa TensorKit.LQpos || |
117 | 121 | error("only `alg=LQ()` and `alg=LQpos()` are supported") |
118 | | - L, Q = rightorth(t; alg) |
119 | | - function rightorth!_pullback((_ΔL, _ΔQ)) |
120 | | - ΔL, ΔQ = unthunk(_ΔL), unthunk(_ΔQ) |
121 | | - Δt = similar(t) |
122 | | - for (c, b) in blocks(Δt) |
123 | | - lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c)) |
| 122 | + LQ = rightorth(t; alg) |
| 123 | + function rightorth!_pullback(ΔLQ′) |
| 124 | + ΔLQ = unthunk(ΔLQ′) |
| 125 | + Δt = zerovector(t) |
| 126 | + foreachblock(Δt) do c, (b,) |
| 127 | + LQc = block.(LQ, Ref(c)) |
| 128 | + ΔLQc = block.(ΔLQ, Ref(c)) |
| 129 | + lq_compact_pullback!(b, LQc, ΔLQc) |
| 130 | + return nothing |
124 | 131 | end |
125 | 132 | return NoTangent(), Δt |
126 | 133 | end |
127 | | - rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent() |
128 | | - return (L, Q), rightorth!_pullback |
129 | | -end |
130 | | - |
131 | | -# Corresponding matrix factorisations: implemented as mutating methods |
132 | | -# --------------------------------------------------------------------- |
133 | | -# helper routines |
134 | | -safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a) |
135 | | - |
136 | | -function lowertriangularind(A::AbstractMatrix) |
137 | | - m, n = size(A) |
138 | | - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) |
139 | | - offset = 0 |
140 | | - for j in 1:n |
141 | | - r = (j + 1):m |
142 | | - I[offset .- j .+ r] = (j - 1) * m .+ r |
143 | | - offset += length(r) |
144 | | - end |
145 | | - return I |
146 | | -end |
147 | | - |
148 | | -function uppertriangularind(A::AbstractMatrix) |
149 | | - m, n = size(A) |
150 | | - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) |
151 | | - offset = 0 |
152 | | - for i in 1:m |
153 | | - r = (i + 1):n |
154 | | - I[offset .- i .+ r] = i .+ m .* (r .- 1) |
155 | | - offset += length(r) |
156 | | - end |
157 | | - return I |
158 | | -end |
159 | | - |
160 | | -function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR; |
161 | | - tol::Real=default_pullback_gaugetol(R)) |
162 | | - Rd = view(R, diagind(R)) |
163 | | - p = something(findlast(≥(tol) ∘ abs, Rd), 0) |
164 | | - m, n = size(R) |
165 | | - |
166 | | - Q1 = view(Q, :, 1:p) |
167 | | - R1 = view(R, 1:p, :) |
168 | | - R11 = view(R, 1:p, 1:p) |
169 | | - |
170 | | - ΔA1 = view(ΔA, :, 1:p) |
171 | | - ΔQ1 = view(ΔQ, :, 1:p) |
172 | | - ΔR1 = view(ΔR, 1:p, :) |
173 | | - |
174 | | - M = similar(R, (p, p)) |
175 | | - ΔR isa AbstractZero || mul!(M, ΔR1, R1') |
176 | | - ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero)) |
177 | | - view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) |
178 | | - if eltype(M) <: Complex |
179 | | - Md = view(M, diagind(M)) |
180 | | - Md .= real.(Md) |
181 | | - end |
182 | | - |
183 | | - ΔA1 .= ΔQ1 |
184 | | - mul!(ΔA1, Q1, M, +1, 1) |
185 | | - |
186 | | - if n > p |
187 | | - R12 = view(R, 1:p, (p + 1):n) |
188 | | - ΔA2 = view(ΔA, :, (p + 1):n) |
189 | | - ΔR12 = view(ΔR, 1:p, (p + 1):n) |
190 | | - |
191 | | - if ΔR isa AbstractZero |
192 | | - ΔA2 .= zero(eltype(ΔA)) |
193 | | - else |
194 | | - mul!(ΔA2, Q1, ΔR12) |
195 | | - mul!(ΔA1, ΔA2, R12', -1, 1) |
196 | | - end |
197 | | - end |
198 | | - if m > p && !(ΔQ isa AbstractZero) # case where R is not full rank |
199 | | - Q2 = view(Q, :, (p + 1):m) |
200 | | - ΔQ2 = view(ΔQ, :, (p + 1):m) |
201 | | - Q1dΔQ2 = Q1' * ΔQ2 |
202 | | - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) |
203 | | - Δgauge < tol || |
204 | | - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" |
205 | | - mul!(ΔA1, Q2, Q1dΔQ2', -1, 1) |
206 | | - end |
207 | | - rdiv!(ΔA1, UpperTriangular(R11)') |
208 | | - return ΔA |
209 | | -end |
210 | | - |
211 | | -function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ; |
212 | | - tol::Real=default_pullback_gaugetol(L)) |
213 | | - Ld = view(L, diagind(L)) |
214 | | - p = something(findlast(≥(tol) ∘ abs, Ld), 0) |
215 | | - m, n = size(L) |
216 | | - |
217 | | - L1 = view(L, :, 1:p) |
218 | | - L11 = view(L, 1:p, 1:p) |
219 | | - Q1 = view(Q, 1:p, :) |
220 | | - |
221 | | - ΔA1 = view(ΔA, 1:p, :) |
222 | | - ΔQ1 = view(ΔQ, 1:p, :) |
223 | | - ΔL1 = view(ΔL, :, 1:p) |
224 | | - |
225 | | - M = similar(L, (p, p)) |
226 | | - ΔL isa AbstractZero || mul!(M, L1', ΔL1) |
227 | | - ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero)) |
228 | | - view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) |
229 | | - if eltype(M) <: Complex |
230 | | - Md = view(M, diagind(M)) |
231 | | - Md .= real.(Md) |
232 | | - end |
233 | | - |
234 | | - ΔA1 .= ΔQ1 |
235 | | - mul!(ΔA1, M, Q1, +1, 1) |
236 | | - |
237 | | - if m > p |
238 | | - L21 = view(L, (p + 1):m, 1:p) |
239 | | - ΔA2 = view(ΔA, (p + 1):m, :) |
240 | | - ΔL21 = view(ΔL, (p + 1):m, 1:p) |
241 | | - |
242 | | - if ΔL isa AbstractZero |
243 | | - ΔA2 .= zero(eltype(ΔA)) |
244 | | - else |
245 | | - mul!(ΔA2, ΔL21, Q1) |
246 | | - mul!(ΔA1, L21', ΔA2, -1, 1) |
247 | | - end |
248 | | - end |
249 | | - if n > p && !(ΔQ isa AbstractZero) # case where R is not full rank |
250 | | - Q2 = view(Q, (p + 1):n, :) |
251 | | - ΔQ2 = view(ΔQ, (p + 1):n, :) |
252 | | - ΔQ2Q1d = ΔQ2 * Q1' |
253 | | - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)) |
254 | | - Δgauge < tol || |
255 | | - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" |
256 | | - mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1) |
257 | | - end |
258 | | - ldiv!(LowerTriangular(L11)', ΔA1) |
259 | | - return ΔA |
260 | | -end |
261 | | - |
262 | | -function default_pullback_gaugetol(a) |
263 | | - n = norm(a, Inf) |
264 | | - return eps(eltype(n))^(3 / 4) * max(n, one(n)) |
| 134 | + rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() |
| 135 | + return LQ, rightorth!_pullback |
265 | 136 | end |
0 commit comments