Skip to content

Commit f9b0784

Browse files
committed
simplify test
1 parent 9d19db2 commit f9b0784

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

test/enzyme.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ A = rand(n, n);
66
dA = zeros(n, n);
77
b1 = rand(n);
88
db1 = zeros(n);
9-
b2 = rand(n);
10-
db2 = zeros(n);
119

12-
function f(A, b1, b2; alg = LUFactorization())
10+
function f(A, b1; alg = LUFactorization())
1311
prob = LinearProblem(A, b1)
1412

1513
sol1 = solve(prob, alg)
@@ -18,16 +16,15 @@ function f(A, b1, b2; alg = LUFactorization())
1816
norm(s1)
1917
end
2018

21-
f(A, b1, b2) # Uses BLAS
19+
f(A, b1) # Uses BLAS
2220

23-
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
21+
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
2422

2523
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A))
2624
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1))
2725

2826
@test dA dA2
2927
@test db1 db12
30-
@test db2 == zeros(4)
3128

3229
A = rand(n, n);
3330
dA = zeros(n, n);
@@ -36,9 +33,6 @@ b1 = rand(n);
3633
db1 = zeros(n);
3734
db12 = zeros(n);
3835

39-
b2 = rand(n);
40-
db2 = zeros(n);
41-
db22 = zeros(n);
42-
43-
@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)), BatchDuplicated(copy(b2), (db2, db22)))
44-
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
36+
# This is not legal, all args need to be batch'd at the same size
37+
@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)))
38+
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1))

0 commit comments

Comments
 (0)