Skip to content

Commit 767b6b5

Browse files
committed
use truncate in pullbacks
1 parent 32edf68 commit 767b6b5

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,9 @@ for eig in (:eig, :eigh)
120120
alg::TruncatedAlgorithm
121121
)
122122
Ac = copy_input($eig_f, A)
123-
D, V = $(eig_f!)(Ac, DV, alg.alg)
124-
ind = findtruncated(diagview(D), alg.trunc)
125-
return (Diagonal(diagview(D)[ind]), V[:, ind]),
126-
$(_make_eig_t_pb)(A, (D, V), ind)
123+
DV = $(eig_f!)(Ac, DV, alg.alg)
124+
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
125+
return DV′, $(_make_eig_t_pb)(A, DV, ind)
127126
end
128127
function $(_make_eig_t_pb)(A, DV, ind)
129128
function $eig_t_pb(ΔDV)
@@ -163,10 +162,9 @@ function ChainRulesCore.rrule(
163162
alg::TruncatedAlgorithm
164163
)
165164
Ac = copy_input(svd_compact, A)
166-
U, S, Vᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
167-
ind = findtruncated_svd(diagview(S), alg.trunc)
168-
return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]),
169-
_make_svd_trunc_pullback(A, (U, S, Vᴴ), ind)
165+
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
166+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
167+
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
170168
end
171169
function _make_svd_trunc_pullback(A, USVᴴ, ind)
172170
function svd_trunc_pullback(ΔUSVᴴ)

0 commit comments

Comments
 (0)