Skip to content

Commit 1a0e5ee

Browse files
Merge pull request #236 from avik-pal/ap/lsoptim
Impoving NLS Solvers
2 parents a6af39c + 1c19fa7 commit 1a0e5ee

10 files changed

+305
-47
lines changed

Project.toml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.2.1"
4+
version = "2.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -24,15 +24,25 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2424
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2626

27+
[weakdeps]
28+
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
29+
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
30+
31+
[extensions]
32+
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
33+
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
34+
2735
[compat]
2836
ADTypes = "0.2"
2937
ArrayInterface = "6.0.24, 7"
3038
ConcreteStructs = "0.2"
3139
DiffEqBase = "6.130"
3240
EnumX = "1"
3341
Enzyme = "0.11"
42+
FastLevenbergMarquardt = "0.1"
3443
FiniteDiff = "2"
3544
ForwardDiff = "0.10.3"
45+
LeastSquaresOptim = "0.8"
3646
LineSearches = "7"
3747
LinearSolve = "2"
3848
NonlinearProblemLibrary = "0.1"
@@ -50,7 +60,9 @@ julia = "1.9"
5060
[extras]
5161
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5262
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
63+
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
5364
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
65+
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
5466
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5567
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
5668
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
@@ -64,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6476
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6577

6678
[targets]
67-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"]
79+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
module NonlinearSolveFastLevenbergMarquardtExt
2+
3+
using ArrayInterface, NonlinearSolve, SciMLBase
4+
import ConcreteStructs: @concrete
5+
import FastLevenbergMarquardt as FastLM
6+
7+
NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true
8+
9+
function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
10+
if linsolve == :cholesky
11+
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
12+
elseif linsolve == :qr
13+
return FastLM.QRSolver(eltype(x), length(x))
14+
else
15+
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
16+
end
17+
end
18+
19+
@concrete struct FastLMCache
20+
f!
21+
J!
22+
prob
23+
alg
24+
lmworkspace
25+
solver
26+
kwargs
27+
end
28+
29+
@concrete struct InplaceFunction{iip} <: Function
30+
f
31+
end
32+
33+
(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
34+
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))
35+
36+
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
37+
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
38+
verbose = false, maxiters = 1000, kwargs...)
39+
iip = SciMLBase.isinplace(prob)
40+
41+
@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"
42+
43+
f! = InplaceFunction{iip}(prob.f)
44+
J! = InplaceFunction{iip}(prob.f.jac)
45+
46+
resid_prototype = prob.f.resid_prototype === nothing ?
47+
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
48+
prob.f.resid_prototype
49+
50+
J = similar(prob.u0, length(resid_prototype), length(prob.u0))
51+
52+
solver = _fast_lm_solver(alg, prob.u0)
53+
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)
54+
55+
return FastLMCache(f!, J!, prob, alg, LM, solver,
56+
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
57+
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
58+
alg.maxfactor, kwargs...))
59+
end
60+
61+
function SciMLBase.solve!(cache::FastLMCache)
62+
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
63+
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
64+
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
65+
retcode = info == 1 ? ReturnCode.Success :
66+
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
67+
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
68+
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
69+
end
70+
71+
end
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
module NonlinearSolveLeastSquaresOptimExt
2+
3+
using NonlinearSolve, SciMLBase
4+
import ConcreteStructs: @concrete
5+
import LeastSquaresOptim as LSO
6+
7+
NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true
8+
9+
function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
10+
ls = linsolve == :qr ? LSO.QR() :
11+
(linsolve == :cholesky ? LSO.Cholesky() :
12+
(linsolve == :lsmr ? LSO.LSMR() : nothing))
13+
if alg == :lm
14+
return LSO.LevenbergMarquardt(ls)
15+
elseif alg == :dogleg
16+
return LSO.Dogleg(ls)
17+
else
18+
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))
19+
end
20+
end
21+
22+
@concrete struct LeastSquaresOptimCache
23+
prob
24+
alg
25+
allocated_prob
26+
kwargs
27+
end
28+
29+
@concrete struct FunctionWrapper{iip}
30+
f
31+
p
32+
end
33+
34+
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
35+
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))
36+
37+
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;
38+
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
39+
iip = SciMLBase.isinplace(prob)
40+
41+
f! = FunctionWrapper{iip}(prob.f, prob.p)
42+
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)
43+
44+
resid_prototype = prob.f.resid_prototype === nothing ?
45+
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
46+
prob.f.resid_prototype
47+
48+
lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!,
49+
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
50+
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
51+
52+
return LeastSquaresOptimCache(prob, alg, allocated_prob,
53+
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
54+
kwargs...))
55+
end
56+
57+
function SciMLBase.solve!(cache::LeastSquaresOptimCache)
58+
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
59+
maxiters = cache.kwargs[:iterations]
60+
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
61+
(res.iterations maxiters ? ReturnCode.MaxIters :
62+
ReturnCode.ConvergenceFailure)
63+
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)
64+
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2;
65+
retcode, original = res, stats)
66+
end
67+
68+
end

src/NonlinearSolve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm
3030

3131
abstract type AbstractNonlinearSolveCache{iip} end
3232

33+
extension_loaded(::Val) = false
34+
3335
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
3436

3537
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
@@ -60,6 +62,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
6062
end
6163

6264
include("utils.jl")
65+
include("algorithms.jl")
6366
include("linesearch.jl")
6467
include("raphson.jl")
6568
include("trustRegion.jl")
@@ -93,6 +96,7 @@ end
9396
export RadiusUpdateSchemes
9497

9598
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
99+
export LSOptimSolver, FastLevenbergMarquardtSolver
96100

97101
export LineSearch
98102

src/algorithms.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Define Algorithms extended via extensions
2+
"""
3+
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
4+
5+
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
6+
`NonlinearLeastSquaresProblem`.
7+
8+
## Arguments:
9+
10+
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
11+
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
12+
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
13+
on the Jacobian structure.
14+
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
15+
16+
!!! note
17+
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
18+
"""
19+
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
20+
autodiff::Symbol
21+
end
22+
23+
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
24+
@assert alg in (:lm, :dogleg)
25+
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
26+
@assert autodiff in (:central, :forward)
27+
28+
if !extension_loaded(Val(:LeastSquaresOptim))
29+
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \
30+
before `solve(prob, LSOptimSolver())` is called."
31+
end
32+
33+
return LSOptimSolver{alg, linsolve}(autodiff)
34+
end
35+
36+
"""
37+
FastLevenbergMarquardtSolver(linsolve = :cholesky)
38+
39+
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
40+
`NonlinearLeastSquaresProblem`.
41+
42+
!!! warning
43+
This is not really the fastest solver. It is called that since the original package
44+
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
45+
46+
!!! warning
47+
This algorithm requires the jacobian function to be provided!
48+
49+
## Arguments:
50+
51+
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
52+
53+
!!! note
54+
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
55+
"""
56+
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
57+
factor
58+
factoraccept
59+
factorreject
60+
factorupdate::Symbol
61+
minscale
62+
maxscale
63+
minfactor
64+
maxfactor
65+
end
66+
67+
function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
68+
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
69+
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
70+
@assert linsolve in (:qr, :cholesky)
71+
@assert factorupdate in (:marquardt, :nielson)
72+
73+
if !extension_loaded(Val(:FastLevenbergMarquardt))
74+
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \
75+
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
76+
end
77+
78+
return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
79+
factorupdate, minscale, maxscale, minfactor, maxfactor)
80+
end

src/gaussnewton.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
3-
adkwargs...)
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
3+
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced GaussNewton implementation with support for efficient handling of sparse
66
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
@@ -41,7 +41,7 @@ for large-scale and numerically-difficult nonlinear least squares problems.
4141
precs
4242
end
4343

44-
function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
44+
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
4545
precs = DEFAULT_PRECS, adkwargs...)
4646
ad = default_adargs_to_adtype(; adkwargs...)
4747
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
@@ -93,12 +93,12 @@ end
9393
function perform_step!(cache::GaussNewtonCache{true})
9494
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
9595
jacobian!!(J, cache)
96-
mul!(JᵀJ, J', J)
97-
mul!(Jᵀf, J', fu1)
96+
__matmul!(JᵀJ, J', J)
97+
__matmul!(Jᵀf, J', fu1)
9898

9999
# u = u - J \ fu
100-
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
101-
p, reltol = cache.abstol)
100+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
101+
linu = _vec(du), p, reltol = cache.abstol)
102102
cache.linsolve = linres.cache
103103
@. u = u - du
104104
f(cache.fu_new, u, p)
@@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
125125
if linsolve === nothing
126126
cache.du = fu1 / cache.J
127127
else
128-
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
129-
linu = _vec(cache.du), p, reltol = cache.abstol)
128+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
129+
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
130130
cache.linsolve = linres.cache
131131
end
132132
cache.u = @. u - cache.du # `u` might not support mutation

src/jacobian.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
6565
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
6666
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
6767
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
68-
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
68+
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
6969
sd = sparsity_detection_alg(f, alg.ad)
7070
ad = alg.ad
7171
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
@@ -74,7 +74,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
7474
jac_cache = nothing
7575
end
7676

77-
J = if !(linsolve_needs_jac || alg_wants_jac)
77+
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
78+
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
79+
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
7880
# We don't need to construct the Jacobian
7981
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
8082
else
@@ -93,14 +95,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9395
Jᵀfu = J' * fu
9496
end
9597

96-
linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
97-
u0 = _vec(du))
98+
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
99+
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))
98100

99101
weight = similar(u)
100102
recursivefill!(weight, true)
101103

102-
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
103-
nothing)..., weight)
104+
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
105+
nothing, nothing, nothing, nothing, nothing)..., weight)
104106
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
105107
linsolve_kwargs...)
106108

@@ -114,9 +116,15 @@ __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
114116
__get_nonsparse_ad(ad) = ad
115117

116118
__init_JᵀJ(J::Number) = zero(J)
117-
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
119+
__init_JᵀJ(J::AbstractArray) = J' * J
118120
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
119121

122+
__maybe_symmetric(x) = Symmetric(x)
123+
__maybe_symmetric(x::Number) = x
124+
# LinearSolve with `nothing` doesn't dispatch correctly here
125+
__maybe_symmetric(x::StaticArray) = x
126+
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
127+
120128
## Special Handling for Scalars
121129
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
122130
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),

0 commit comments

Comments
 (0)