diff --git a/Project.toml b/Project.toml index 58db07d37..3cf059417 100644 --- a/Project.toml +++ b/Project.toml @@ -74,8 +74,8 @@ ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" -Enzyme = "0.11.15, 0.12, 0.13" -EnzymeCore = "0.6.5, 0.7, 0.8" +Enzyme = "0.13" +EnzymeCore = "0.8" FastAlmostBandedMatrices = "0.1" FastLapackInterface = "2" FiniteDiff = "2.22" diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index c8d89e874..84884c040 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -8,13 +8,17 @@ using Enzyme using EnzymeCore -function EnzymeCore.EnzymeRules.forward( +function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @assert !(prob isa Const) res = func.val(prob.val, alg.val; kwargs...) if RT <: Const - return res + if EnzymeRules.needs_primal(config) + return res + else + return nothing + end end dres = func.val(prob.dval, alg.val; kwargs...) dres.b .= res.b == dres.b ? zero(dres.b) : dres.b @@ -25,9 +29,19 @@ function EnzymeCore.EnzymeRules.forward( return Duplicated(res, dres) end error("Unsupported return type $RT") + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + Duplicated(res, dres) + elseif EnzymeRules.needs_shadow(config) + dres + elseif EnzymeRules.needs_primal(config) + res + else + nothing + end end -function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, +function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) @@ -35,7 +49,11 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, res = func.val(linsolve.val; kwargs...) if RT <: Const - return res + if EnzymeRules.needs_primal(config) + return res + else + return nothing + end end if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") @@ -50,13 +68,15 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, linsolve.val.b = b - if RT <: DuplicatedNoNeed - return dres - elseif RT <: Duplicated - return Duplicated(res, dres) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + Duplicated(res, dres) + elseif EnzymeRules.needs_shadow(config) + dres + elseif EnzymeRules.needs_primal(config) + res + else + nothing end - - return Duplicated(res, dres) end function EnzymeCore.EnzymeRules.augmented_primal(