Skip to content

Commit a08386d

Browse files
add a test for Enzyme rule correctness
1 parent ce7ffc0 commit a08386d

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ julia = "1.6"
8181

8282
[extras]
8383
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
84+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
85+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
8486
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8587
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
8688
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -98,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
98100
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99101

100102
[targets]
101-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
103+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]

test/enzyme.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using Enzyme, FiniteDiff
2+
using LinearSolve, LinearAlgebra, Test
3+
4+
n = 4
5+
A = rand(n, n);
6+
dA = zeros(n, n);
7+
b1 = rand(n);
8+
db1 = zeros(n);
9+
b2 = rand(n);
10+
db2 = zeros(n);
11+
12+
function f(A, b1, b2; alg = LUFactorization())
13+
prob = LinearProblem(A, b1)
14+
15+
sol1 = solve(prob, alg)
16+
17+
s1 = sol1.u
18+
norm(s1)
19+
end
20+
21+
f(A, b1, b2) # Uses BLAS
22+
23+
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
24+
25+
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A))
26+
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1))
27+
28+
@test dA dA2
29+
@test db1 db12
30+
@test db2 == zeros(4)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ if GROUP == "All" || GROUP == "Core"
1313
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
1414
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
1515
@time @safetestset "Default Alg Tests" include("default_algs.jl")
16+
VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
1617
@time @safetestset "Traits" include("traits.jl")
1718
end
1819

0 commit comments

Comments
 (0)