Skip to content

Commit c2ad2db

Browse files
push batch test updates
1 parent b0d228d commit c2ad2db

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

test/enzyme.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,20 @@ b1 = rand(n);
3333
db1 = zeros(n);
3434
db12 = zeros(n);
3535

36-
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
36+
function fbatch(y, A, b1; alg = LUFactorization())
37+
prob = LinearProblem(A, b1)
38+
39+
sol1 = solve(prob, alg)
40+
41+
s1 = sol1.u
42+
y[1] = norm(s1)
43+
nothing
44+
end
45+
46+
y = [0.0]
47+
dy1 = [1.0]
48+
dy2 = [1.0]
49+
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
3750

3851
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
3952
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
@@ -92,7 +105,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
92105
function f3(A, b1, b2; alg = KrylovJL_GMRES())
93106
prob = LinearProblem(A, b1)
94107
cache = init(prob, alg)
95-
s1 = solve!(cache).u
108+
s1 = copy(solve!(cache).u)
96109
cache.b = b2
97110
s2 = solve!(cache).u
98111
norm(s1 + s2)

0 commit comments

Comments
 (0)