@@ -10,10 +10,24 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
10
10
algorithm for nonlinear least-squares minimization". Designed for large-scale and
11
11
numerically-difficult nonlinear systems.
12
12
13
- If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
14
- routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
15
- "Chapter 10: Implementation of the Levenberg-Marquardt Method" of
16
- ["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
13
+ ### How to Choose the Linear Solver?
14
+
15
+ There are 2 ways to perform the LM Step
16
+
17
+ 1. Solve `(JᵀJ + λDᵀD) δx = Jᵀf` directly using a linear solver
18
+ 2. Solve for `Jδx = f` and `√λ⋅D δx = 0` simultaneously (to derive this simply compute the
19
+ normal form for this)
20
+
21
+ The second form tends to be more robust and can be solved using any Least Squares Solver.
22
+ If no `linsolve` or a least squares solver is provided, then we will solve the 2nd form.
23
+ However, in most cases, this means losing structure in `J` which is not ideal. Note that
24
+ whatever you do, do not specify solvers like `linsolve = NormalCholeskyFactorization()` or
25
+ any such solver which converts the equation to normal form before solving. These don't use
26
+ cache efficiently and we already support the normal form natively.
27
+
28
+ Additionally, note that the first form leads to a positive definite system, so we can use
29
+ more efficient solvers like `linsolve = CholeskyFactorization()`. If you know that the
30
+ problem is very well conditioned, then you might want to solve the normal form directly.
17
31
18
32
### Keyword Arguments
19
33
@@ -168,7 +182,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
168
182
T = eltype (u)
169
183
fu = evaluate_f (prob, u)
170
184
171
- fastls = ! __needs_square_A (alg, u0)
185
+ fastls = prob isa NonlinearProblem && ! __needs_square_A (alg, u0)
172
186
173
187
if ! fastls
174
188
uf, linsolve, J, fu_cache, jac_cache, du, JᵀJ, v = jacobian_caches (alg, f, u, p,
@@ -253,9 +267,9 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
253
267
if fastls
254
268
if setindex_trait (cache. mat_tmp) === CanSetindex ()
255
269
copyto! (@view (cache. mat_tmp[1 : length (cache. fu), :]), cache. J)
256
- cache. mat_tmp[(length (cache. fu) + 1 ): end , :] .= cache. λ .* cache. DᵀD
270
+ cache. mat_tmp[(length (cache. fu) + 1 ): end , :] .= sqrt .( cache. λ .* cache. DᵀD)
257
271
else
258
- cache. mat_tmp = _vcat (cache. J, cache. λ .* cache. DᵀD)
272
+ cache. mat_tmp = _vcat (cache. J, sqrt .( cache. λ .* cache. DᵀD) )
259
273
end
260
274
if setindex_trait (cache. rhs_tmp) === CanSetindex ()
261
275
cache. rhs_tmp[1 : length (cache. fu)] .= _vec (cache. fu)
@@ -283,7 +297,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
283
297
evaluate_f (cache, cache. u_cache_2, cache. p, Val (:fu_cache_2 ))
284
298
285
299
# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
286
- # NOTE: Don't pass `A`` in again, since we want to reuse the previous solve
300
+ # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
287
301
@bb cache. Jv = cache. J × vec (cache. v)
288
302
Jv = _restructure (cache. fu_cache_2, cache. Jv)
289
303
@bb @. cache. fu_cache_2 = (2 / cache. h) * ((cache. fu_cache_2 - cache. fu) / cache. h - Jv)
@@ -337,6 +351,33 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
337
351
return nothing
338
352
end
339
353
354
+ @inline __update_LM_diagonal!! (y:: Number , x:: Number ) = max (y, x)
355
+ @inline function __update_LM_diagonal!! (y:: Diagonal , x:: AbstractVector )
356
+ if setindex_trait (y. diag) === CanSetindex ()
357
+ @. y. diag = max (y. diag, x)
358
+ return y
359
+ else
360
+ return Diagonal (max .(y. diag, x))
361
+ end
362
+ end
363
+ @inline function __update_LM_diagonal!! (y:: Diagonal , x:: AbstractMatrix )
364
+ if setindex_trait (y. diag) === CanSetindex ()
365
+ if fast_scalar_indexing (y. diag)
366
+ @inbounds for i in axes (x, 1 )
367
+ y. diag[i] = max (y. diag[i], x[i, i])
368
+ end
369
+ return y
370
+ else
371
+ idxs = diagind (x)
372
+ @. . broadcast= false y. diag= max (y. diag, @view (x[idxs]))
373
+ return y
374
+ end
375
+ else
376
+ idxs = diagind (x)
377
+ return Diagonal (@. . broadcast= false max (y. diag, @view (x[idxs])))
378
+ end
379
+ end
380
+
340
381
function __reinit_internal! (cache:: LevenbergMarquardtCache ;
341
382
termination_condition = get_termination_mode (cache. tc_cache_1), kwargs... )
342
383
abstol, reltol, tc_cache_1 = init_termination_cache (cache. abstol, cache. reltol,
0 commit comments