Skip to content

Commit 711bc53

Browse files
Fix forwarddiff overloads
1 parent 6f0d159 commit 711bc53

File tree

4 files changed

+197
-186
lines changed

4 files changed

+197
-186
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ julia = "1.6"
3434
[extras]
3535
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3636
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
37+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3738
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3839

3940
[targets]
40-
test = ["BenchmarkTools", "Test", "ForwardDiff"]
41+
test = ["BenchmarkTools", "SafeTestsets", "Test", "ForwardDiff"]

src/scalar.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}}, alg::N
1515
for i in 1:maxiters
1616
if alg_autodiff(alg)
1717
fx, dfx = value_derivative(f, x)
18+
elseif x isa AbstractArray
19+
fx = f(x)
20+
dfx = FiniteDiff.finite_difference_jacobian(f, x, alg.diff_type, eltype(x), fx)
1821
else
1922
fx = f(x)
2023
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
@@ -54,12 +57,12 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5457
return sol, partials
5558
end
5659

57-
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
60+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5861
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5962
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
6063

6164
end
62-
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
65+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
6366
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
6467
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
6568
end

test/basictests.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
using NonlinearSolve
2+
using StaticArrays
3+
using BenchmarkTools
4+
using Test
5+
6+
function benchmark_immutable(f, u0)
7+
probN = NonlinearProblem{false}(f, u0)
8+
solver = init(probN, NewtonRaphson(), tol = 1e-9)
9+
sol = solve!(solver)
10+
end
11+
12+
function benchmark_mutable(f, u0)
13+
probN = NonlinearProblem{false}(f, u0)
14+
solver = init(probN, NewtonRaphson(), tol = 1e-9)
15+
sol = (reinit!(solver, probN); solve!(solver))
16+
end
17+
18+
function benchmark_scalar(f, u0)
19+
probN = NonlinearProblem{false}(f, u0)
20+
sol = (solve(probN, NewtonRaphson()))
21+
end
22+
23+
function ff(u,p)
24+
u .* u .- 2
25+
end
26+
const cu0 = @SVector[1.0, 1.0]
27+
function sf(u,p)
28+
u * u - 2
29+
end
30+
const csu0 = 1.0
31+
32+
sol = benchmark_immutable(ff, cu0)
33+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
34+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
35+
sol = benchmark_mutable(ff, cu0)
36+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
37+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
38+
sol = benchmark_scalar(sf, csu0)
39+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
40+
@test sol.u * sol.u - 2 < 1e-9
41+
42+
@test (@ballocated benchmark_immutable(ff, cu0)) == 0
43+
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
44+
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
45+
46+
# AD Tests
47+
using ForwardDiff
48+
49+
# Immutable
50+
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
51+
52+
g = function (p)
53+
probN = NonlinearProblem{false}(f, u0, p)
54+
sol = solve(probN, NewtonRaphson(), tol = 1e-9)
55+
return sol.u[end]
56+
end
57+
58+
for p in 1.0:0.1:100.0
59+
@test g(p) sqrt(p)
60+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
61+
end
62+
63+
# Scalar
64+
f, u0 = (u, p) -> u * u - p, 1.0
65+
66+
# NewtonRaphson
67+
g = function (p)
68+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
69+
sol = solve(probN, NewtonRaphson())
70+
return sol.u
71+
end
72+
73+
@test ForwardDiff.derivative(g, 1.0) 0.5
74+
75+
for p in 1.1:0.1:100.0
76+
@test g(p) sqrt(p)
77+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
78+
end
79+
80+
u0 = (1.0, 20.0)
81+
# Falsi
82+
g = function (p)
83+
probN = NonlinearProblem{false}(f, typeof(p).(u0), p)
84+
sol = solve(probN, Falsi())
85+
return sol.left
86+
end
87+
88+
for p in 1.1:0.1:100.0
89+
@test g(p) sqrt(p)
90+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
91+
end
92+
93+
f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
94+
t = (p) -> [sqrt(p[2] / p[1])]
95+
p = [0.9, 50.0]
96+
for alg in [Bisection(), Falsi()]
97+
global g, p
98+
g = function (p)
99+
probN = NonlinearProblem{false}(f, u0, p)
100+
sol = solve(probN, Bisection())
101+
return [sol.left]
102+
end
103+
104+
@test g(p) [sqrt(p[2] / p[1])]
105+
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
106+
end
107+
108+
gnewton = function (p)
109+
probN = NonlinearProblem{false}(f, 0.5, p)
110+
sol = solve(probN, NewtonRaphson())
111+
return [sol.u]
112+
end
113+
@test gnewton(p) [sqrt(p[2] / p[1])]
114+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
115+
116+
# Error Checks
117+
118+
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
119+
probN = NonlinearProblem(f, u0)
120+
121+
@test solve(probN, NewtonRaphson()).u[end] sqrt(2.0)
122+
@test solve(probN, NewtonRaphson(); immutable = false).u[end] sqrt(2.0)
123+
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)
124+
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)
125+
126+
for u0 in [1.0, [1, 1.0]]
127+
local f, probN, sol
128+
f = (u, p) -> u .* u .- 2.0
129+
probN = NonlinearProblem(f, u0)
130+
sol = sqrt(2) * u0
131+
132+
@test solve(probN, NewtonRaphson()).u sol
133+
@test solve(probN, NewtonRaphson()).u sol
134+
@test solve(probN, NewtonRaphson(;autodiff=false)).u sol
135+
end
136+
137+
# Bisection Tests
138+
f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
139+
probB = NonlinearProblem(f, u0)
140+
141+
# Falsi
142+
solver = init(probB, Falsi())
143+
sol = solve!(solver)
144+
@test sol.left sqrt(2.0)
145+
146+
# this should call the fast scalar overload
147+
@test solve(probB, Bisection()).left sqrt(2.0)
148+
149+
# these should call the iterator version
150+
solver = init(probB, Bisection())
151+
@test solver isa NonlinearSolve.BracketingImmutableSolver
152+
@test solve!(solver).left sqrt(2.0)
153+
154+
# Garuntee Tests for Bisection
155+
f = function (u, p)
156+
if u < 2.0
157+
return u - 2.0
158+
elseif u > 3.0
159+
return u - 3.0
160+
else
161+
return 0.0
162+
end
163+
end
164+
probB = NonlinearProblem(f, (0.0, 4.0))
165+
166+
solver = init(probB, Bisection(;exact_left = true))
167+
sol = solve!(solver)
168+
@test f(sol.left, nothing) < 0.0
169+
@test f(nextfloat(sol.left), nothing) >= 0.0
170+
171+
solver = init(probB, Bisection(;exact_right = true))
172+
sol = solve!(solver)
173+
@test f(sol.right, nothing) > 0.0
174+
@test f(prevfloat(sol.right), nothing) <= 0.0
175+
176+
solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false)
177+
sol = solve!(solver)
178+
@test f(sol.left, nothing) < 0.0
179+
@test f(nextfloat(sol.left), nothing) >= 0.0
180+
@test f(sol.right, nothing) > 0.0
181+
@test f(prevfloat(sol.right), nothing) <= 0.0

0 commit comments

Comments
 (0)