Skip to content

Commit 1c19fa7

Browse files
committed
Wrap FastLM.jl
1 parent f151a0a commit 1c19fa7

File tree

6 files changed

+155
-16
lines changed

6 files changed

+155
-16
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2626

2727
[weakdeps]
28+
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
2829
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
2930

3031
[extensions]
32+
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
3133
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
3234

3335
[compat]
@@ -37,6 +39,7 @@ ConcreteStructs = "0.2"
3739
DiffEqBase = "6.130"
3840
EnumX = "1"
3941
Enzyme = "0.11"
42+
FastLevenbergMarquardt = "0.1"
4043
FiniteDiff = "2"
4144
ForwardDiff = "0.10.3"
4245
LeastSquaresOptim = "0.8"
@@ -57,6 +60,7 @@ julia = "1.9"
5760
[extras]
5861
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5962
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
63+
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
6064
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6165
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
6266
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -72,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7276
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7377

7478
[targets]
75-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim"]
79+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
module NonlinearSolveFastLevenbergMarquardtExt
2+
3+
using ArrayInterface, NonlinearSolve, SciMLBase
4+
import ConcreteStructs: @concrete
5+
import FastLevenbergMarquardt as FastLM
6+
7+
NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true
8+
9+
function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
10+
if linsolve == :cholesky
11+
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
12+
elseif linsolve == :qr
13+
return FastLM.QRSolver(eltype(x), length(x))
14+
else
15+
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
16+
end
17+
end
18+
19+
@concrete struct FastLMCache
20+
f!
21+
J!
22+
prob
23+
alg
24+
lmworkspace
25+
solver
26+
kwargs
27+
end
28+
29+
@concrete struct InplaceFunction{iip} <: Function
30+
f
31+
end
32+
33+
(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
34+
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))
35+
36+
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
37+
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
38+
verbose = false, maxiters = 1000, kwargs...)
39+
iip = SciMLBase.isinplace(prob)
40+
41+
@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"
42+
43+
f! = InplaceFunction{iip}(prob.f)
44+
J! = InplaceFunction{iip}(prob.f.jac)
45+
46+
resid_prototype = prob.f.resid_prototype === nothing ?
47+
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
48+
prob.f.resid_prototype
49+
50+
J = similar(prob.u0, length(resid_prototype), length(prob.u0))
51+
52+
solver = _fast_lm_solver(alg, prob.u0)
53+
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)
54+
55+
return FastLMCache(f!, J!, prob, alg, LM, solver,
56+
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
57+
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
58+
alg.maxfactor, kwargs...))
59+
end
60+
61+
function SciMLBase.solve!(cache::FastLMCache)
62+
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
63+
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
64+
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
65+
retcode = info == 1 ? ReturnCode.Success :
66+
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
67+
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
68+
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
69+
end
70+
71+
end

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using NonlinearSolve, SciMLBase
44
import ConcreteStructs: @concrete
55
import LeastSquaresOptim as LSO
66

7-
extension_loaded(::Val{:LeastSquaresOptim}) = true
7+
NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true
88

99
function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
1010
ls = linsolve == :qr ? LSO.QR() :

src/NonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ end
9595

9696
export RadiusUpdateSchemes
9797

98-
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton, LSOptimSolver
98+
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
99+
export LSOptimSolver, FastLevenbergMarquardtSolver
99100

100101
export LineSearch
101102

src/algorithms.jl

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,70 @@ Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquare
1111
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
1212
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
1313
on the Jacobian structure.
14+
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
1415
1516
!!! note
1617
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
1718
"""
1819
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
1920
autodiff::Symbol
21+
end
22+
23+
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
24+
@assert alg in (:lm, :dogleg)
25+
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
26+
@assert autodiff in (:central, :forward)
27+
28+
if !extension_loaded(Val(:LeastSquaresOptim))
29+
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \
30+
before `solve(prob, LSOptimSolver())` is called."
31+
end
32+
33+
return LSOptimSolver{alg, linsolve}(autodiff)
34+
end
35+
36+
"""
37+
FastLevenbergMarquardtSolver(linsolve = :cholesky)
38+
39+
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
40+
`NonlinearLeastSquaresProblem`.
2041
21-
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
22-
@assert alg in (:lm, :dogleg)
23-
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
24-
@assert autodiff in (:central, :forward)
42+
!!! warning
43+
This is not really the fastest solver. It is called that since the original package
44+
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
2545
26-
return new{alg, linsolve}(autodiff)
46+
!!! warning
47+
This algorithm requires the jacobian function to be provided!
48+
49+
## Arguments:
50+
51+
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
52+
53+
!!! note
54+
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
55+
"""
56+
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
57+
factor
58+
factoraccept
59+
factorreject
60+
factorupdate::Symbol
61+
minscale
62+
maxscale
63+
minfactor
64+
maxfactor
65+
end
66+
67+
function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
68+
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
69+
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
70+
@assert linsolve in (:qr, :cholesky)
71+
@assert factorupdate in (:marquardt, :nielson)
72+
73+
if !extension_loaded(Val(:FastLevenbergMarquardt))
74+
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \
75+
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
2776
end
77+
78+
return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
79+
factorupdate, minscale, maxscale, minfactor, maxfactor)
2880
end

test/nonlinear_least_squares.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random
2-
import LeastSquaresOptim
1+
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random, ForwardDiff
2+
import FastLevenbergMarquardt, LeastSquaresOptim
33

44
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
55
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
@@ -27,15 +27,26 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = [
31-
GaussNewton(),
32-
LevenbergMarquardt(),
33-
LSOptimSolver(:lm),
34-
LSOptimSolver(:dogleg),
35-
]
30+
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]
3631

3732
for prob in nlls_problems, solver in solvers
3833
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
3934
@test SciMLBase.successful_retcode(sol)
4035
@test norm(sol.resid) < 1e-6
4136
end
37+
38+
function jac!(J, θ, p)
39+
ForwardDiff.jacobian!(J, resid -> loss_function(resid, θ, p), θ)
40+
return J
41+
end
42+
43+
prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
44+
resid_prototype = zero(y_target), jac = jac!), θ_init, x)
45+
46+
solvers = [FastLevenbergMarquardtSolver(:cholesky), FastLevenbergMarquardtSolver(:qr)]
47+
48+
for solver in solvers
49+
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
50+
@test SciMLBase.successful_retcode(sol)
51+
@test norm(sol.resid) < 1e-6
52+
end

0 commit comments

Comments
 (0)