Skip to content

Commit 0669d08

Browse files
committed
Chainrules fixes
1 parent dc6a765 commit 0669d08

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ for eig in (:eig, :eigh)
9494
eig_pb = Symbol(eig, "_pullback")
9595
eig_t! = Symbol(eig, "_trunc!")
9696
eig_t_pb = Symbol(eig, "_trunc_pullback")
97+
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
98+
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
9799
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
100+
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
98101
eig_v = Symbol(eig, "_vals")
99102
eig_v! = Symbol(eig_v, "!")
100103
eig_v_pb = Symbol(eig_v, "_pullback")
@@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
136139
end
137140
return $eig_t_pb
138141
end
142+
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
143+
Ac = copy_input($eig_f, A)
144+
DV = $(eig_f!)(Ac, DV, alg.alg)
145+
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
146+
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
147+
end
148+
function $(_make_eig_t_ne_pb)(A, DV, ind)
149+
function $eig_t_ne_pb(ΔDV)
150+
ΔA = zero(A)
151+
ΔD, ΔV = ΔDV
152+
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
153+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
154+
end
155+
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
156+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
157+
end
158+
return $eig_t_ne_pb
159+
end
139160
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
140161
DV = $eig_f(A, alg)
141162
function $eig_v_pb(ΔD)

test/testsuite/ad_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function ad_lq_full_setup(A)
138138
Q1 = view(Q, 1:minmn, 1:n)
139139
ΔQ = randn!(similar(A, T, n, n))
140140
ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
141-
ΔQ2 = (ΔQ2 * Q1') * Q1
141+
mul!(ΔQ2, ΔQ2 * Q1', Q1)
142142
ΔL = randn!(similar(A, T, m, n))
143143
return (L, Q), (ΔL, ΔQ)
144144
end

0 commit comments

Comments
 (0)