Skip to content

Commit 9a6bdfb

Browse files
committed
Finish wrapping NLSolvers.jl
1 parent 36d4bb6 commit 9a6bdfb

File tree

7 files changed

+151
-61
lines changed

7 files changed

+151
-61
lines changed

docs/src/api/nlsolve.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using these solvers:
77
```julia
88
using Pkg
99
Pkg.add("NLsolve")
10-
using NLSolve, NonlinearSolve
10+
using NLsolve, NonlinearSolve
1111
```
1212

1313
## Solver API

docs/src/api/nlsolvers.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# NLSolvers.jl
2+
3+
This is a extension for importing solvers from NLSolvers.jl into the SciML interface. Note
4+
that these solvers do not come by default, and thus one needs to install the package before
5+
using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("NLSolvers")
10+
using NLSolvers, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
NLSolversJL
17+
```

docs/src/solvers/nonlinear_system_solvers.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,12 @@ SIAMFANLEquations.jl is a wrapper for the methods in the SIAMFANLEquations.jl li
168168
Other solvers listed in [Fixed Point Solvers](@ref fixed_point_methods_full_list),
169169
[FastLevenbergMarquardt.jl](@ref fastlm_wrapper_summary) and
170170
[LeastSquaresOptim.jl](@ref lso_wrapper_summary) can also solve nonlinear systems.
171+
172+
### NLSolvers.jl
173+
174+
This is a wrapper package for importing solvers from NLSolvers.jl into the SciML interface.
175+
176+
- [`NLSolversJL()`](@ref): A wrapper for
177+
[NLSolvers.jl](https://github.com/JuliaNLSolvers/NLSolvers.jl)
178+
179+
For a list of possible solvers see the [NLSolvers.jl documentation](https://julianlsolvers.github.io/NLSolvers.jl/)

ext/NonlinearSolveNLSolversExt.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module NonlinearSolveNLSolversExt
2+
3+
using ADTypes, FastClosures, NonlinearSolve, NLSolvers, SciMLBase, LinearAlgebra
4+
using FiniteDiff, ForwardDiff
5+
6+
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
7+
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
8+
termination_condition = nothing, kwargs...)
9+
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)
10+
11+
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
12+
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(prob.u0))
13+
14+
options = NEqOptions(; maxiter = maxiters, f_abstol = abstol, f_reltol = reltol)
15+
16+
if prob.u0 isa Number
17+
f_scalar = @closure x -> prob.f(x, prob.p)
18+
19+
if alg.autodiff === nothing
20+
if ForwardDiff.can_dual(typeof(prob.u0))
21+
autodiff_concrete = :forwarddiff
22+
else
23+
autodiff_concrete = :finitediff
24+
end
25+
else
26+
if alg.autodiff isa AutoForwardDiff || alg.autodiff isa AutoPolyesterForwardDiff
27+
autodiff_concrete = :forwarddiff
28+
elseif alg.autodiff isa AutoFiniteDiff
29+
autodiff_concrete = :finitediff
30+
else
31+
error("Only ForwardDiff or FiniteDiff autodiff is supported.")
32+
end
33+
end
34+
35+
if autodiff_concrete === :forwarddiff
36+
fj_scalar = @closure (Jx, x) -> begin
37+
T = typeof(NonlinearSolve.NonlinearSolveTag())
38+
x_dual = ForwardDiff.Dual{T}(x, one(x))
39+
y = prob.f(x_dual, prob.p)
40+
return ForwardDiff.value(y), ForwardDiff.extract_derivative(T, y)
41+
end
42+
else
43+
fj_scalar = @closure (Jx, x) -> begin
44+
_f = Base.Fix2(prob.f, prob.p)
45+
return _f(x), FiniteDiff.finite_difference_derivative(_f, x)
46+
end
47+
end
48+
49+
prob_obj = NLSolvers.ScalarObjective(; f = f_scalar, fg = fj_scalar)
50+
prob_nlsolver = NEqProblem(prob_obj; inplace = false)
51+
res = NLSolvers.solve(prob_nlsolver, prob.u0, alg.method, options)
52+
53+
retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
54+
ReturnCode.MaxIters)
55+
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)
56+
57+
return SciMLBase.build_solution(prob, alg, res.info.solution,
58+
res.info.best_residual; retcode, original = res, stats)
59+
end
60+
61+
f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
62+
63+
jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; alg.autodiff)
64+
65+
FJ_vector! = @closure (Fx, Jx, x) -> begin
66+
f!(Fx, x)
67+
jac!(Jx, x)
68+
return Fx, Jx
69+
end
70+
71+
prob_obj = NLSolvers.VectorObjective(; F = f!, FJ = FJ_vector!)
72+
prob_nlsolver = NEqProblem(prob_obj)
73+
74+
res = NLSolvers.solve(prob_nlsolver, u0, alg.method, options)
75+
76+
retcode = ifelse(norm(res.info.best_residual, Inf) abstol, ReturnCode.Success,
77+
ReturnCode.MaxIters)
78+
stats = SciMLBase.NLStats(-1, -1, -1, -1, res.info.iter)
79+
80+
return SciMLBase.build_solution(prob, alg, res.info.solution,
81+
res.info.best_residual; retcode, original = res, stats)
82+
end
83+
84+
end

ext/NonlinearSolveNLsolversExt.jl

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/algorithms/extension_algs.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,28 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
281281
factor, autoscale, m, beta, show_trace)
282282
end
283283

284-
@concrete struct NLSolversJL <: AbstractNonlinearSolveExtensionAlgorithm
285-
method
286-
autodiff
284+
"""
285+
NLSolversJL(method; autodiff = nothing)
286+
NLSolversJL(; method, autodiff = nothing)
287+
288+
Wrapper over NLSolvers.jl Nonlinear Equation Solvers. We automatically construct the
289+
jacobian function and supply it to the solver.
290+
291+
### Arguments
292+
293+
- `method`: the choice of method for solving the nonlinear system. See the documentation
294+
for NLSolvers.jl for more information.
295+
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `nothing`
296+
which means that a default is selected according to the problem specification. Can be
297+
any valid ADTypes.jl autodiff type (conditional on that backend being supported in
298+
NonlinearSolve.jl).
299+
"""
300+
struct NLSolversJL{M, AD} <: AbstractNonlinearSolveExtensionAlgorithm
301+
method::M
302+
autodiff::AD
287303

288304
function NLSolversJL(method, autodiff)
289-
if Base.get_extension(@__MODULE__, :NonlinearSolveNLsolversExt) === nothing
305+
if Base.get_extension(@__MODULE__, :NonlinearSolveNLSolversExt) === nothing
290306
error("NLSolversJL requires NLSolvers.jl to be loaded")
291307
end
292308

test/wrappers/rootfind_tests.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
@testsetup module WrapperRootfindImports
22
using Reexport
3-
@reexport using LinearAlgebra, NLsolve, SIAMFANLEquations, MINPACK
3+
@reexport using LinearAlgebra
4+
import NLSolvers, NLsolve, SIAMFANLEquations, MINPACK
5+
6+
export NLSolvers
47
end
58

69
@testitem "Steady State Problems" setup=[WrapperRootfindImports] begin
@@ -12,7 +15,8 @@ end
1215
u0 = zeros(2)
1316
prob_iip = SteadyStateProblem(f_iip, u0)
1417

15-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
18+
for alg in [
19+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
1620
sol = solve(prob_iip, alg)
1721
@test SciMLBase.successful_retcode(sol.retcode)
1822
@test maximum(abs, sol.resid) < 1e-6
@@ -23,7 +27,8 @@ end
2327
u0 = zeros(2)
2428
prob_oop = SteadyStateProblem(f_oop, u0)
2529

26-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
30+
for alg in [
31+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
2732
sol = solve(prob_oop, alg)
2833
@test SciMLBase.successful_retcode(sol.retcode)
2934
@test maximum(abs, sol.resid) < 1e-6
@@ -39,7 +44,8 @@ end
3944
u0 = zeros(2)
4045
prob_iip = NonlinearProblem{true}(f_iip, u0)
4146

42-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
47+
for alg in [
48+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
4349
local sol
4450
sol = solve(prob_iip, alg)
4551
@test SciMLBase.successful_retcode(sol.retcode)
@@ -50,7 +56,8 @@ end
5056
f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]]
5157
u0 = zeros(2)
5258
prob_oop = NonlinearProblem{false}(f_oop, u0)
53-
for alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
59+
for alg in [
60+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())), NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
5461
local sol
5562
sol = solve(prob_oop, alg)
5663
@test SciMLBase.successful_retcode(sol.retcode)
@@ -61,7 +68,10 @@ end
6168
f_tol(u, p) = u^2 - 2
6269
prob_tol = NonlinearProblem(f_tol, 1.0)
6370
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15],
64-
alg in [NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
71+
alg in [
72+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
73+
NLsolveJL(),
74+
CMINPACK(), SIAMFANLEquationsJL(; method = :newton),
6575
SIAMFANLEquationsJL(; method = :pseudotransient),
6676
SIAMFANLEquationsJL(; method = :secant)]
6777

@@ -107,6 +117,10 @@ end
107117

108118
sol = solve(ProbN, NLsolveJL(); abstol = 1e-8)
109119
@test maximum(abs, sol.resid) < 1e-6
120+
sol = solve(ProbN,
121+
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking()));
122+
abstol = 1e-8)
123+
@test maximum(abs, sol.resid) < 1e-6
110124
sol = solve(ProbN, SIAMFANLEquationsJL(; method = :newton); abstol = 1e-8)
111125
@test maximum(abs, sol.resid) < 1e-6
112126
sol = solve(ProbN, SIAMFANLEquationsJL(; method = :pseudotransient); abstol = 1e-8)

0 commit comments

Comments
 (0)