Skip to content

Commit e433589

Browse files
committed
Fix inplace vals rules
1 parent c86ec1d commit e433589

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ for (f!, f, f_full, pb, adj) in (
139139
D, dD = arrayify(D_, dD_)
140140
# update primal
141141
DV = $f_full(A, Mooncake.primal(alg_dalg))
142-
output = copy(diagview(DV[1]))
142+
copy!(D, diagview(DV[1]))
143143
V = DV[2]
144144
function $adj(::Mooncake.NoRData)
145145
$pb(dA, A, (D, V), (dD, nothing))
146146
MatrixAlgebraKit.zero!(dD)
147147
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
148148
end
149-
return Mooncake.CoDual(output, dD_), $adj
149+
return D_dD, $adj
150150
end
151151
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
152152
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
@@ -288,13 +288,13 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Co
288288
A, dA = arrayify(A_, dA_)
289289
S, dS = arrayify(S_, dS_)
290290
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
291-
output = copy(diagview(nS))
291+
copy!(S, diagview(nS))
292292
function dsvd_vals_adjoint(::Mooncake.NoRData)
293293
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
294294
MatrixAlgebraKit.zero!(dS)
295295
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
296296
end
297-
return Mooncake.CoDual(output, dS), dsvd_vals_adjoint
297+
return S_dS, dsvd_vals_adjoint
298298
end
299299

300300
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}

0 commit comments

Comments
 (0)