@@ -6,9 +6,10 @@ using MatrixAlgebraKit
66using MatrixAlgebraKit: inv_safe, diagview, copy_input
77using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9- using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
9+ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
10+ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1011using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11- using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
12+ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1213using LinearAlgebra
1314
1415
@@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
122123end
123124
124125for (f!, f, f_full, pb, adj) in (
125- (:eig_vals!, :eig_vals, :eig_full, :eig_pullback !, :eig_vals_adjoint),
126- (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback !, :eigh_vals_adjoint),
126+ (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback !, :eig_vals_adjoint),
127+ (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback !, :eigh_vals_adjoint),
127128 )
128129 @eval begin
129130 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
@@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
136137 copy!(D, diagview(DV[1 ]))
137138 V = DV[2 ]
138139 function $ adj(:: NoRData )
139- $ pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing ) )
140+ $ pb(dA, A, DV, dD )
140141 MatrixAlgebraKit. zero!(dD)
141142 return NoRData(), NoRData(), NoRData(), NoRData()
142143 end
@@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
153154 output_codual = CoDual(output, Mooncake. zero_tangent(output))
154155 function $ adj(:: NoRData )
155156 D, dD = arrayify(output_codual)
156- $ pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing ) )
157+ $ pb(dA, A, DV, dD )
157158 MatrixAlgebraKit. zero!(dD)
158159 return NoRData(), NoRData(), NoRData()
159160 end
@@ -275,7 +276,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
275276 U, nS, Vᴴ = svd_compact(A, Mooncake. primal(alg_dalg))
276277 copy!(S, diagview(nS))
277278 function svd_vals_adjoint(:: NoRData )
278- svd_pullback !(dA, A, (U, Diagonal(S) , Vᴴ), ( nothing , Diagonal(dS), nothing ) )
279+ svd_vals_pullback !(dA, A, (U, nS , Vᴴ), dS )
279280 MatrixAlgebraKit. zero!(dS)
280281 return NoRData(), NoRData(), NoRData(), NoRData()
281282 end
@@ -294,7 +295,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
294295 S_codual = CoDual(diagview(S), Mooncake. fdata(Mooncake. zero_tangent(diagview(S))))
295296 function svd_vals_adjoint(:: NoRData )
296297 S, dS = arrayify(S_codual)
297- svd_pullback !(dA, A, (U, Diagonal(S) , Vᴴ), ( nothing , Diagonal(dS), nothing ) )
298+ svd_vals_pullback !(dA, A, (U, S , Vᴴ), dS )
298299 MatrixAlgebraKit. zero!(dS)
299300 return NoRData(), NoRData(), NoRData()
300301 end
0 commit comments