Skip to content

Commit f5aefbe

Browse files
committed
test: enable runtime activity for now
1 parent ef14bab commit f5aefbe

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

test/enzyme.jl

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -177,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
177177
A = rand(n, n);
178178
dA = zeros(n, n);
179179
b1 = rand(n);
180-
for alg in (
180+
181+
function fnice(A, b, alg)
182+
prob = LinearProblem(A, b)
183+
sol1 = solve(prob, alg)
184+
return sum(sol1.u)
185+
end
186+
187+
@testset for alg in (
181188
LUFactorization(),
182189
RFLUFactorization() # KrylovJL_GMRES(), fails
183190
)
184-
@show alg
185-
function fb(b)
186-
prob = LinearProblem(A, b)
187-
188-
sol1 = solve(prob, alg)
191+
fb_closure = b -> fnice(A, b, alg)
189192

190-
sum(sol1.u)
191-
end
192-
fb(b1)
193-
194-
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
193+
fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
195194
@show fd_jac
196195

197196
en_jac = map(onehot(b1)) do db1
198-
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
199-
eres[1]
197+
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
198+
Const(A), Duplicated(b1, db1), Const(alg)))
200199
end |> collect
201200
@show en_jac
202201

203202
@test en_jacfd_jac rtol=1e-4
204203

205-
function fA(A)
206-
prob = LinearProblem(A, b1)
207-
208-
sol1 = solve(prob, alg)
204+
fA_closure = A -> fnice(A, b1, alg)
209205

210-
sum(sol1.u)
211-
end
212-
fA(A)
213-
214-
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
206+
fd_jac = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
215207
@show fd_jac
216208

217209
en_jac = map(onehot(A)) do dA
218-
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
219-
eres[1]
220-
end |> collect
210+
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
211+
Duplicated(A, dA), Const(b1), Const(alg)))
212+
end |> collect |> (x -> reshape(x, n, n))
221213
@show en_jac
222214

223215
@test en_jacfd_jac rtol=1e-4

0 commit comments

Comments
 (0)