Skip to content

Commit 9273a20

Browse files
committed
Extend
1 parent a08386d commit 9273a20

File tree

1 file changed

+69
-16
lines changed

1 file changed

+69
-16
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,84 @@ using Enzyme
88

99
using EnzymeCore
1010

11+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
12+
res = func.val(prob.val, alg.val; kwargs...)
13+
dres = if EnzymeRules.width(config) == 1
14+
func.val(prob.dval, alg.val; kwargs...)
15+
else
16+
(func.val(dval, alg.val; kwargs...) for dval in prob.dval)
17+
end
18+
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing)
19+
end
20+
21+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
22+
return (nothing, nothing)
23+
end
24+
1125
# y=inv(A) B
1226
# dA −= z y^T
1327
# dB += z, where z = inv(A^T) dy
14-
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
15-
res = func.val(prob.val, alg.val; kwargs...)
16-
dres = deepcopy(res)
17-
dres.u .= 0
18-
cache = (copy(prob.val.A), res, dres.u)
19-
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache)
28+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
29+
res = func.val(linsolve.val; kwargs...)
30+
dres = if EnzymeRules.width(config) == 1
31+
deepcopy(res)
32+
else
33+
(deepcopy(res) for dval in linsolve.dval)
34+
end
35+
36+
if EnzymeRules.width(config) == 1
37+
dres.u .= 0
38+
else
39+
for dr in dres
40+
dr.u .= 0
41+
end
42+
end
43+
44+
resvals = if EnzymeRules.width(config) == 1
45+
dres.u
46+
else
47+
(dr.u for dr in dres)
48+
end
49+
50+
cache = (copy(linsolve.val.A), res, resvals)
51+
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
2052
end
2153

22-
function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
23-
A, y, dy = cache
54+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
55+
A, y, dys = cache
2456

25-
dA = prob.dval.A
26-
db = prob.dval.b
57+
@assert !(typeof(linsolve) <: Const)
58+
@assert !(typeof(linsolve) <: Active)
2759

28-
invprob = LinearProblem(transpose(A), dy)
60+
if EnzymeRules.width(config) == 1
61+
dys = (dys,)
62+
end
2963

30-
z = func.val(invprob, alg; kwargs...)
64+
dAs = if EnzymeRules.width(config) == 1
65+
(linsolve.dval.A,)
66+
else
67+
(dval.A for dval in linsolve.dval)
68+
end
3169

32-
dA .-= z * transpose(y)
33-
db .+= z
34-
dy .= 0
35-
return (nothing, nothing)
70+
dbs = if EnzymeRules.width(config) == 1
71+
(linsolve.dval.b,)
72+
else
73+
(dval.b for dval in linsolve.dval)
74+
end
75+
76+
for (dA, db, dy) in zip(dAs, dbs, dys)
77+
invprob = LinearSolve.LinearProblem(transpose(A), dy)
78+
z = solve(invprob;
79+
abstol = linsolve.val.abstol,
80+
reltol = linsolve.val.reltol,
81+
verbose = linsolve.val.verbose)
82+
83+
dA .-= z * transpose(y)
84+
db .+= z
85+
dy .= eltype(dy)(0)
86+
end
87+
88+
return (nothing,)
3689
end
3790

3891
end

0 commit comments

Comments
 (0)