Skip to content

Commit 23835fc

Browse files
committed
(test) AD Tests
1 parent 9aa63f4 commit 23835fc

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1515
[extras]
1616
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
1717
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1819

1920
[targets]
20-
test = ["BenchmarkTools", "Test"]
21+
test = ["BenchmarkTools", "Test", "ForwardDiff"]

src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ function DiffEqBase.solve(prob::NonlinearProblem,
22
alg::AbstractNonlinearSolveAlgorithm, args...;
33
kwargs...)
44
solver = DiffEqBase.init(prob, alg, args...; kwargs...)
5-
solve!(solver)
6-
return solver.sol
5+
sol = solve!(solver)
6+
return sol
77
end
88

99
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;

test/runtests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,36 @@ sol = benchmark_scalar(sf, su0)
3535
@test (@ballocated benchmark_immutable($f, $u0)) == 0
3636
@test (@ballocated benchmark_mutable($f, $u0)) < 200
3737
@test (@ballocated benchmark_scalar($sf, $su0)) == 0
38+
39+
# AD Tests
40+
using ForwardDiff
41+
42+
# Immutable
43+
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
44+
45+
g = function (p)
46+
probN = NonlinearProblem{false}(f, u0, p)
47+
sol = solve(probN, NewtonRaphson(), immutable = true, tol = 1e-9)
48+
return sol.u[end]
49+
end
50+
51+
for p in 1.0:0.1:100.0
52+
@test g(p) sqrt(p)
53+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
54+
end
55+
56+
# Scalar
57+
f, u0 = (u, p) -> u * u - p, 1.0
58+
59+
g = function (p)
60+
probN = NonlinearProblem{false}(f, u0, p)
61+
sol = solve(probN, NewtonRaphson())
62+
return sol.u
63+
end
64+
65+
@test_broken ForwardDiff.derivative(g, 1.0) 0.5
66+
67+
for p in 1.1:0.1:100.0
68+
@test g(p) sqrt(p)
69+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
70+
end

0 commit comments

Comments
 (0)