Skip to content

Commit ac5693c

Browse files
committed
Adapt to pending Enzyme breaking change
1 parent 33911f6 commit ac5693c

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ ChainRulesCore = "1.22"
7474
ConcreteStructs = "0.2.3"
7575
DocStringExtensions = "0.9.3"
7676
EnumX = "1.0.4"
77-
Enzyme = "0.11.15, 0.12"
77+
Enzyme = "0.13"
7878
EnzymeCore = "0.6.5, 0.7"
7979
FastAlmostBandedMatrices = "0.1"
8080
FastLapackInterface = "2"

ext/LinearSolveEnzymeExt.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@ using Enzyme
88

99
using EnzymeCore
1010

11-
function EnzymeCore.EnzymeRules.forward(
11+
function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1},
1212
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
1313
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
1414
@assert !(prob isa Const)
1515
res = func.val(prob.val, alg.val; kwargs...)
1616
if RT <: Const
17-
return res
17+
if EnzymeRules.needs_primal(config)
18+
return res
19+
else
20+
return nothing
21+
end
1822
end
1923
dres = func.val(prob.dval, alg.val; kwargs...)
2024
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
@@ -25,17 +29,31 @@ function EnzymeCore.EnzymeRules.forward(
2529
return Duplicated(res, dres)
2630
end
2731
error("Unsupported return type $RT")
32+
33+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
34+
Duplicated(res, dres)
35+
elseif EnzymeRules.needs_shadow(config)
36+
dres
37+
elseif EnzymeRules.needs_primal(config)
38+
res
39+
else
40+
nothing
41+
end
2842
end
2943

30-
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
44+
function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
3145
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
3246
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
3347
@assert !(linsolve isa Const)
3448

3549
res = func.val(linsolve.val; kwargs...)
3650

3751
if RT <: Const
38-
return res
52+
if EnzymeRules.needs_primal(config)
53+
return res
54+
else
55+
return nothing
56+
end
3957
end
4058
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
4159
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
@@ -50,13 +68,15 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
5068

5169
linsolve.val.b = b
5270

53-
if RT <: DuplicatedNoNeed
54-
return dres
55-
elseif RT <: Duplicated
56-
return Duplicated(res, dres)
71+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
72+
Duplicated(res, dres)
73+
elseif EnzymeRules.needs_shadow(config)
74+
dres
75+
elseif EnzymeRules.needs_primal(config)
76+
res
77+
else
78+
nothing
5779
end
58-
59-
return Duplicated(res, dres)
6080
end
6181

6282
function EnzymeCore.EnzymeRules.augmented_primal(

0 commit comments

Comments
 (0)