@@ -139,14 +139,14 @@ for (f!, f, f_full, pb, adj) in (
139139 D, dD = arrayify(D_, dD_)
140140 # update primal
141141 DV = $ f_full(A, Mooncake. primal(alg_dalg))
142- output = copy( diagview(DV[1 ]))
142+ copy!(D, diagview(DV[1 ]))
143143 V = DV[2 ]
144144 function $ adj(:: Mooncake.NoRData )
145145 $ pb(dA, A, (D, V), (dD, nothing ))
146146 MatrixAlgebraKit. zero!(dD)
147147 return Mooncake. NoRData(), Mooncake. NoRData(), Mooncake. NoRData(), Mooncake. NoRData()
148148 end
149- return Mooncake . CoDual(output, dD_) , $ adj
149+ return D_dD , $ adj
150150 end
151151 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
152152 function Mooncake. rrule!!(:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual )
@@ -288,13 +288,13 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Co
288288 A, dA = arrayify(A_, dA_)
289289 S, dS = arrayify(S_, dS_)
290290 U, nS, Vᴴ = svd_compact(A, Mooncake. primal(alg_dalg))
291- output = copy( diagview(nS))
291+ copy!(S, diagview(nS))
292292 function dsvd_vals_adjoint(:: Mooncake.NoRData )
293293 svd_pullback!(dA, A, (U, S, Vᴴ), (nothing , dS, nothing ))
294294 MatrixAlgebraKit. zero!(dS)
295295 return Mooncake. NoRData(), Mooncake. NoRData(), Mooncake. NoRData(), Mooncake. NoRData()
296296 end
297- return Mooncake . CoDual(output, dS) , dsvd_vals_adjoint
297+ return S_dS , dsvd_vals_adjoint
298298end
299299
300300@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof(MatrixAlgebraKit. svd_vals), Any, MatrixAlgebraKit. AbstractAlgorithm}
0 commit comments