Skip to content

Commit b731362

Browse files
Merge pull request #379 from SciML/ap/nlsolvers
Wrapper for NLSolvers.jl
2 parents d383660 + 9a6bdfb commit b731362

File tree

8 files changed

+170
-10
lines changed

8 files changed

+170
-10
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.5.6"
4+
version = "3.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -36,6 +36,7 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3636
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3737
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
3838
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
39+
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
3940
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4041
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4142
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -48,6 +49,7 @@ NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
4849
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
4950
NonlinearSolveMINPACKExt = "MINPACK"
5051
NonlinearSolveNLsolveExt = "NLsolve"
52+
NonlinearSolveNLSolversExt = "NLSolvers"
5153
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5254
NonlinearSolveSpeedMappingExt = "SpeedMapping"
5355
NonlinearSolveSymbolicsExt = "Symbolics"
@@ -77,6 +79,7 @@ LinearSolve = "2.21"
7779
MINPACK = "1.2"
7880
MaybeInplace = "0.1.1"
7981
NLsolve = "4.5"
82+
NLSolvers = "0.5"
8083
NaNMath = "1"
8184
NonlinearProblemLibrary = "0.1.2"
8285
OrdinaryDiffEq = "6.63"
@@ -120,6 +123,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
120123
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
121124
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
122125
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
126+
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
123127
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
124128
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
125129
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
@@ -139,4 +143,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
139143
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
140144

141145
[targets]
142-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials", "ReTestItems", "Reexport", "CUDA"]
146+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials", "ReTestItems", "Reexport", "CUDA", "NLSolvers"]

docs/src/api/nlsolve.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using these solvers:
77
```julia
88
using Pkg
99
Pkg.add("NLsolve")
10-
using NLSolve, NonlinearSolve
10+
using NLsolve, NonlinearSolve
1111
```
1212

1313
## Solver API

docs/src/api/nlsolvers.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# NLSolvers.jl
2+
3+
This is a extension for importing solvers from NLSolvers.jl into the SciML interface. Note
4+
that these solvers do not come by default, and thus one needs to install the package before
5+
using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("NLSolvers")
10+
using NLSolvers, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
NLSolversJL
17+
```

docs/src/solvers/nonlinear_system_solvers.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,12 @@ SIAMFANLEquations.jl is a wrapper for the methods in the SIAMFANLEquations.jl li
168168
Other solvers listed in [Fixed Point Solvers](@ref fixed_point_methods_full_list),
169169
[FastLevenbergMarquardt.jl](@ref fastlm_wrapper_summary) and
170170
[LeastSquaresOptim.jl](@ref lso_wrapper_summary) can also solve nonlinear systems.
171+
172+
### NLSolvers.jl
173+
174+
This is a wrapper package for importing solvers from NLSolvers.jl into the SciML interface.
175+
176+
- [`NLSolversJL()`](@ref): A wrapper for
177+
[NLSolvers.jl](https://github.com/JuliaNLSolvers/NLSolvers.jl)
178+
179+
For a list of possible solvers see the [NLSolvers.jl documentation](https://julianlsolvers.github.io/NLSolvers.jl/)

ext/NonlinearSolveNLSolversExt.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module NonlinearSolveNLSolversExt
2+
3+
using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra
4+
using FiniteDiff, ForwardDiff
5+
6+
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
7+
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
8+
termination_condition = nothing, kwargs...)
9+
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)
10+
11+
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
12+
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(prob.u0))
13+
14+
options = NEqOptions(; maxiter = maxiters, f_abstol = abstol, f_reltol = reltol)
15+
16+
if prob.u0 isa Number
17+
f_scalar = @closure x -> prob.f(x, prob.p)
18+
19+
if alg.autodiff === nothing
20+
if ForwardDiff.can_dual(typeof(prob.u0))
21+
autodiff_concrete = :forwarddiff
22+
else
23+
autodiff_concrete = :finitediff
24+
end
25+
else
26+
if alg.autodiff isa AutoForwardDiff || alg.autodiff isa AutoPolyesterForwardDiff
27+
autodiff_concrete = :forwarddiff
28+
elseif alg.autodiff isa AutoFiniteDiff
29+
autodiff_concrete = :finitediff
30+
else
31+
error("Only ForwardDiff or FiniteDiff autodiff is supported.")
32+
end
33+
end
34+
35+
if autodiff_concrete === :forwarddiff
36+
fj_scalar = @closure (Jx, x) -> begin
37+
T = typeof(NonlinearSolve.NonlinearSolveTag())
38+
x_dual = ForwardDiff.Dual{T}(x, one(x))
39+
y = prob.f(x_dual, prob.p)
40+
return ForwardDiff.value(y), ForwardDiff.extract_derivative(T, y)
41+
end
42+
else
43+
fj_scalar = @closure (Jx, x) -> begin
44+
_f = Base.Fix2(prob.f, prob.p)
45+
return _f(x), FiniteDiff.finite_difference_derivative(_f, x)
46+
end
47+
end
48+
49+
prob_obj = NLSolvers.ScalarObjective(; f = f_scalar, fg = fj_scalar)
50+
prob_nlsolver = NEqProblem(prob_obj; inplace = false)
51+
res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options)
52+
53+
retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
54+
ReturnCode.MaxIters)
55+
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)
56+
57+
return SciMLBase.build_solution(prob, alg, res.info.solution,
58+
res.info.best_residual; retcode, original = res, stats)
59+
end
60+
61+
f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
62+
63+
jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; alg.autodiff)
64+
65+
FJ_vector! = @closure (Fx, Jx, x) -> begin
66+
f!(Fx, x)
67+
jac!(Jx, x)
68+
return Fx, Jx
69+
end
70+
71+
prob_obj = NLSolvers.VectorObjective(; F = f!, FJ = FJ_vector!)
72+
prob_nlsolver = NEqProblem(prob_obj)
73+
74+
res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options)
75+
76+
retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
77+
ReturnCode.MaxIters)
78+
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)
79+
80+
return SciMLBase.build_solution(prob, alg, res.info.solution,
81+
res.info.best_residual; retcode, original = res, stats)
82+
end
83+
84+
end

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPoly
144144
FastShortcutNLLSPolyalg
145145

146146
# Extension Algorithms
147-
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL,
147+
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
148148
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
149149

150150
# Advanced Algorithms -- Without Bells and Whistles

src/algorithms/extension_algs.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,38 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
279279
linsolve, factor, autoscale, m, beta, show_trace)
280280
end
281281

282+
"""
283+
NLSolversJL(method; autodiff = nothing)
284+
NLSolversJL(; method, autodiff = nothing)
285+
286+
Wrapper over NLSolvers.jl Nonlinear Equation Solvers. We automatically construct the
287+
jacobian function and supply it to the solver.
288+
289+
### Arguments
290+
291+
- `method`: the choice of method for solving the nonlinear system. See the documentation
292+
for NLSolvers.jl for more information.
293+
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `nothing`
294+
which means that a default is selected according to the problem specification. Can be
295+
any valid ADTypes.jl autodiff type (conditional on that backend being supported in
296+
NonlinearSolve.jl).
297+
"""
298+
struct NLSolversJL{M, AD} <: AbstractNonlinearSolveExtensionAlgorithm
299+
method::M
300+
autodiff::AD
301+
302+
function NLSolversJL(method, autodiff)
303+
if Base.get_extension(@__MODULE__, :NonlinearSolveNLSolversExt) === nothing
304+
error("NLSolversJL requires NLSolvers.jl to be loaded")
305+
end
306+
307+
return new{typeof(method), typeof(autodiff)}(method, autodiff)
308+
end
309+
end
310+
311+
NLSolversJL(method; autodiff = nothing) = NLSolversJL(method, autodiff)
312+
NLSolversJL(; method, autodiff = nothing) = NLSolversJL(method, autodiff)
313+
282314
"""
283315
SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
284316
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)

test/wrappers/rootfind_tests.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
@testsetup module WrapperRootfindImports
22
using Reexport
3-
@reexport using LinearAlgebra, NLsolve, SIAMFANLEquations, MINPACK
3+
@reexport using LinearAlgebra
4+
import NLSolvers, NLsolve, SIAMFANLEquations, MINPACK
5+
6+
export NLSolvers
47
end
58

69
@testitem "Steady State Problems" setup=[WrapperRootfindImports] begin
@@ -12,7 +15,8 @@ end
1215
u0 = zeros(2)
1316
prob_iip = SteadyStateProblem(f_iip, u0)
1417

15-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
18+
for alg in [
19+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
1620
sol = solve(prob_iip, alg)
1721
@test SciMLBase.successful_retcode(sol.retcode)
1822
@test maximum(abs, sol.resid) < 1e-6
@@ -23,7 +27,8 @@ end
2327
u0 = zeros(2)
2428
prob_oop = SteadyStateProblem(f_oop, u0)
2529

26-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
30+
for alg in [
31+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
2732
sol = solve(prob_oop, alg)
2833
@test SciMLBase.successful_retcode(sol.retcode)
2934
@test maximum(abs, sol.resid) < 1e-6
@@ -39,7 +44,8 @@ end
3944
u0 = zeros(2)
4045
prob_iip = NonlinearProblem{true}(f_iip, u0)
4146

42-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
47+
for alg in [
48+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
4349
local sol
4450
sol = solve(prob_iip, alg)
4551
@test SciMLBase.successful_retcode(sol.retcode)
@@ -50,7 +56,8 @@ end
5056
f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]]
5157
u0 = zeros(2)
5258
prob_oop = NonlinearProblem{false}(f_oop, u0)
53-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
59+
for alg in [
60+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
5461
local sol
5562
sol = solve(prob_oop, alg)
5663
@test SciMLBase.successful_retcode(sol.retcode)
@@ -61,7 +68,10 @@ end
6168
f_tol(u, p) = u^2 - 2
6269
prob_tol = NonlinearProblem(f_tol, 1.0)
6370
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15],
64-
alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
71+
alg in [
72+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
73+
NLsolveJL(),
74+
CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
6575
SIAMFANLEquationsJL(; method = :pseudotransient),
6676
SIAMFANLEquationsJL(; method = :secant)]
6777

@@ -107,6 +117,10 @@ end
107117

108118
sol = solve(ProbN, NLsolveJL(); abstol = 1e-8)
109119
@test maximum(abs, sol.resid) < 1e-6
120+
sol = solve(ProbN,
121+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking()));
122+
abstol = 1e-8)
123+
@test maximum(abs, sol.resid) < 1e-6
110124
sol = solve(ProbN, SIAMFANLEquationsJL(; method = :newton); abstol = 1e-8)
111125
@test maximum(abs, sol.resid) < 1e-6
112126
sol = solve(ProbN, SIAMFANLEquationsJL(; method = :pseudotransient); abstol = 1e-8)

0 commit comments

Comments
 (0)