diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index abd2232e1..e86518632 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -5,6 +5,8 @@ using LinearSolve.LinearAlgebra using EnzymeCore using EnzymeCore: EnzymeRules +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:LinearSolve.SciMLLinearSolveAlgorithm}) = true + function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @@ -223,10 +225,10 @@ function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod # Doesn't modify `A`, so it's safe to just reuse it invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy) - solve(invprob, _linearsolve.alg; - abstol = _linsolve.val.abstol, - reltol = _linsolve.val.reltol, - verbose = _linsolve.val.verbose) + solve(invprob, _linsolve.alg; + abstol = _linsolve.abstol, + reltol = _linsolve.reltol, + verbose = _linsolve.verbose) elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver LinearSolve.defaultalg_adjoint_eval(_linsolve, dy) else diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 6f8a7e244..6eb45d55f 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -327,7 +327,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) else cache.u = convert(typeof(cache.u), cacheval.x) end - - return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; + + return SciMLBase.build_linear_solution(alg, cache.u, Ref(resid), cache; iters = stats.niter, retcode, stats) end diff --git a/test/enzyme.jl b/test/enzyme.jl index d523036e5..f4f8dc64b 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -157,7 +157,6 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), @test db1 ≈ db12 @test db2 ≈ db22 -#= function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -167,12 +166,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES()) norm(s1 + s2) end -Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); +Enzyme.autodiff(set_runtime_activity(Reverse), f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# A = rand(n, n); dA = zeros(n, n);