Skip to content

Commit 7ec5a82

Browse files
committed
add adjoints test item
1 parent aa4bdc1 commit 7ec5a82

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

test/adjoint_tests.jl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,41 @@
1-
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test
1+
@testitem "Adjoint Tests" tags = [:adjoint] begin
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test
23

3-
ff(u, p) = u .^ 2 .- p
4+
ff(u, p) = u .^ 2 .- p
45

5-
function solve_nlprob(p)
6-
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
7-
sol = solve(prob, NewtonRaphson())
8-
res = sol isa AbstractArray ? sol : sol.u
9-
return sum(abs2, res)
6+
function solve_nlprob(p)
7+
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
8+
sol = solve(prob, NewtonRaphson())
9+
res = sol isa AbstractArray ? sol : sol.u
10+
return sum(abs2, res)
11+
end
12+
13+
p = [3.0, 2.0]
14+
15+
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
16+
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
17+
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
18+
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
19+
∂p_enzyme = Enzyme.gradient(Enzyme.Reverse, solve_nlprob, p)[1]
20+
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
21+
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff ∂p_enzyme
1022
end
1123

12-
p = [3.0, 2.0]
24+
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
25+
using ForwardDiff, Zygote, BracketingNonlinearSolve
26+
27+
ff(u, p) = u^2 .- p[1]
1328

14-
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
15-
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
16-
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
17-
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
18-
@test ∂p_zygote ∂p_tracker ∂p_reversediff
19-
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff
29+
function solve_nlprob(p)
30+
prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p)
31+
sol = solve(prob, Bisection())
32+
res = sol isa AbstractArray ? sol : sol.u
33+
return sum(abs2, res)
34+
end
35+
36+
p = [2.0, 2.0]
37+
38+
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
39+
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
40+
@test ∂p_zygote ∂p_forwarddiff
41+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ if GROUP == "all" || GROUP == "cuda"
4141
push!(EXTRA_PKGS, Pkg.PackageSpec("CUDA"))
4242
end
4343
end
44+
45+
(GROUP == "all" || GROUP == "adjoint") && Pkg.add(["SciMLSensitivity"])
46+
4447
length(EXTRA_PKGS) 1 && Pkg.add(EXTRA_PKGS)
4548

4649
# Use sequential execution for wrapper tests to avoid parallel initialization issues

0 commit comments

Comments
 (0)