Skip to content

Commit c14aeb2

Browse files
sharanryChrisRackauckas
authored andcommitted
Add tests for other algs and handle cases of algs currently not supported
1 parent 60afb5d commit c14aeb2

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
3333
if RT <: Const
3434
return res
3535
end
36-
36+
if linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
37+
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
38+
end
3739
b = deepcopy(linsolve.val.b)
3840

3941
db = linsolve.dval.b

test/enzyme.jl

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Enzyme, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
33
using FiniteDiff
4+
using SafeTestsets
45

56
n = 4
67
A = rand(n, n);
@@ -164,46 +165,53 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
164165
@test db2 ≈ db22
165166
=#
166167

167-
168168
A = rand(n, n);
169169
dA = zeros(n, n);
170170
b1 = rand(n);
171-
function fb(b; alg = LUFactorization())
172-
prob = LinearProblem(A, b)
171+
for alg in (
172+
LUFactorization(),
173+
RFLUFactorization(),
174+
# KrylovJL_GMRES(), fails
175+
)
176+
alg_str = string(alg)
177+
@show alg_str
178+
function fb(b)
179+
prob = LinearProblem(A, b)
173180

174-
sol1 = solve(prob, alg)
181+
sol1 = solve(prob, alg)
175182

176-
sum(sol1.u)
177-
end
178-
fb(b1)
183+
sum(sol1.u)
184+
end
185+
fb(b1)
179186

180-
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
181-
@show fd_jac
187+
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
188+
@show fd_jac
182189

183-
en_jac = map(onehot(b1)) do db1
184-
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
185-
eres[1]
186-
end |> collect
187-
@show en_jac
190+
en_jac = map(onehot(b1)) do db1
191+
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
192+
eres[1]
193+
end |> collect
194+
@show en_jac
188195

189-
@test en_jac fd_jac rtol=1e-6
196+
@test en_jac fd_jac rtol=1e-6
190197

191-
function fA(A; alg = LUFactorization())
192-
prob = LinearProblem(A, b1)
198+
function fA(A)
199+
prob = LinearProblem(A, b1)
193200

194-
sol1 = solve(prob, alg)
201+
sol1 = solve(prob, alg)
195202

196-
sum(sol1.u)
197-
end
198-
fA(A)
203+
sum(sol1.u)
204+
end
205+
fA(A)
199206

200-
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
201-
@show fd_jac
207+
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
208+
@show fd_jac
202209

203-
en_jac = map(onehot(A)) do dA
204-
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
205-
eres[1]
206-
end |> collect
207-
@show en_jac
210+
en_jac = map(onehot(A)) do dA
211+
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
212+
eres[1]
213+
end |> collect
214+
@show en_jac
208215

209-
@test en_jac fd_jac rtol=1e-6
216+
@test en_jac fd_jac rtol=1e-6
217+
end

0 commit comments

Comments
 (0)