1- using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test
1+ @testitem " Adjoint Tests" tags = [:adjoint ] begin
2+ using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, NonlinearSolve, Test
23
3- ff (u, p) = u .^ 2 .- p
4+ ff (u, p) = u .^ 2 .- p
45
5- function solve_nlprob (p)
6- prob = NonlinearProblem {false} (ff, [1.0 , 2.0 ], p)
7- sol = solve (prob, NewtonRaphson ())
8- res = sol isa AbstractArray ? sol : sol. u
9- return sum (abs2, res)
6+ function solve_nlprob (p)
7+ prob = NonlinearProblem {false} (ff, [1.0 , 2.0 ], p)
8+ sol = solve (prob, NewtonRaphson ())
9+ res = sol isa AbstractArray ? sol : sol. u
10+ return sum (abs2, res)
11+ end
12+
13+ p = [3.0 , 2.0 ]
14+
15+ ∂p_zygote = only (Zygote. gradient (solve_nlprob, p))
16+ ∂p_forwarddiff = ForwardDiff. gradient (solve_nlprob, p)
17+ ∂p_tracker = Tracker. data (only (Tracker. gradient (solve_nlprob, p)))
18+ ∂p_reversediff = ReverseDiff. gradient (solve_nlprob, p)
19+ ∂p_enzyme = Enzyme. gradient (Enzyme. Reverse, solve_nlprob, p)[1 ]
20+ @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme
21+ @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme
1022end
1123
12- p = [3.0 , 2.0 ]
24+ @testitem " Simple Adjoint Test" tags= [:adjoint ] begin
25+ using ForwardDiff, Zygote, BracketingNonlinearSolve
26+
27+ ff (u, p) = u^ 2 .- p[1 ]
1328
14- ∂p_zygote = only (Zygote. gradient (solve_nlprob, p))
15- ∂p_forwarddiff = ForwardDiff. gradient (solve_nlprob, p)
16- ∂p_tracker = Tracker. data (only (Tracker. gradient (solve_nlprob, p)))
17- ∂p_reversediff = ReverseDiff. gradient (solve_nlprob, p)
18- @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff
19- @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff
29+ function solve_nlprob (p)
30+ prob = IntervalNonlinearProblem {false} (ff, (1.0 , 3.0 ), p)
31+ sol = solve (prob, Bisection ())
32+ res = sol isa AbstractArray ? sol : sol. u
33+ return sum (abs2, res)
34+ end
35+
36+ p = [2.0 , 2.0 ]
37+
38+ ∂p_zygote = only (Zygote. gradient (solve_nlprob, p))
39+ ∂p_forwarddiff = ForwardDiff. gradient (solve_nlprob, p)
40+ @test ∂p_zygote ≈ ∂p_forwarddiff
41+ end
0 commit comments