Skip to content

Commit 364d7f4

Browse files
Merge pull request #237 from avik-pal/ap/ext
Rename wrappers to be consistent with other SciML packages
2 parents 1a0e5ee + 42d741b commit 364d7f4

File tree

6 files changed

+33
-39
lines changed

6 files changed

+33
-39
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.3.0"
4+
version = "2.4.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase
44
import ConcreteStructs: @concrete
55
import FastLevenbergMarquardt as FastLM
66

7-
NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true
8-
9-
function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
7+
function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
108
if linsolve == :cholesky
119
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
1210
elseif linsolve == :qr
@@ -16,7 +14,7 @@ function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {lin
1614
end
1715
end
1816

19-
@concrete struct FastLMCache
17+
@concrete struct FastLevenbergMarquardtJLCache
2018
f!
2119
J!
2220
prob
@@ -34,7 +32,7 @@ end
3432
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))
3533

3634
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
37-
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
35+
alg::FastLevenbergMarquardtJL, args...; abstol = 1e-8, reltol = 1e-8,
3836
verbose = false, maxiters = 1000, kwargs...)
3937
iip = SciMLBase.isinplace(prob)
4038

@@ -52,13 +50,13 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
5250
solver = _fast_lm_solver(alg, prob.u0)
5351
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)
5452

55-
return FastLMCache(f!, J!, prob, alg, LM, solver,
53+
return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
5654
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
5755
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
5856
alg.maxfactor, kwargs...))
5957
end
6058

61-
function SciMLBase.solve!(cache::FastLMCache)
59+
function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
6260
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
6361
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
6462
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ using NonlinearSolve, SciMLBase
44
import ConcreteStructs: @concrete
55
import LeastSquaresOptim as LSO
66

7-
NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true
8-
9-
function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
7+
function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
108
ls = linsolve == :qr ? LSO.QR() :
119
(linsolve == :cholesky ? LSO.Cholesky() :
1210
(linsolve == :lsmr ? LSO.LSMR() : nothing))
@@ -19,7 +17,7 @@ function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
1917
end
2018
end
2119

22-
@concrete struct LeastSquaresOptimCache
20+
@concrete struct LeastSquaresOptimJLCache
2321
prob
2422
alg
2523
allocated_prob
@@ -34,8 +32,8 @@ end
3432
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
3533
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))
3634

37-
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;
38-
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
35+
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
36+
args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
3937
iip = SciMLBase.isinplace(prob)
4038

4139
f! = FunctionWrapper{iip}(prob.f, prob.p)
@@ -49,12 +47,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver
4947
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
5048
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
5149

52-
return LeastSquaresOptimCache(prob, alg, allocated_prob,
50+
return LeastSquaresOptimJLCache(prob, alg, allocated_prob,
5351
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
5452
kwargs...))
5553
end
5654

57-
function SciMLBase.solve!(cache::LeastSquaresOptimCache)
55+
function SciMLBase.solve!(cache::LeastSquaresOptimJLCache)
5856
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
5957
maxiters = cache.kwargs[:iterations]
6058
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :

src/NonlinearSolve.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm
3030

3131
abstract type AbstractNonlinearSolveCache{iip} end
3232

33-
extension_loaded(::Val) = false
34-
3533
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
3634

3735
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
@@ -62,7 +60,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
6260
end
6361

6462
include("utils.jl")
65-
include("algorithms.jl")
63+
include("extension_algs.jl")
6664
include("linesearch.jl")
6765
include("raphson.jl")
6866
include("trustRegion.jl")
@@ -96,7 +94,7 @@ end
9694
export RadiusUpdateSchemes
9795

9896
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
99-
export LSOptimSolver, FastLevenbergMarquardtSolver
97+
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
10098

10199
export LineSearch
102100

src/algorithms.jl renamed to src/extension_algs.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# Define Algorithms extended via extensions
1+
# This file only include the algorithm struct to be exported by LinearSolve.jl. The main
2+
# functionality is implemented as package extensions
23
"""
3-
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
4+
LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
45
5-
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
6-
`NonlinearLeastSquaresProblem`.
6+
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl)
7+
for solving `NonlinearLeastSquaresProblem`.
78
89
## Arguments:
910
@@ -16,25 +17,24 @@ Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquare
1617
!!! note
1718
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
1819
"""
19-
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
20+
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
2021
autodiff::Symbol
2122
end
2223

23-
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
24+
function LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
2425
@assert alg in (:lm, :dogleg)
2526
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
2627
@assert autodiff in (:central, :forward)
2728

28-
if !extension_loaded(Val(:LeastSquaresOptim))
29-
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \
30-
before `solve(prob, LSOptimSolver())` is called."
29+
if Base.get_extension(@__MODULE__, :NonlinearSolveLeastSquaresOptimExt) === nothing
30+
error("LeastSquaresOptimJL requires LeastSquaresOptim.jl to be loaded")
3131
end
3232

33-
return LSOptimSolver{alg, linsolve}(autodiff)
33+
return LeastSquaresOptimJL{alg, linsolve}(autodiff)
3434
end
3535

3636
"""
37-
FastLevenbergMarquardtSolver(linsolve = :cholesky)
37+
FastLevenbergMarquardtJL(linsolve = :cholesky)
3838
3939
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
4040
`NonlinearLeastSquaresProblem`.
@@ -53,7 +53,7 @@ Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenberg
5353
!!! note
5454
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
5555
"""
56-
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
56+
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
5757
factor
5858
factoraccept
5959
factorreject
@@ -64,17 +64,16 @@ Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenberg
6464
maxfactor
6565
end
6666

67-
function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
67+
function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6,
6868
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
6969
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
7070
@assert linsolve in (:qr, :cholesky)
7171
@assert factorupdate in (:marquardt, :nielson)
7272

73-
if !extension_loaded(Val(:FastLevenbergMarquardt))
74-
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \
75-
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
73+
if Base.get_extension(@__MODULE__, :NonlinearSolveFastLevenbergMarquardtExt) === nothing
74+
error("LeastSquaresOptimJL requires FastLevenbergMarquardt.jl to be loaded")
7675
end
7776

78-
return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
77+
return FastLevenbergMarquardtJL{linsolve}(factor, factoraccept, factorreject,
7978
factorupdate, minscale, maxscale, minfactor, maxfactor)
8079
end

test/nonlinear_least_squares.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]
30+
solvers = [GaussNewton(), LevenbergMarquardt(), LeastSquaresOptimJL(:lm),
31+
LeastSquaresOptimJL(:dogleg)]
3132

3233
for prob in nlls_problems, solver in solvers
3334
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@@ -43,7 +44,7 @@ end
4344
prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
4445
resid_prototype = zero(y_target), jac = jac!), θ_init, x)
4546

46-
solvers = [FastLevenbergMarquardtSolver(:cholesky), FastLevenbergMarquardtSolver(:qr)]
47+
solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)]
4748

4849
for solver in solvers
4950
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)