@@ -94,7 +94,10 @@ for eig in (:eig, :eigh)
9494 eig_pb = Symbol(eig, " _pullback" )
9595 eig_t! = Symbol(eig, " _trunc!" )
9696 eig_t_pb = Symbol(eig, " _trunc_pullback" )
97+ eig_t_ne! = Symbol(eig, " _trunc_no_error!" )
98+ eig_t_ne_pb = Symbol(eig, " _trunc_no_error_pullback" )
9799 _make_eig_t_pb = Symbol(" _make_" , eig_t_pb)
100+ _make_eig_t_ne_pb = Symbol(" _make_" , eig_t_ne_pb)
98101 eig_v = Symbol(eig, " _vals" )
99102 eig_v! = Symbol(eig_v, " !" )
100103 eig_v_pb = Symbol(eig_v, " _pullback" )
@@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
136139 end
137140 return $ eig_t_pb
138141 end
142+ function ChainRulesCore. rrule(:: typeof ($ eig_t_ne!), A, DV, alg:: TruncatedAlgorithm )
143+ Ac = copy_input($ eig_f, A)
144+ DV = $ (eig_f!)(Ac, DV, alg. alg)
145+ DV′, ind = MatrixAlgebraKit. truncate($ eig_t!, DV, alg. trunc)
146+ return DV′, $ (_make_eig_t_ne_pb)(A, DV, ind)
147+ end
148+ function $ (_make_eig_t_ne_pb)(A, DV, ind)
149+ function $ eig_t_ne_pb(ΔDV)
150+ ΔA = zero(A)
151+ ΔD, ΔV = ΔDV
152+ MatrixAlgebraKit.$ eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
153+ return NoTangent(), ΔA, ZeroTangent(), NoTangent()
154+ end
155+ function $ eig_t_ne_pb(:: Tuple{ZeroTangent, ZeroTangent, ZeroTangent} ) # is this extra definition useful?
156+ return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
157+ end
158+ return $ eig_t_ne_pb
159+ end
139160 function ChainRulesCore. rrule(:: typeof ($ eig_v!), A, D, alg)
140161 DV = $ eig_f(A, alg)
141162 function $ eig_v_pb(ΔD)
0 commit comments