Skip to content

Commit 51c443e

Browse files
Merge pull request #50 from SciML/staticarrays
specialize Newton on static arrays
2 parents eb6b2e6 + 47bcc55 commit 51c443e

File tree

5 files changed

+211
-189
lines changed

5 files changed

+211
-189
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ julia = "1.6"
3434
[extras]
3535
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3636
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
37+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
38+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3739
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3840

3941
[targets]
40-
test = ["BenchmarkTools", "Test", "ForwardDiff"]
42+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]

src/scalar.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1-
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
1+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
22
f = Base.Fix2(prob.f, prob.p)
33
x = float(prob.u0)
44
fx = float(prob.u0)
55
T = typeof(x)
6-
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
7-
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
6+
atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4//5)
7+
rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4//5)
8+
9+
if typeof(x) <: Number
10+
xo = oftype(one(eltype(x)), Inf)
11+
else
12+
xo = map(x->oftype(one(eltype(x)), Inf),x)
13+
end
814

9-
xo = oftype(x, Inf)
1015
for i in 1:maxiters
1116
if alg_autodiff(alg)
1217
fx, dfx = value_derivative(f, x)
18+
elseif x isa AbstractArray
19+
fx = f(x)
20+
dfx = FiniteDiff.finite_difference_jacobian(f, x, alg.diff_type, eltype(x), fx)
1321
else
1422
fx = f(x)
1523
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
@@ -49,12 +57,12 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
4957
return sol, partials
5058
end
5159

52-
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
60+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5361
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5462
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
55-
63+
5664
end
57-
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
65+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5866
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5967
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
6068
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ function value_derivative(f::F, x::R) where {F,R}
224224
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
225225
end
226226

227+
# Todo: improve this dispatch
228+
value_derivative(f::F, x::SVector) where F = f(x),ForwardDiff.jacobian(f, x)
229+
227230
value(x) = x
228231
value(x::Dual) = ForwardDiff.value(x)
229232
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

test/basictests.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
using NonlinearSolve
2+
using StaticArrays
3+
using BenchmarkTools
4+
using Test
5+
6+
function benchmark_immutable(f, u0)
7+
probN = NonlinearProblem{false}(f, u0)
8+
solver = init(probN, NewtonRaphson(), tol = 1e-9)
9+
sol = solve!(solver)
10+
end
11+
12+
function benchmark_mutable(f, u0)
13+
probN = NonlinearProblem{false}(f, u0)
14+
solver = init(probN, NewtonRaphson(), tol = 1e-9)
15+
sol = (reinit!(solver, probN); solve!(solver))
16+
end
17+
18+
function benchmark_scalar(f, u0)
19+
probN = NonlinearProblem{false}(f, u0)
20+
sol = (solve(probN, NewtonRaphson()))
21+
end
22+
23+
function ff(u,p)
24+
u .* u .- 2
25+
end
26+
const cu0 = @SVector[1.0, 1.0]
27+
function sf(u,p)
28+
u * u - 2
29+
end
30+
const csu0 = 1.0
31+
32+
sol = benchmark_immutable(ff, cu0)
33+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
34+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
35+
sol = benchmark_mutable(ff, cu0)
36+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
37+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
38+
sol = benchmark_scalar(sf, csu0)
39+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
40+
@test sol.u * sol.u - 2 < 1e-9
41+
42+
@test (@ballocated benchmark_immutable(ff, cu0)) == 0
43+
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
44+
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
45+
46+
# AD Tests
47+
using ForwardDiff
48+
49+
# Immutable
50+
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
51+
52+
g = function (p)
53+
probN = NonlinearProblem{false}(f, csu0, p)
54+
sol = solve(probN, NewtonRaphson(), tol = 1e-9)
55+
return sol.u[end]
56+
end
57+
58+
for p in 1.0:0.1:100.0
59+
@test g(p) sqrt(p)
60+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
61+
end
62+
63+
# Scalar
64+
f, u0 = (u, p) -> u * u - p, 1.0
65+
66+
# NewtonRaphson
67+
g = function (p)
68+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
69+
sol = solve(probN, NewtonRaphson())
70+
return sol.u
71+
end
72+
73+
@test ForwardDiff.derivative(g, 1.0) 0.5
74+
75+
for p in 1.1:0.1:100.0
76+
@test g(p) sqrt(p)
77+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
78+
end
79+
80+
u0 = (1.0, 20.0)
81+
# Falsi
82+
g = function (p)
83+
probN = NonlinearProblem{false}(f, typeof(p).(u0), p)
84+
sol = solve(probN, Falsi())
85+
return sol.left
86+
end
87+
88+
for p in 1.1:0.1:100.0
89+
@test g(p) sqrt(p)
90+
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
91+
end
92+
93+
f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
94+
t = (p) -> [sqrt(p[2] / p[1])]
95+
p = [0.9, 50.0]
96+
for alg in [Bisection(), Falsi()]
97+
global g, p
98+
g = function (p)
99+
probN = NonlinearProblem{false}(f, u0, p)
100+
sol = solve(probN, Bisection())
101+
return [sol.left]
102+
end
103+
104+
@test g(p) [sqrt(p[2] / p[1])]
105+
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
106+
end
107+
108+
gnewton = function (p)
109+
probN = NonlinearProblem{false}(f, 0.5, p)
110+
sol = solve(probN, NewtonRaphson())
111+
return [sol.u]
112+
end
113+
@test gnewton(p) [sqrt(p[2] / p[1])]
114+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
115+
116+
# Error Checks
117+
118+
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
119+
probN = NonlinearProblem(f, u0)
120+
121+
@test solve(probN, NewtonRaphson()).u[end] sqrt(2.0)
122+
@test solve(probN, NewtonRaphson(); immutable = false).u[end] sqrt(2.0)
123+
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)
124+
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)
125+
126+
for u0 in [1.0, [1, 1.0]]
127+
local f, probN, sol
128+
f = (u, p) -> u .* u .- 2.0
129+
probN = NonlinearProblem(f, u0)
130+
sol = sqrt(2) * u0
131+
132+
@test solve(probN, NewtonRaphson()).u sol
133+
@test solve(probN, NewtonRaphson()).u sol
134+
@test solve(probN, NewtonRaphson(;autodiff=false)).u sol
135+
end
136+
137+
# Bisection Tests
138+
f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
139+
probB = NonlinearProblem(f, u0)
140+
141+
# Falsi
142+
solver = init(probB, Falsi())
143+
sol = solve!(solver)
144+
@test sol.left sqrt(2.0)
145+
146+
# this should call the fast scalar overload
147+
@test solve(probB, Bisection()).left sqrt(2.0)
148+
149+
# these should call the iterator version
150+
solver = init(probB, Bisection())
151+
@test solver isa NonlinearSolve.BracketingImmutableSolver
152+
@test solve!(solver).left sqrt(2.0)
153+
154+
# Garuntee Tests for Bisection
155+
f = function (u, p)
156+
if u < 2.0
157+
return u - 2.0
158+
elseif u > 3.0
159+
return u - 3.0
160+
else
161+
return 0.0
162+
end
163+
end
164+
probB = NonlinearProblem(f, (0.0, 4.0))
165+
166+
solver = init(probB, Bisection(;exact_left = true))
167+
sol = solve!(solver)
168+
@test f(sol.left, nothing) < 0.0
169+
@test f(nextfloat(sol.left), nothing) >= 0.0
170+
171+
solver = init(probB, Bisection(;exact_right = true))
172+
sol = solve!(solver)
173+
@test f(sol.right, nothing) > 0.0
174+
@test f(prevfloat(sol.right), nothing) <= 0.0
175+
176+
solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false)
177+
sol = solve!(solver)
178+
@test f(sol.left, nothing) < 0.0
179+
@test f(nextfloat(sol.left), nothing) >= 0.0
180+
@test f(sol.right, nothing) > 0.0
181+
@test f(prevfloat(sol.right), nothing) <= 0.0

0 commit comments

Comments
 (0)