Skip to content

Commit 9b540ba

Browse files
Merge pull request #416 from sharanry/enzyme_forward
Add forward enzyme rules for init and solve
2 parents b9da6ac + 0935919 commit 9b540ba

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,53 @@ using Enzyme
99

1010
using EnzymeCore
1111

12+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
13+
@assert !(prob isa Const)
14+
res = func.val(prob.val, alg.val; kwargs...)
15+
if RT <: Const
16+
return res
17+
end
18+
dres = func.val(prob.dval, alg.val; kwargs...)
19+
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
20+
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
21+
if RT <: DuplicatedNoNeed
22+
return dres
23+
elseif RT <: Duplicated
24+
return Duplicated(res, dres)
25+
end
26+
error("Unsupported return type $RT")
27+
end
28+
29+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
30+
@assert !(linsolve isa Const)
31+
32+
res = func.val(linsolve.val; kwargs...)
33+
34+
if RT <: Const
35+
return res
36+
end
37+
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
38+
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
39+
end
40+
b = deepcopy(linsolve.val.b)
41+
42+
db = linsolve.dval.b
43+
dA = linsolve.dval.A
44+
45+
linsolve.val.b = db - dA * res.u
46+
dres = func.val(linsolve.val; kwargs...)
47+
48+
linsolve.val.b = b
49+
50+
if RT <: DuplicatedNoNeed
51+
return dres
52+
elseif RT <: Duplicated
53+
return Duplicated(res, dres)
54+
end
55+
56+
return Duplicated(res, dres)
57+
end
58+
1259
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
1360
res = func.val(prob.val, alg.val; kwargs...)
1461
dres = if EnzymeRules.width(config) == 1

test/enzyme.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Enzyme, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
3+
using FiniteDiff
4+
using SafeTestsets
35

46
n = 4
57
A = rand(n, n);
@@ -161,4 +163,54 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
161163
@test dA ≈ dA2 atol=5e-5
162164
@test db1 ≈ db12
163165
@test db2 ≈ db22
164-
=#
166+
=#
167+
168+
A = rand(n, n);
169+
dA = zeros(n, n);
170+
b1 = rand(n);
171+
for alg in (
172+
LUFactorization(),
173+
RFLUFactorization(),
174+
# KrylovJL_GMRES(), fails
175+
)
176+
@show alg
177+
function fb(b)
178+
prob = LinearProblem(A, b)
179+
180+
sol1 = solve(prob, alg)
181+
182+
sum(sol1.u)
183+
end
184+
fb(b1)
185+
186+
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
187+
@show fd_jac
188+
189+
en_jac = map(onehot(b1)) do db1
190+
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
191+
eres[1]
192+
end |> collect
193+
@show en_jac
194+
195+
@test en_jac fd_jac rtol=1e-4
196+
197+
function fA(A)
198+
prob = LinearProblem(A, b1)
199+
200+
sol1 = solve(prob, alg)
201+
202+
sum(sol1.u)
203+
end
204+
fA(A)
205+
206+
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
207+
@show fd_jac
208+
209+
en_jac = map(onehot(A)) do dA
210+
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
211+
eres[1]
212+
end |> collect
213+
@show en_jac
214+
215+
@test en_jac fd_jac rtol=1e-4
216+
end

0 commit comments

Comments
 (0)