Skip to content

Commit 0cbc2fc

Browse files
committed
test: adjoints
1 parent fcca2f7 commit 0cbc2fc

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote
3+
4+
ff(u, p) = u .^ 2 .- p
5+
6+
function solve_nlprob(p)
7+
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
8+
sol = solve(prob, SimpleNewtonRaphson())
9+
res = sol isa AbstractArray ? sol : sol.u
10+
return sum(abs2, res)
11+
end
12+
13+
p = [3.0, 2.0]
14+
15+
∂p_zygote = only(Zygote.gradient(solve_nlprob, p))
16+
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
17+
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
18+
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
19+
@test ∂p_zygote ∂p_tracker ∂p_reversediff
20+
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff
21+
end

0 commit comments

Comments
 (0)