Skip to content

Commit 36d4bb6

Browse files
committed
Start wrapping NLSolvers
1 parent c50c21a commit 36d4bb6

File tree

4 files changed

+73
-3
lines changed

4 files changed

+73
-3
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.5.6"
4+
version = "3.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -36,6 +36,7 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3636
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3737
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
3838
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
39+
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
3940
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4041
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4142
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -48,6 +49,7 @@ NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
4849
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
4950
NonlinearSolveMINPACKExt = "MINPACK"
5051
NonlinearSolveNLsolveExt = "NLsolve"
52+
NonlinearSolveNLSolversExt = "NLSolvers"
5153
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5254
NonlinearSolveSpeedMappingExt = "SpeedMapping"
5355
NonlinearSolveSymbolicsExt = "Symbolics"
@@ -77,6 +79,7 @@ LinearSolve = "2.21"
7779
MINPACK = "1.2"
7880
MaybeInplace = "0.1.1"
7981
NLsolve = "4.5"
82+
NLSolvers = "0.5"
8083
NaNMath = "1"
8184
NonlinearProblemLibrary = "0.1.2"
8285
OrdinaryDiffEq = "6.63"
@@ -120,6 +123,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
120123
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
121124
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
122125
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
126+
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
123127
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
124128
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
125129
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
@@ -139,4 +143,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
139143
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
140144

141145
[targets]
142-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials", "ReTestItems", "Reexport", "CUDA"]
146+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials", "ReTestItems", "Reexport", "CUDA", "NLSolvers"]

ext/NonlinearSolveNLsolversExt.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
module NonlinearSolveNLSolversExt
2+
3+
using NonlinearSolve, NLSolversJL, SciMLBase
4+
5+
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
6+
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0::Bool = false,
7+
termination_condition = nothing, kwargs...) where {StT, ShT}
8+
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)
9+
10+
if prob.u0 isa Number
11+
error("Scalar Inputs for NLsolversJL is not yet handled.")
12+
end
13+
14+
f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
15+
16+
# if prob.f.jac === nothing && alg.autodiff isa Symbol
17+
# df = OnceDifferentiable(f!, u0, resid; alg.autodiff)
18+
# else
19+
# jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; alg.autodiff)
20+
# if prob.f.jac_prototype === nothing
21+
# J = similar(u0, promote_type(eltype(u0), eltype(resid)), length(u0),
22+
# length(resid))
23+
# else
24+
# J = zero(prob.f.jac_prototype)
25+
# end
26+
# df = OnceDifferentiable(f!, jac!, vec(u0), vec(resid), J)
27+
# end
28+
29+
# abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
30+
# show_trace = ShT || alg.show_trace
31+
# store_trace = StT || alg.store_trace
32+
# extended_trace = !(trace_level isa TraceMinimal) || alg.extended_trace
33+
34+
# original = nlsolve(df, vec(u0); ftol = abstol, iterations = maxiters, alg.method,
35+
# store_trace, extended_trace, alg.linesearch, alg.linsolve, alg.factor,
36+
# alg.autoscale, alg.m, alg.beta, show_trace)
37+
38+
# f!(vec(resid), original.zero)
39+
# u = prob.u0 isa Number ? original.zero[1] : reshape(original.zero, size(prob.u0))
40+
# resid = prob.u0 isa Number ? resid[1] : resid
41+
42+
# retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
43+
# ReturnCode.Failure
44+
# stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
45+
# original.g_calls, original.iterations)
46+
47+
# return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
48+
end
49+
50+
end

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ export NonlinearSolvePolyAlgorithm,
144144
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
145145

146146
# Extension Algorithms
147-
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL,
147+
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
148148
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
149149

150150
# Advanced Algorithms -- Without Bells and Whistles

src/algorithms/extension_algs.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,22 @@ function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace =
281281
factor, autoscale, m, beta, show_trace)
282282
end
283283

284+
@concrete struct NLSolversJL <: AbstractNonlinearSolveExtensionAlgorithm
285+
method
286+
autodiff
287+
288+
function NLSolversJL(method, autodiff)
289+
if Base.get_extension(@__MODULE__, :NonlinearSolveNLsolversExt) === nothing
290+
error("NLSolversJL requires NLSolvers.jl to be loaded")
291+
end
292+
293+
return new{typeof(method), typeof(autodiff)}(method, autodiff)
294+
end
295+
end
296+
297+
NLSolversJL(method; autodiff = nothing) = NLSolversJL(method, autodiff)
298+
NLSolversJL(; method, autodiff = nothing) = NLSolversJL(method, autodiff)
299+
284300
"""
285301
SpeedMappingJL(; σ_min = 0.0, stabilize::Bool = false, check_obj::Bool = false,
286302
orders::Vector{Int} = [3, 3, 2], time_limit::Real = 1000)

0 commit comments

Comments
 (0)