Skip to content

Commit 9f8d18f

Browse files
wsmosesChrisRackauckas
authored andcommitted
Add actual file
1 parent bb6d623 commit 9f8d18f

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module LinearSolveEnzymeExt
2+
3+
using LinearSolve
4+
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
5+
6+
7+
using Enzyme
8+
9+
using EnzymeCore
10+
11+
# y=inv(A) B
12+
# dA −= z y^T
13+
# dB += z, where z = inv(A^T) dy
14+
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
15+
res = func.val(prob.val, alg.val; kwargs...)
16+
dres = deepcopy(res)
17+
dres.u .= 0
18+
cache = (copy(prob.val.A), res, dres.u)
19+
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache)
20+
end
21+
22+
function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
23+
A, y, dy = cache
24+
25+
dA = prob.dval.A
26+
db = prob.dval.b
27+
28+
invprob = LinearProblem(transpose(A), dy)
29+
30+
z = func.val(invprob, alg; kwargs...)
31+
32+
dA .-= z * transpose(y)
33+
db .+= z
34+
dy .= 0
35+
return (nothing, nothing)
36+
end
37+
38+
end

0 commit comments

Comments
 (0)