Skip to content

Commit cbb5f1d

Browse files
fix multiple solve handling
1 parent 3b39753 commit cbb5f1d

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8383
(dr.u for dr in dres)
8484
end
8585

86-
cache = (res, resvals, linsolve.val)
86+
cache = (res, resvals)
8787
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
8888
end
8989

9090
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
91-
y, dys, _linsolve = cache
91+
y, dys = cache
92+
_linsolve = linsolve.val
9293

9394
@assert !(typeof(linsolve) <: Const)
9495
@assert !(typeof(linsolve) <: Active)
@@ -110,8 +111,8 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
110111
end
111112

112113
for (dA, db, dy) in zip(dAs, dbs, dys)
113-
z = if linsolve.cacheval isa Factorization
114-
linsolve.cacheval' \ dy
114+
z = if _linsolve.cacheval isa Factorization
115+
_linsolve.cacheval' \ dy
115116
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
116117
linsolve.cacheval[1]' \ dy
117118
elseif linsolve.alg isa AbstractKrylovSubspaceMethod

0 commit comments

Comments
 (0)