Skip to content

Commit 6f0d159

Browse files
specialize Newton on static arrays
```julia using StaticArrays using NonlinearSolve function f(x, _) F1 = (x[1] + 3) * (x[2]^3 - 7) + 18 F2 = sin(x[2] * exp(x[1]) - 1) SA[F1,F2] end function f!(F, x) F[1] = (x[1] + 3) * (x[2]^3 - 7) + 18 F[2] = sin(x[2] * exp(x[1]) - 1) nothing end x0 = [0.1; 1.2] x0s = SVector{size(x0)...}(x0) x0m = MVector{size(x0)...}(x0) prob = NonlinearProblem{false}(f, x0s) using NLsolve, BenchmarkTools @Btime sol = solve(prob,NewtonRaphson()); # 320.000 ns (2 allocations: 128 bytes) @Btime nlsolve(f!, x0m); # 1.460 μs (35 allocations: 1.36 KiB) ```
1 parent eb6b2e6 commit 6f0d159

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/scalar.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
1+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
22
f = Base.Fix2(prob.f, prob.p)
33
x = float(prob.u0)
44
fx = float(prob.u0)
55
T = typeof(x)
6-
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
7-
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
6+
atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4//5)
7+
rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4//5)
8+
9+
if typeof(x) <: Number
10+
xo = oftype(one(eltype(x)), Inf)
11+
else
12+
xo = map(x->oftype(one(eltype(x)), Inf),x)
13+
end
814

9-
xo = oftype(x, Inf)
1015
for i in 1:maxiters
1116
if alg_autodiff(alg)
1217
fx, dfx = value_derivative(f, x)
@@ -52,7 +57,7 @@ end
5257
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5358
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5459
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
55-
60+
5661
end
5762
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5863
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ function value_derivative(f::F, x::R) where {F,R}
224224
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
225225
end
226226

227+
# Todo: improve this dispatch
228+
value_derivative(f::F, x::SVector) where F = f(x),ForwardDiff.jacobian(f, x)
229+
227230
value(x) = x
228231
value(x::Dual) = ForwardDiff.value(x)
229232
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

0 commit comments

Comments
 (0)