@@ -6,10 +6,8 @@ A = rand(n, n);
6
6
dA = zeros (n, n);
7
7
b1 = rand (n);
8
8
db1 = zeros (n);
9
- b2 = rand (n);
10
- db2 = zeros (n);
11
9
12
- function f (A, b1, b2 ; alg = LUFactorization ())
10
+ function f (A, b1; alg = LUFactorization ())
13
11
prob = LinearProblem (A, b1)
14
12
15
13
sol1 = solve (prob, alg)
@@ -18,16 +16,15 @@ function f(A, b1, b2; alg = LUFactorization())
18
16
norm (s1)
19
17
end
20
18
21
- f (A, b1, b2 ) # Uses BLAS
19
+ f (A, b1) # Uses BLAS
22
20
23
- Enzyme. autodiff (Reverse, f, Duplicated (copy (A), dA), Duplicated (copy (b1), db1), Duplicated ( copy (b2), db2) )
21
+ Enzyme. autodiff (Reverse, f, Duplicated (copy (A), dA), Duplicated (copy (b1), db1))
24
22
25
23
dA2 = FiniteDiff. finite_difference_gradient (x-> f (x,b1, b2), copy (A))
26
24
db12 = FiniteDiff. finite_difference_gradient (x-> f (A,x, b2), copy (b1))
27
25
28
26
@test dA ≈ dA2
29
27
@test db1 ≈ db12
30
- @test db2 == zeros (4 )
31
28
32
29
A = rand (n, n);
33
30
dA = zeros (n, n);
@@ -36,9 +33,6 @@ b1 = rand(n);
36
33
db1 = zeros (n);
37
34
db12 = zeros (n);
38
35
39
- b2 = rand (n);
40
- db2 = zeros (n);
41
- db22 = zeros (n);
42
-
43
- @test_broken Enzyme. autodiff (Reverse, f, Duplicated (copy (A), dA), BatchDuplicated (copy (b1), (db1, db12)), BatchDuplicated (copy (b2), (db2, db22)))
44
- @test_broken Enzyme. autodiff (Reverse, f, BatchDuplicated (copy (A), (dA, dA2)), Duplicated (copy (b1), db1), Duplicated (copy (b2), db2))
36
+ # This is not legal, all args need to be batch'd at the same size
37
+ @test_broken Enzyme. autodiff (Reverse, f, Duplicated (copy (A), dA), BatchDuplicated (copy (b1), (db1, db12)))
38
+ @test_broken Enzyme. autodiff (Reverse, f, BatchDuplicated (copy (A), (dA, dA2)), Duplicated (copy (b1), db1))
0 commit comments