Skip to content

Commit 9d19db2

Browse files
committed
Cache before LU in place
1 parent bb93d68 commit 9d19db2

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ end
6060
# dA −= z y^T
6161
# dB += z, where z = inv(A^T) dy
6262
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
63+
@assert linsolve.val.isfresh
64+
A_cache = copy(linsolve.val.A)
65+
6366
res = func.val(linsolve.val; kwargs...)
67+
6468
dres = if EnzymeRules.width(config) == 1
6569
deepcopy(res)
6670
else
@@ -81,7 +85,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8185
(dr.u for dr in dres)
8286
end
8387

84-
cache = (copy(linsolve.val.A), res, resvals)
88+
cache = (A_cache, res, resvals)
8589
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
8690
end
8791

0 commit comments

Comments
 (0)