Skip to content

Commit df7362c

Browse files
committed
also update mooncake rules
1 parent b0eec92 commit df7362c

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ using MatrixAlgebraKit
66
using MatrixAlgebraKit: inv_safe, diagview, copy_input
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9-
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
9+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
10+
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1011
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11-
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
12+
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1213
using LinearAlgebra
1314

1415

@@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
122123
end
123124

124125
for (f!, f, f_full, pb, adj) in (
125-
(:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint),
126-
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint),
126+
(:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint),
127+
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
127128
)
128129
@eval begin
129130
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
136137
copy!(D, diagview(DV[1]))
137138
V = DV[2]
138139
function $adj(::NoRData)
139-
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
140+
$pb(dA, A, DV, dD)
140141
MatrixAlgebraKit.zero!(dD)
141142
return NoRData(), NoRData(), NoRData(), NoRData()
142143
end
@@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
153154
output_codual = CoDual(output, Mooncake.zero_tangent(output))
154155
function $adj(::NoRData)
155156
D, dD = arrayify(output_codual)
156-
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
157+
$pb(dA, A, DV, dD)
157158
MatrixAlgebraKit.zero!(dD)
158159
return NoRData(), NoRData(), NoRData()
159160
end
@@ -275,7 +276,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
275276
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
276277
copy!(S, diagview(nS))
277278
function svd_vals_adjoint(::NoRData)
278-
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
279+
svd_vals_pullback!(dA, A, (U, nS, Vᴴ), dS)
279280
MatrixAlgebraKit.zero!(dS)
280281
return NoRData(), NoRData(), NoRData(), NoRData()
281282
end
@@ -294,7 +295,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
294295
S_codual = CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S))))
295296
function svd_vals_adjoint(::NoRData)
296297
S, dS = arrayify(S_codual)
297-
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
298+
svd_vals_pullback!(dA, A, (U, S, Vᴴ), dS)
298299
MatrixAlgebraKit.zero!(dS)
299300
return NoRData(), NoRData(), NoRData()
300301
end

0 commit comments

Comments
 (0)