You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ext/LinearSolveEnzymeExt.jl
+47Lines changed: 47 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -9,6 +9,53 @@ using Enzyme
9
9
10
10
using EnzymeCore
11
11
12
+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
13
+
@assert!(prob isa Const)
14
+
res = func.val(prob.val, alg.val; kwargs...)
15
+
if RT <:Const
16
+
return res
17
+
end
18
+
dres = func.val(prob.dval, alg.val; kwargs...)
19
+
dres.b .= res.b == dres.b ?zero(dres.b) : dres.b
20
+
dres.A .= res.A == dres.A ?zero(dres.A) : dres.A
21
+
if RT <:DuplicatedNoNeed
22
+
return dres
23
+
elseif RT <:Duplicated
24
+
returnDuplicated(res, dres)
25
+
end
26
+
error("Unsupported return type $RT")
27
+
end
28
+
29
+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <:LinearSolve.LinearCache}
30
+
@assert!(linsolve isa Const)
31
+
32
+
res = func.val(linsolve.val; kwargs...)
33
+
34
+
if RT <:Const
35
+
return res
36
+
end
37
+
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
38
+
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
39
+
end
40
+
b =deepcopy(linsolve.val.b)
41
+
42
+
db = linsolve.dval.b
43
+
dA = linsolve.dval.A
44
+
45
+
linsolve.val.b = db - dA * res.u
46
+
dres = func.val(linsolve.val; kwargs...)
47
+
48
+
linsolve.val.b = b
49
+
50
+
if RT <:DuplicatedNoNeed
51
+
return dres
52
+
elseif RT <:Duplicated
53
+
returnDuplicated(res, dres)
54
+
end
55
+
56
+
returnDuplicated(res, dres)
57
+
end
58
+
12
59
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
0 commit comments