11module MatrixAlgebraKitChainRulesCoreExt
22
33using MatrixAlgebraKit
4- using MatrixAlgebraKit: copy_input, TruncatedAlgorithm, zero!
4+ using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
5+ TruncatedAlgorithm, findtruncated, findtruncated_svd
56using ChainRulesCore
67using LinearAlgebra
78
@@ -24,7 +25,7 @@ for qr_f in (:qr_compact, :qr_full)
2425 QR = $ (qr_f!)(Ac, QR, alg)
2526 function qr_pullback(ΔQR)
2627 ΔA = zero(A)
27- MatrixAlgebraKit. qr_compact_pullback!(ΔA, QR, unthunk.(ΔQR))
28+ MatrixAlgebraKit. qr_compact_pullback!(ΔA, A, QR, unthunk.(ΔQR))
2829 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
2930 end
3031 function qr_pullback(:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
@@ -36,7 +37,7 @@ for qr_f in (:qr_compact, :qr_full)
3637end
3738function ChainRulesCore. rrule(:: typeof (qr_null!), A:: AbstractMatrix , N, alg)
3839 Ac = copy_input(qr_full, A)
39- QR = MatrixAlgebraKit . initialize_output(qr_full!, A, alg)
40+ QR = initialize_output(qr_full!, A, alg)
4041 Q, R = qr_full!(Ac, QR, alg)
4142 N = copy!(N, view(Q, 1 : size(A, 1 ), (size(A, 2 ) + 1 ): size(A, 1 )))
4243 function qr_null_pullback(ΔN)
@@ -45,7 +46,7 @@ function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
4546 minmn = min(m, n)
4647 ΔQ = zero!(similar(A, (m, m)))
4748 view(ΔQ, 1 : m, (minmn + 1 ): m) .= unthunk.(ΔN)
48- MatrixAlgebraKit. qr_compact_pullback!(ΔA, (Q, R), (ΔQ, ZeroTangent()))
49+ MatrixAlgebraKit. qr_compact_pullback!(ΔA, A, (Q, R), (ΔQ, ZeroTangent()))
4950 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
5051 end
5152 function qr_null_pullback(:: ZeroTangent ) # is this extra definition useful?
@@ -62,7 +63,7 @@ for lq_f in (:lq_compact, :lq_full)
6263 LQ = $ (lq_f!)(Ac, LQ, alg)
6364 function lq_pullback(ΔLQ)
6465 ΔA = zero(A)
65- MatrixAlgebraKit. lq_compact_pullback!(ΔA, LQ, unthunk.(ΔLQ))
66+ MatrixAlgebraKit. lq_compact_pullback!(ΔA, A, LQ, unthunk.(ΔLQ))
6667 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
6768 end
6869 function lq_pullback(:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
@@ -74,7 +75,7 @@ for lq_f in (:lq_compact, :lq_full)
7475end
7576function ChainRulesCore. rrule(:: typeof (lq_null!), A:: AbstractMatrix , Nᴴ, alg)
7677 Ac = copy_input(lq_full, A)
77- LQ = MatrixAlgebraKit . initialize_output(lq_full!, A, alg)
78+ LQ = initialize_output(lq_full!, A, alg)
7879 L, Q = lq_full!(Ac, LQ, alg)
7980 Nᴴ = copy!(Nᴴ, view(Q, (size(A, 1 ) + 1 ): size(A, 2 ), 1 : size(A, 2 )))
8081 function lq_null_pullback(ΔNᴴ)
@@ -83,7 +84,7 @@ function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
8384 minmn = min(m, n)
8485 ΔQ = zero!(similar(A, (n, n)))
8586 view(ΔQ, (minmn + 1 ): n, 1 : n) .= unthunk.(ΔNᴴ)
86- MatrixAlgebraKit. lq_compact_pullback!(ΔA, (L, Q), (ZeroTangent(), ΔQ))
87+ MatrixAlgebraKit. lq_compact_pullback!(ΔA, A, (L, Q), (ZeroTangent(), ΔQ))
8788 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
8889 end
8990 function lq_null_pullback(:: ZeroTangent ) # is this extra definition useful?
9596for eig in (:eig, :eigh)
9697 eig_f = Symbol(eig, " _full" )
9798 eig_f! = Symbol(eig_f, " !" )
98- eig_f_pb ! = Symbol(eig, " _full_pullback !" )
99+ eig_pb ! = Symbol(eig, " _pullback !" )
99100 eig_pb = Symbol(eig, " _pullback" )
101+ eig_t! = Symbol(eig, " _trunc!" )
102+ eig_t_pb = Symbol(eig, " _trunc_pullback" )
103+ _make_eig_t_pb = Symbol(" _make_" , eig_t_pb)
100104 @eval begin
101105 function ChainRulesCore. rrule(:: typeof ($ eig_f!), A:: AbstractMatrix , DV, alg)
102106 Ac = copy_input($ eig_f, A)
103107 DV = $ (eig_f!)(Ac, DV, alg)
104108 function $ eig_pb(ΔDV)
105109 ΔA = zero(A)
106- MatrixAlgebraKit.$ eig_f_pb !(ΔA, DV, unthunk.(ΔDV))
110+ MatrixAlgebraKit.$ eig_pb !(ΔA, A , DV, unthunk.(ΔDV))
107111 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
108112 end
109113 function $ eig_pb(:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
110114 return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
111115 end
112116 return DV, $ eig_pb
113117 end
118+ function ChainRulesCore. rrule(
119+ :: typeof ($ eig_t!), A:: AbstractMatrix , DV,
120+ alg:: TruncatedAlgorithm
121+ )
122+ Ac = copy_input($ eig_f, A)
123+ D, V = $ (eig_f!)(Ac, DV, alg. alg)
124+ ind = findtruncated(diagview(D), alg. trunc)
125+ return (Diagonal(diagview(D)[ind]), V[:, ind]),
126+ $ (_make_eig_t_pb)(A, (D, V), ind)
127+ end
128+ function $ (_make_eig_t_pb)(A, DV, ind)
129+ function $ eig_t_pb(ΔDV)
130+ ΔA = zero(A)
131+ MatrixAlgebraKit.$ eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind)
132+ return NoTangent(), ΔA, ZeroTangent(), NoTangent()
133+ end
134+ function $ eig_t_pb(:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
135+ return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
136+ end
137+ return $ eig_t_pb
138+ end
114139 end
115140end
116141
@@ -122,7 +147,7 @@ for svd_f in (:svd_compact, :svd_full)
122147 USVᴴ = $ (svd_f!)(Ac, USVᴴ, alg)
123148 function svd_pullback(ΔUSVᴴ)
124149 ΔA = zero(A)
125- MatrixAlgebraKit. svd_compact_pullback !(ΔA, USVᴴ, unthunk.(ΔUSVᴴ))
150+ MatrixAlgebraKit. svd_pullback !(ΔA, A , USVᴴ, unthunk.(ΔUSVᴴ))
126151 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
127152 end
128153 function svd_pullback(:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
@@ -134,27 +159,33 @@ for svd_f in (:svd_compact, :svd_full)
134159end
135160
136161function ChainRulesCore. rrule(
137- :: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ, alg:: TruncatedAlgorithm
162+ :: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ,
163+ alg:: TruncatedAlgorithm
138164 )
139- Ac = MatrixAlgebraKit. copy_input(svd_compact, A)
140- USVᴴ = svd_compact!(Ac, USVᴴ, alg. alg)
165+ Ac = copy_input(svd_compact, A)
166+ U, S, Vᴴ = svd_compact!(Ac, USVᴴ, alg. alg)
167+ ind = findtruncated_svd(diagview(S), alg. trunc)
168+ return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]),
169+ _make_svd_trunc_pullback(A, (U, S, Vᴴ), ind)
170+ end
171+ function _make_svd_trunc_pullback(A, USVᴴ, ind)
141172 function svd_trunc_pullback(ΔUSVᴴ)
142173 ΔA = zero(A)
143- MatrixAlgebraKit. svd_compact_pullback !(ΔA, USVᴴ, unthunk.(ΔUSVᴴ))
174+ MatrixAlgebraKit. svd_pullback !(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind )
144175 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
145176 end
146177 function svd_trunc_pullback(:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
147178 return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
148179 end
149- return MatrixAlgebraKit . truncate!(svd_trunc!, USVᴴ, alg . trunc), svd_trunc_pullback
180+ return svd_trunc_pullback
150181end
151182
152183function ChainRulesCore. rrule(:: typeof (left_polar!), A:: AbstractMatrix , WP, alg)
153184 Ac = copy_input(left_polar, A)
154185 WP = left_polar!(Ac, WP, alg)
155186 function left_polar_pullback(ΔWP)
156187 ΔA = zero(A)
157- MatrixAlgebraKit. left_polar_pullback!(ΔA, WP, unthunk.(ΔWP))
188+ MatrixAlgebraKit. left_polar_pullback!(ΔA, A, WP, unthunk.(ΔWP))
158189 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
159190 end
160191 function left_polar_pullback(:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
@@ -168,7 +199,7 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ,
168199 PWᴴ = right_polar!(Ac, PWᴴ, alg)
169200 function right_polar_pullback(ΔPWᴴ)
170201 ΔA = zero(A)
171- MatrixAlgebraKit. right_polar_pullback!(ΔA, PWᴴ, unthunk.(ΔPWᴴ))
202+ MatrixAlgebraKit. right_polar_pullback!(ΔA, A, PWᴴ, unthunk.(ΔPWᴴ))
172203 return NoTangent(), ΔA, ZeroTangent(), NoTangent()
173204 end
174205 function right_polar_pullback(:: Tuple{ZeroTangent, ZeroTangent} ) # is this extra definition useful?
0 commit comments