Skip to content

Commit d5e64bd

Browse files
sharanryChrisRackauckas
authored andcommitted
Add forward enzyme rules for init and solve
1 parent b9da6ac commit d5e64bd

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ DocStringExtensions = "0.9"
6666
EnumX = "1"
6767
EnzymeCore = "0.6"
6868
FastLapackInterface = "2"
69+
EnzymeTestUtils = "0.1"
6970
GPUArraysCore = "0.1"
7071
HYPRE = "1.4.0"
7172
InteractiveUtils = "1.6"
@@ -96,6 +97,7 @@ julia = "1.9"
9697
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
9798
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
9899
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
100+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
99101
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
100102
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
101103
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
@@ -114,4 +116,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
114116
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
115117

116118
[targets]
117-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"]
119+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "EnzymeTestUtils", "FiniteDiff", "BandedMatrices"]

ext/LinearSolveEnzymeExt.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,47 @@ using Enzyme
99

1010
using EnzymeCore
1111

12+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
13+
@assert !(prob isa Const)
14+
res = func.val(prob.val, alg.val; kwargs...)
15+
if RT <: Const
16+
return res
17+
end
18+
dres = func.val(prob.dval, alg.val; kwargs...)
19+
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
20+
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
21+
if RT <: DuplicatedNoNeed
22+
return dres
23+
elseif RT <: Duplicated
24+
return Duplicated(res, dres)
25+
end
26+
end
27+
28+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
29+
@assert !(linsolve isa Const)
30+
31+
A = deepcopy(linsolve.val.A) #mutates after function is applied
32+
res = func.val(linsolve.val; kwargs...)
33+
34+
if RT <: Const
35+
return res
36+
end
37+
38+
dres = deepcopy(res)
39+
invA = inv(A)
40+
db = linsolve.dval.b
41+
dA = linsolve.dval.A
42+
dres.u .= invA * (db - dA * res.u)
43+
44+
if RT <: DuplicatedNoNeed
45+
return dres
46+
elseif RT <: Duplicated
47+
return Duplicated(res, dres)
48+
end
49+
50+
return Duplicated(res, dres)
51+
end
52+
1253
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
1354
res = func.val(prob.val, alg.val; kwargs...)
1455
dres = if EnzymeRules.width(config) == 1

test/enzyme.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Enzyme, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
3+
using FiniteDiff
34

45
n = 4
56
A = rand(n, n);
@@ -161,4 +162,59 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
161162
@test dA ≈ dA2 atol=5e-5
162163
@test db1 ≈ db12
163164
@test db2 ≈ db22
164-
=#
165+
=#
166+
167+
168+
A = rand(n, n);
169+
dA = zeros(n, n);
170+
b1 = rand(n);
171+
function fb(b; alg = LUFactorization())
172+
prob = LinearProblem(A, b)
173+
174+
sol1 = solve(prob, alg)
175+
176+
sum(sol1.u)
177+
end
178+
fb(b1)
179+
180+
manual_jac = map(onehot(b1)) do db
181+
y = A \ b1
182+
sum(inv(A) * (db - dA*y))
183+
end |> collect
184+
@show manual_jac
185+
186+
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
187+
@show fd_jac
188+
189+
en_jac = map(onehot(b1)) do db1
190+
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
191+
eres[1]
192+
end |> collect
193+
@show en_jac
194+
195+
@test_broken en_jac manual_jac
196+
@test_broken en_jac fd_jac
197+
198+
function fA(A; alg = LUFactorization())
199+
prob = LinearProblem(A, b1)
200+
201+
sol1 = solve(prob, alg)
202+
203+
sum(sol1.u)
204+
end
205+
fA(A)
206+
207+
manual_jac = map(onehot(A)) do dA
208+
y = A \ b1
209+
sum(inv(A) * (db1 - dA*y))
210+
end |> collect
211+
212+
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
213+
214+
en_jac = map(onehot(A)) do dA
215+
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
216+
eres[1]
217+
end |> collect
218+
219+
@test_broken en_jac manual_jac
220+
@test_broken en_jac fd_jac

0 commit comments

Comments
 (0)