1
- using Enzyme, FiniteDiff
1
+ using Enzyme, ForwardDiff
2
2
using LinearSolve, LinearAlgebra, Test
3
3
4
4
n = 4
@@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS
20
20
21
21
Enzyme. autodiff (Reverse, f, Duplicated (copy (A), dA), Duplicated (copy (b1), db1))
22
22
23
- dA2 = FiniteDiff . finite_difference_gradient (x-> f (x,b1 ), copy (A))
24
- db12 = FiniteDiff . finite_difference_gradient (x-> f (A ,x), copy (b1))
23
+ dA2 = ForwardDiff . gradient (x-> f (x,eltype (x).(b1) ), copy (A))
24
+ db12 = ForwardDiff . gradient (x-> f (eltype (x).(A) ,x), copy (b1))
25
25
26
26
@test dA ≈ dA2
27
27
@test db1 ≈ db12
@@ -35,8 +35,8 @@ db12 = zeros(n);
35
35
36
36
@test_broken Enzyme. autodiff (Reverse, f, BatchDuplicated (copy (A), (dA, dA2)), BatchDuplicated (copy (b1), (db1, db12)))
37
37
38
- dA_2 = FiniteDiff . finite_difference_gradient (x-> f (x,b1 ), copy (A))
39
- db1_2 = FiniteDiff . finite_difference_gradient (x-> f (A ,x), copy (b1))
38
+ dA2 = ForwardDiff . gradient (x-> f (x,eltype (x).(b1) ), copy (A))
39
+ db12 = ForwardDiff . gradient (x-> f (eltype (x).(A) ,x), copy (b1))
40
40
41
41
@test_broken dA ≈ dA_2
42
42
@test_broken dA2 ≈ dA_2
@@ -45,9 +45,8 @@ db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
45
45
46
46
function f (A, b1, b2; alg = LUFactorization ())
47
47
prob = LinearProblem (A, b1)
48
-
49
48
cache = init (prob, alg)
50
- s1 = solve! (cache). u
49
+ s1 = copy ( solve! (cache). u)
51
50
cache. b = b2
52
51
s2 = solve! (cache). u
53
52
norm (s1 + s2)
@@ -60,11 +59,46 @@ db1 = zeros(n);
60
59
b2 = rand (n);
61
60
db2 = zeros (n);
62
61
62
+ f (A, b1, b2)
63
63
Enzyme. autodiff (Reverse, f, Duplicated (copy (A), dA), Duplicated (copy (b1), db1), Duplicated (copy (b2), db2))
64
64
65
- dA2 = FiniteDiff. finite_difference_gradient (x-> f (x,b1,b2), copy (A))
66
- db12 = FiniteDiff. finite_difference_gradient (x-> f (A,x,b2), copy (b1))
67
- db22 = FiniteDiff. finite_difference_gradient (x-> f (A,b1,x), copy (b2))
65
+ dA2 = ForwardDiff. gradient (x-> f (x,eltype (x).(b1),eltype (x).(b2)), copy (A))
66
+ db12 = ForwardDiff. gradient (x-> f (eltype (x).(A),x,eltype (x).(b2)), copy (b1))
67
+ db22 = ForwardDiff. gradient (x-> f (eltype (x).(A),eltype (x).(b1),x), copy (b2))
68
+
69
+ @test dA ≈ dA2
70
+ @test db1 ≈ db12
71
+ @test db2 ≈ db22
72
+
73
+ function f2 (A, b1, b2; alg = RFLUFactorization ())
74
+ prob = LinearProblem (A, b1)
75
+ cache = init (prob, alg)
76
+ s1 = copy (solve! (cache). u)
77
+ cache. b = b2
78
+ s2 = solve! (cache). u
79
+ norm (s1 + s2)
80
+ end
81
+
82
+ f2 (A, b1, b2)
83
+ dA = zeros (n, n);
84
+ db1 = zeros (n);
85
+ db2 = zeros (n);
86
+ Enzyme. autodiff (Reverse, f2, Duplicated (copy (A), dA), Duplicated (copy (b1), db1), Duplicated (copy (b2), db2))
87
+
88
+ @test dA ≈ dA2
89
+ @test db1 ≈ db12
90
+ @test db2 ≈ db22
91
+
92
+ function f3 (A, b1, b2; alg = KrylovJL_GMRES ())
93
+ prob = LinearProblem (A, b1)
94
+ cache = init (prob, alg)
95
+ s1 = solve! (cache). u
96
+ cache. b = b2
97
+ s2 = solve! (cache). u
98
+ norm (s1 + s2)
99
+ end
100
+
101
+ Enzyme. autodiff (Reverse, f3, Duplicated (copy (A), dA), Duplicated (copy (b1), db1), Duplicated (copy (b2), db2))
68
102
69
103
@test dA ≈ dA2 atol= 5e-5
70
104
@test db1 ≈ db12
0 commit comments