Skip to content

Commit 34937a4

Browse files
committed
fix up adjoint tests
1 parent 7ec5a82 commit 34937a4

File tree

3 files changed

+6
-23
lines changed

3 files changed

+6
-23
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
137137
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
138138
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
139139
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
140+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
140141
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
141142
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
142143
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
@@ -154,6 +155,8 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
154155
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
155156
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
156157
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
158+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
159+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
157160
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
158161
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
159162
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
@@ -162,7 +165,8 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
162165
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
163166
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
164167
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
168+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
165169
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
166170

167171
[targets]
168-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
172+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLSensitivity", "Enzyme"]

test/adjoint_tests.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testitem "Adjoint Tests" tags = [:adjoint] begin
2-
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme
33

44
ff(u, p) = u .^ 2 .- p
55

@@ -20,22 +20,3 @@
2020
@test ∂p_zygote ∂p_tracker ∂p_reversediff ∂p_enzyme
2121
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff ∂p_enzyme
2222
end
23-
24-
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
25-
using ForwardDiff, Zygote, BracketingNonlinearSolve
26-
27-
ff(u, p) = u^2 .- p[1]
28-
29-
function solve_nlprob(p)
30-
prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p)
31-
sol = solve(prob, Bisection())
32-
res = sol isa AbstractArray ? sol : sol.u
33-
return sum(abs2, res)
34-
end
35-
36-
p = [2.0, 2.0]
37-
38-
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
39-
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
40-
@test ∂p_zygote ∂p_forwarddiff
41-
end

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ if GROUP == "all" || GROUP == "cuda"
4242
end
4343
end
4444

45-
(GROUP == "all" || GROUP == "adjoint") && Pkg.add(["SciMLSensitivity"])
46-
4745
length(EXTRA_PKGS) 1 && Pkg.add(EXTRA_PKGS)
4846

4947
# Use sequential execution for wrapper tests to avoid parallel initialization issues

0 commit comments

Comments
 (0)