|
1 | 1 | # Factorizations rules |
2 | 2 | # -------------------- |
3 | | -function ChainRulesCore.rrule(::typeof(MatrixAlgebraKit.copy_input), f, |
4 | | - t::AbstractTensorMap) |
5 | | - project = ProjectTo(t) |
6 | | - copy_input_pullback(Δt) = (NoTangent(), NoTangent(), project(unthunk(Δt))) |
7 | | - return MatrixAlgebraKit.copy_input(f, t), copy_input_pullback |
8 | | -end |
9 | | - |
10 | | -@non_differentiable MatrixAlgebraKit.initialize_output(f, t::AbstractTensorMap, args...) |
11 | | -@non_differentiable MatrixAlgebraKit.check_input(f, t::AbstractTensorMap, args...) |
12 | | - |
13 | | -for qr_f in (:qr_compact, :qr_full) |
14 | | - qr_f! = Symbol(qr_f, '!') |
15 | | - @eval function ChainRulesCore.rrule(::typeof($qr_f!), t::AbstractTensorMap, QR, alg) |
16 | | - tc = MatrixAlgebraKit.copy_input($qr_f, t) |
17 | | - QR = $(qr_f!)(tc, QR, alg) |
18 | | - function qr_pullback(ΔQR′) |
19 | | - ΔQR = unthunk.(ΔQR′) |
20 | | - Δt = zerovector(t) |
21 | | - MatrixAlgebraKit.qr_compact_pullback!(Δt, t, QR, ΔQR) |
22 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
23 | | - end |
24 | | - function qr_pullback(::Tuple{ZeroTangent,ZeroTangent}) |
25 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
26 | | - end |
27 | | - return QR, qr_pullback |
28 | | - end |
29 | | -end |
30 | | -function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg) |
31 | | - Q, R = qr_full(t, alg) |
32 | | - for (c, b) in blocks(t) |
33 | | - m, n = size(b) |
34 | | - copy!(block(N, c), view(block(Q, c), 1:m, (n + 1):m)) |
35 | | - end |
36 | | - |
37 | | - function qr_null_pullback(ΔN′) |
38 | | - ΔN = unthunk(ΔN′) |
39 | | - Δt = zerovector(t) |
40 | | - ΔQ = zerovector!(similar(Q, codomain(Q) ← fuse(codomain(Q)))) |
41 | | - foreachblock(ΔN) do c, (b,) |
42 | | - n = size(b, 2) |
43 | | - ΔQc = block(ΔQ, c) |
44 | | - return copy!(@view(ΔQc[:, (end - n + 1):end]), b) |
45 | | - end |
46 | | - ΔR = ZeroTangent() |
47 | | - MatrixAlgebraKit.qr_compact_pullback!(Δt, t, (Q, R), (ΔQ, ΔR)) |
48 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
49 | | - end |
50 | | - qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
51 | | - |
52 | | - return N, qr_null_pullback |
53 | | -end |
54 | | - |
55 | | -for lq_f in (:lq_compact, :lq_full) |
56 | | - lq_f! = Symbol(lq_f, '!') |
57 | | - @eval function ChainRulesCore.rrule(::typeof($lq_f!), t::AbstractTensorMap, LQ, alg) |
58 | | - tc = MatrixAlgebraKit.copy_input($lq_f, t) |
59 | | - LQ = $(lq_f!)(tc, LQ, alg) |
60 | | - function lq_pullback(ΔLQ′) |
61 | | - ΔLQ = unthunk.(ΔLQ′) |
62 | | - Δt = zerovector(t) |
63 | | - MatrixAlgebraKit.lq_compact_pullback!(Δt, t, LQ, ΔLQ) |
64 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
65 | | - end |
66 | | - function lq_pullback(::Tuple{ZeroTangent,ZeroTangent}) |
67 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
68 | | - end |
69 | | - return LQ, lq_pullback |
70 | | - end |
71 | | -end |
72 | | -function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, alg) |
73 | | - L, Q = lq_full(t, alg) |
74 | | - for (c, b) in blocks(t) |
75 | | - m, n = size(b) |
76 | | - copy!(block(Nᴴ, c), view(block(Q, c), (m + 1):n, 1:n)) |
77 | | - end |
78 | | - |
79 | | - function lq_null_pullback(ΔNᴴ′) |
80 | | - ΔNᴴ = unthunk(ΔNᴴ′) |
81 | | - Δt = zerovector(t) |
82 | | - ΔQ = zerovector!(similar(Q, codomain(Q) ← fuse(codomain(Q)))) |
83 | | - foreachblock(ΔNᴴ) do c, (b,) |
84 | | - m = size(b, 1) |
85 | | - ΔQc = block(ΔQ, c) |
86 | | - return copy!(@view(ΔQc[(end - m + 1):end, :]), b) |
87 | | - end |
88 | | - ΔL = ZeroTangent() |
89 | | - MatrixAlgebraKit.lq_compact_pullback!(Δt, t, (L, Q), (ΔL, ΔQ)) |
90 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
91 | | - end |
92 | | - lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
93 | | - |
94 | | - return Nᴴ, lq_null_pullback |
95 | | -end |
96 | | - |
97 | | -for eig in (:eig, :eigh) |
98 | | - eig_f = Symbol(eig, "_full") |
99 | | - eig_f! = Symbol(eig_f, "!") |
100 | | - eig_f_pb! = Symbol(eig, "_pullback!") |
101 | | - eig_pb = Symbol(eig, "_pullback") |
102 | | - @eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg) |
103 | | - tc = MatrixAlgebraKit.copy_input($eig_f, t) |
104 | | - DV = $(eig_f!)(tc, DV, alg) |
105 | | - function $eig_pb(ΔDV) |
106 | | - Δt = zerovector(t) |
107 | | - MatrixAlgebraKit.$eig_f_pb!(Δt, t, DV, unthunk.(ΔDV)) |
108 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
109 | | - end |
110 | | - function $eig_pb(::Tuple{ZeroTangent,ZeroTangent}) |
111 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
112 | | - end |
113 | | - return DV, $eig_pb |
114 | | - end |
115 | | -end |
116 | | - |
117 | | -for svd_f in (:svd_compact, :svd_full) |
118 | | - svd_f! = Symbol(svd_f, "!") |
119 | | - @eval begin |
120 | | - function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg) |
121 | | - tc = MatrixAlgebraKit.copy_input($svd_f, t) |
122 | | - USVᴴ = $(svd_f!)(tc, USVᴴ, alg) |
123 | | - function svd_pullback(ΔUSVᴴ) |
124 | | - Δt = zerovector(t) |
125 | | - MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ)) |
126 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
127 | | - end |
128 | | - function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) |
129 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
130 | | - end |
131 | | - return USVᴴ, svd_pullback |
132 | | - end |
133 | | - end |
134 | | -end |
135 | | - |
136 | | -function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ, |
137 | | - alg::TruncatedAlgorithm) |
138 | | - tc = MatrixAlgebraKit.copy_input(svd_compact, t) |
139 | | - USVᴴ = svd_compact!(tc, USVᴴ, alg.alg) |
140 | | - USVᴴ_trunc, ind = TensorKit.Factorizations.truncate(svd_trunc!, USVᴴ, alg.trunc) |
141 | | - svd_trunc_pullback = _make_svd_trunc_pullback(t, USVᴴ, ind) |
142 | | - return USVᴴ_trunc, svd_trunc_pullback |
143 | | -end |
144 | | -function _make_svd_trunc_pullback(t::AbstractTensorMap, USVᴴ, ind) |
145 | | - function svd_trunc_pullback(ΔUSVᴴ) |
146 | | - Δt = zerovector(t) |
147 | | - MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ), ind) |
148 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
149 | | - end |
150 | | - function svd_trunc_pullback(::NTuple{3,ZeroTangent}) |
151 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
152 | | - end |
153 | | - return svd_trunc_pullback |
154 | | -end |
155 | | - |
156 | | -function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg) |
157 | | - tc = MatrixAlgebraKit.copy_input(left_polar, t) |
158 | | - WP = left_polar!(tc, WP, alg) |
159 | | - function left_polar_pullback(ΔWP) |
160 | | - Δt = zerovector(t) |
161 | | - MatrixAlgebraKit.left_polar_pullback!(Δt, t, WP, unthunk.(ΔWP)) |
162 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
163 | | - end |
164 | | - function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) |
165 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
166 | | - end |
167 | | - return WP, left_polar_pullback |
168 | | -end |
169 | | - |
170 | | -function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg) |
171 | | - tc = MatrixAlgebraKit.copy_input(left_polar, t) |
172 | | - PWᴴ = right_polar!(tc, PWᴴ, alg) |
173 | | - function right_polar_pullback(ΔPWᴴ) |
174 | | - Δt = zerovector(t) |
175 | | - MatrixAlgebraKit.right_polar_pullback!(Δt, t, PWᴴ, unthunk.(ΔPWᴴ)) |
176 | | - return NoTangent(), Δt, ZeroTangent(), NoTangent() |
177 | | - end |
178 | | - function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) |
179 | | - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() |
180 | | - end |
181 | | - return PWᴴ, right_polar_pullback |
182 | | -end |
183 | | - |
184 | | -# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) |
185 | | -# U, S, V⁺ = tsvd(t) |
186 | | -# s = diag(S) |
187 | | -# project_t = ProjectTo(t) |
188 | | - |
189 | | -# function svdvals_pullback(Δs′) |
190 | | -# Δs = unthunk(Δs′) |
191 | | -# ΔS = diagm(codomain(S), domain(S), Δs) |
192 | | -# return NoTangent(), project_t(U * ΔS * V⁺) |
193 | | -# end |
194 | | - |
195 | | -# return s, svdvals_pullback |
196 | | -# end |
197 | | - |
198 | | -# function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; |
199 | | -# sortby=nothing, kwargs...) |
200 | | -# @assert sortby === nothing "only `sortby=nothing` is supported" |
201 | | -# (D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...) |
202 | | -# d = diag(D) |
203 | | -# project_t = ProjectTo(t) |
204 | | -# function eigvals_pullback(Δd′) |
205 | | -# Δd = unthunk(Δd′) |
206 | | -# ΔD = diagm(codomain(D), domain(D), Δd) |
207 | | -# return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2]) |
208 | | -# end |
209 | | - |
210 | | -# return d, eigvals_pullback |
211 | | -# end |
0 commit comments