Skip to content

Commit bca623f

Browse files
committed
More Robust QR for LM
1 parent 960dad5 commit bca623f

File tree

5 files changed

+131
-54
lines changed

5 files changed

+131
-54
lines changed

docs/src/solvers/NonlinearLeastSquaresSolvers.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ Solves the nonlinear least squares problem defined by `prob` using the algorithm
1919
handling of sparse matrices via colored automatic differentiation and preconditioned
2020
linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares
2121
problems.
22-
- `SimpleNewtonRaphson()`: Newton Raphson implementation that uses Linear Least Squares
23-
solution at every step to compute the descent direction. **WARNING**: This method is not
24-
a robust solver for nonlinear least squares problems. The computed delta step might not
25-
be the correct descent direction!
22+
- `SimpleNewtonRaphson()`: Simple Gauss Newton Implementation with `QRFactorization` to
23+
solve a linear least squares problem at each step!
2624

2725
## Example usage
2826

src/NonlinearSolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ import PrecompileTools
8888
for T in (Float32, Float64)
8989
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
9090

91-
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
92-
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
91+
# precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
92+
# PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
93+
# DON'T MERGE
94+
precompile_algs = ()
9395

9496
for alg in precompile_algs
9597
solve(prob, alg, abstol = T(1e-2))

src/levenberg.jl

Lines changed: 116 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
1010
algorithm for nonlinear least-squares minimization". Designed for large-scale and
1111
numerically-difficult nonlinear systems.
1212
13+
If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
14+
routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
15+
"Chapter 10: Implementation of the Levenberg-Marquardt Method" of
16+
["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
17+
1318
### Keyword Arguments
1419
1520
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -104,7 +109,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
104109
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
105110
end
106111

107-
@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
112+
@concrete mutable struct LevenbergMarquardtCache{iip, fastqr} <:
113+
AbstractNonlinearSolveCache{iip}
108114
f
109115
alg
110116
u
@@ -144,6 +150,8 @@ end
144150
u_tmp
145151
Jv
146152
mat_tmp
153+
rhs_tmp
154+
147155
stats::NLStats
148156
end
149157

@@ -155,8 +163,26 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
155163
@unpack f, u0, p = prob
156164
u = alias_u0 ? u0 : deepcopy(u0)
157165
fu1 = evaluate_f(prob, u)
158-
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, Val(iip);
159-
linsolve_kwargs, linsolve_with_JᵀJ = Val(true))
166+
167+
# Use QR if the user did not specify a linear solver
168+
if (alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
169+
alg.linsolve isa FastQRFactorization) && !(u isa Number)
170+
linsolve_with_JᵀJ = Val(false)
171+
else
172+
linsolve_with_JᵀJ = Val(true)
173+
end
174+
175+
if _unwrap_val(linsolve_with_JᵀJ)
176+
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
177+
Val(iip); linsolve_kwargs, linsolve_with_JᵀJ)
178+
= nothing
179+
else
180+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
181+
linsolve_kwargs, linsolve_with_JᵀJ)
182+
JᵀJ = similar(u)
183+
= similar(J)
184+
v = similar(du)
185+
end
160186

161187
λ = convert(eltype(u), alg.damping_initial)
162188
λ_factor = convert(eltype(u), alg.damping_increase_factor)
@@ -182,16 +208,26 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
182208
δ = _mutable_zero(u)
183209
make_new_J = true
184210
fu_tmp = zero(fu1)
185-
mat_tmp = zero(JᵀJ)
186211

187-
return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
212+
if _unwrap_val(linsolve_with_JᵀJ)
213+
mat_tmp = zero(JᵀJ)
214+
rhs_tmp = nothing
215+
else
216+
mat_tmp = similar(JᵀJ, length(fu1) + length(u), length(u))
217+
fill!(mat_tmp, zero(eltype(u)))
218+
rhs_tmp = similar(mat_tmp, length(fu1) + length(u))
219+
fill!(rhs_tmp, zero(eltype(u)))
220+
end
221+
222+
return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, fu1,
223+
fu2, du, p, uf, linsolve, J,
188224
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
189225
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
190226
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
191-
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
227+
zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0))
192228
end
193229

194-
function perform_step!(cache::LevenbergMarquardtCache{true})
230+
function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fastqr}
195231
@unpack fu1, f, make_new_J = cache
196232
if iszero(fu1)
197233
cache.force_stop = true
@@ -200,35 +236,57 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
200236

201237
if make_new_J
202238
jacobian!!(cache.J, cache)
203-
__matmul!(cache.JᵀJ, cache.J', cache.J)
204-
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
239+
if fastqr
240+
cache.J² .= cache.J .^ 2
241+
sum!(cache.JᵀJ', cache.J²)
242+
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
243+
else
244+
__matmul!(cache.JᵀJ, cache.J', cache.J)
245+
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
246+
end
205247
cache.make_new_J = false
206248
cache.stats.njacs += 1
207249
end
208250
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
209251

210252
# Usual Levenberg-Marquardt step ("velocity").
211253
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
212-
mul!(_vec(cache.u_tmp), J', _vec(fu1))
213-
@. cache.mat_tmp = JᵀJ + λ * DᵀD
214-
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
215-
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
216-
cache.linsolve = linres.cache
217-
_vec(cache.v) .= -1 .* _vec(cache.du)
254+
if fastqr
255+
cache.mat_tmp[1:length(fu1), :] .= cache.J
256+
cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD
257+
cache.rhs_tmp[1:length(fu1)] .= _vec(fu1)
258+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
259+
b = cache.rhs_tmp, linu = _vec(cache.du), p = p, reltol = cache.abstol)
260+
_vec(cache.v) .= -_vec(cache.du)
261+
else
262+
mul!(_vec(cache.u_tmp), J', _vec(fu1))
263+
@. cache.mat_tmp = JᵀJ + λ * DᵀD
264+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
265+
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
266+
cache.linsolve = linres.cache
267+
_vec(cache.v) .= -_vec(cache.du)
268+
end
218269

219270
# Geodesic acceleration (step_size = v + a / 2).
220271
@unpack v, α_geodesic, h = cache
221-
f(cache.fu_tmp, _restructure(u, _vec(u) .+ h .* _vec(v)), p)
272+
cache.u_tmp .= _restructure(cache.u_tmp, _vec(u) .+ h .* _vec(v))
273+
f(cache.fu_tmp, cache.u_tmp, p)
222274

223275
# The following lines do: cache.a = -J \ cache.fu_tmp
276+
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
224277
mul!(_vec(cache.Jv), J, _vec(v))
225278
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
226-
mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
227-
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
228-
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
229-
linu = _vec(cache.du), p = p, reltol = cache.abstol)
230-
cache.linsolve = linres.cache
231-
@. cache.a = -cache.du
279+
if fastqr
280+
cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp)
281+
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du),
282+
p = p, reltol = cache.abstol)
283+
else
284+
mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
285+
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
286+
linu = _vec(cache.du), p = p, reltol = cache.abstol)
287+
cache.linsolve = linres.cache
288+
@. cache.a = -cache.du
289+
end
232290
cache.stats.nsolve += 2
233291
cache.stats.nfactors += 2
234292

@@ -263,7 +321,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
263321
return nothing
264322
end
265323

266-
function perform_step!(cache::LevenbergMarquardtCache{false})
324+
function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fastqr}
267325
@unpack fu1, f, make_new_J = cache
268326
if iszero(fu1)
269327
cache.force_stop = true
@@ -272,40 +330,55 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
272330

273331
if make_new_J
274332
cache.J = jacobian!!(cache.J, cache)
275-
cache.JᵀJ = cache.J' * cache.J
276-
if cache.JᵀJ isa Number
277-
cache.DᵀD = max(cache.DᵀD, cache.JᵀJ)
333+
if fastqr
334+
cache.JᵀJ = _vec(sum(cache.J .^ 2; dims = 1))
335+
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
278336
else
279-
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
337+
cache.JᵀJ = cache.J' * cache.J
338+
if cache.JᵀJ isa Number
339+
cache.DᵀD = max(cache.DᵀD, cache.JᵀJ)
340+
else
341+
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
342+
end
280343
end
281344
cache.make_new_J = false
282345
cache.stats.njacs += 1
283346
end
284347
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
285348

286-
cache.mat_tmp = JᵀJ + λ * DᵀD
287349
# Usual Levenberg-Marquardt step ("velocity").
288-
if linsolve === nothing
289-
cache.v = -cache.mat_tmp \ (J' * fu1)
350+
if fastqr
351+
cache.mat_tmp = vcat(J, λ .* cache.DᵀD)
352+
cache.rhs_tmp[1:length(fu1)] .= -_vec(fu1)
353+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
354+
b = cache.rhs_tmp, linu = _vec(cache.v), p = p, reltol = cache.abstol)
290355
else
291-
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
292-
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
293-
cache.linsolve = linres.cache
356+
cache.mat_tmp = JᵀJ + λ * DᵀD
357+
if linsolve === nothing
358+
cache.v = -cache.mat_tmp \ (J' * fu1)
359+
else
360+
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
361+
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
362+
cache.linsolve = linres.cache
363+
end
294364
end
295365

296366
@unpack v, h, α_geodesic = cache
297367
# Geodesic acceleration (step_size = v + a / 2).
298-
if linsolve === nothing
299-
cache.a = -cache.mat_tmp \
300-
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
301-
else
368+
rhs_term = _vec(((2 / h) .* ((_vec(f(u .+ h .* _restructure(u, v), p)) .-
369+
_vec(fu1)) ./ h .- J * _vec(v))))
370+
if fastqr
371+
cache.rhs_tmp[1:length(fu1)] .= -_vec(rhs_term)
302372
linres = dolinsolve(alg.precs, linsolve;
303-
b = _mutable(_vec(J' * #((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
304-
_vec(((2 / h) .*
305-
((_vec(f(u .+ h .* _restructure(u, v), p)) .-
306-
_vec(fu1)) ./ h .- J * _vec(v)))))),
307-
linu = _vec(cache.a), p, reltol = cache.abstol)
308-
cache.linsolve = linres.cache
373+
b = cache.rhs_tmp, linu = _vec(cache.a), p = p, reltol = cache.abstol)
374+
else
375+
if linsolve === nothing
376+
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)
377+
else
378+
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),
379+
linu = _vec(cache.a), p, reltol = cache.abstol)
380+
cache.linsolve = linres.cache
381+
end
309382
end
310383
cache.stats.nsolve += 1
311384
cache.stats.nfactors += 1

src/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ function value_derivative(f::F, x::R) where {F, R}
6161
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
6262
end
6363

64-
# Todo: improve this dispatch
6564
function value_derivative(f::F, x::SVector) where {F}
6665
f(x), ForwardDiff.jacobian(f, x)
6766
end
@@ -206,8 +205,7 @@ function __get_concrete_algorithm(alg, prob)
206205
# Use Finite Differencing
207206
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
208207
else
209-
use_sparse_ad ? AutoSparseForwardDiff() :
210-
AutoForwardDiff{ForwardDiff.pickchunksize(length(prob.u0)), Nothing}(nothing)
208+
use_sparse_ad ? AutoSparseForwardDiff() : AutoForwardDiff{nothing, Nothing}(nothing)
211209
end
212210
return set_ad(alg, ad)
213211
end

test/nonlinear_least_squares.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2727
resid_prototype = zero(y_target)), θ_init, x)
2828

2929
nlls_problems = [prob_oop, prob_iip]
30-
solvers = [GaussNewton(), LevenbergMarquardt(), LeastSquaresOptimJL(:lm),
31-
LeastSquaresOptimJL(:dogleg)]
30+
solvers = [
31+
GaussNewton(),
32+
GaussNewton(; linsolve = CholeskyFactorization()),
33+
LevenbergMarquardt(),
34+
LevenbergMarquardt(; linsolve = CholeskyFactorization()),
35+
LeastSquaresOptimJL(:lm),
36+
LeastSquaresOptimJL(:dogleg),
37+
]
3238

3339
for prob in nlls_problems, solver in solvers
3440
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)