@@ -120,10 +120,9 @@ for eig in (:eig, :eigh)
120120 alg:: TruncatedAlgorithm
121121 )
122122 Ac = copy_input($ eig_f, A)
123- D, V = $ (eig_f!)(Ac, DV, alg. alg)
124- ind = findtruncated(diagview(D), alg. trunc)
125- return (Diagonal(diagview(D)[ind]), V[:, ind]),
126- $ (_make_eig_t_pb)(A, (D, V), ind)
123+ DV = $ (eig_f!)(Ac, DV, alg. alg)
124+ DV′, ind = MatrixAlgebraKit. truncate($ eig_t!, DV, alg. trunc)
125+ return DV′, $ (_make_eig_t_pb)(A, DV, ind)
127126 end
128127 function $ (_make_eig_t_pb)(A, DV, ind)
129128 function $ eig_t_pb(ΔDV)
@@ -163,10 +162,9 @@ function ChainRulesCore.rrule(
163162 alg:: TruncatedAlgorithm
164163 )
165164 Ac = copy_input(svd_compact, A)
166- U, S, Vᴴ = svd_compact!(Ac, USVᴴ, alg. alg)
167- ind = findtruncated_svd(diagview(S), alg. trunc)
168- return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]),
169- _make_svd_trunc_pullback(A, (U, S, Vᴴ), ind)
165+ USVᴴ = svd_compact!(Ac, USVᴴ, alg. alg)
166+ USVᴴ′, ind = MatrixAlgebraKit. truncate(svd_trunc!, USVᴴ, alg. trunc)
167+ return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
170168end
171169function _make_svd_trunc_pullback(A, USVᴴ, ind)
172170 function svd_trunc_pullback(ΔUSVᴴ)
0 commit comments