@@ -10,6 +10,11 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
10
10
algorithm for nonlinear least-squares minimization". Designed for large-scale and
11
11
numerically-difficult nonlinear systems.
12
12
13
+ If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
14
+ routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
15
+ "Chapter 10: Implementation of the Levenberg-Marquardt Method" of
16
+ ["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
17
+
13
18
### Keyword Arguments
14
19
15
20
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -104,7 +109,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
104
109
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
105
110
end
106
111
107
- @concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
112
+ @concrete mutable struct LevenbergMarquardtCache{iip, fastqr} < :
113
+ AbstractNonlinearSolveCache{iip}
108
114
f
109
115
alg
110
116
u
144
150
u_tmp
145
151
Jv
146
152
mat_tmp
153
+ rhs_tmp
154
+ J²
147
155
stats:: NLStats
148
156
end
149
157
@@ -155,8 +163,26 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
155
163
@unpack f, u0, p = prob
156
164
u = alias_u0 ? u0 : deepcopy (u0)
157
165
fu1 = evaluate_f (prob, u)
158
- uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches (alg, f, u, p, Val (iip);
159
- linsolve_kwargs, linsolve_with_JᵀJ = Val (true ))
166
+
167
+ # Use QR if the user did not specify a linear solver
168
+ if (alg. linsolve === nothing || alg. linsolve isa QRFactorization ||
169
+ alg. linsolve isa FastQRFactorization) && ! (u isa Number)
170
+ linsolve_with_JᵀJ = Val (false )
171
+ else
172
+ linsolve_with_JᵀJ = Val (true )
173
+ end
174
+
175
+ if _unwrap_val (linsolve_with_JᵀJ)
176
+ uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches (alg, f, u, p,
177
+ Val (iip); linsolve_kwargs, linsolve_with_JᵀJ)
178
+ J² = nothing
179
+ else
180
+ uf, linsolve, J, fu2, jac_cache, du = jacobian_caches (alg, f, u, p, Val (iip);
181
+ linsolve_kwargs, linsolve_with_JᵀJ)
182
+ JᵀJ = similar (u)
183
+ J² = similar (J)
184
+ v = similar (du)
185
+ end
160
186
161
187
λ = convert (eltype (u), alg. damping_initial)
162
188
λ_factor = convert (eltype (u), alg. damping_increase_factor)
@@ -182,16 +208,26 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
182
208
δ = _mutable_zero (u)
183
209
make_new_J = true
184
210
fu_tmp = zero (fu1)
185
- mat_tmp = zero (JᵀJ)
186
211
187
- return LevenbergMarquardtCache {iip} (f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
212
+ if _unwrap_val (linsolve_with_JᵀJ)
213
+ mat_tmp = zero (JᵀJ)
214
+ rhs_tmp = nothing
215
+ else
216
+ mat_tmp = similar (JᵀJ, length (fu1) + length (u), length (u))
217
+ fill! (mat_tmp, zero (eltype (u)))
218
+ rhs_tmp = similar (mat_tmp, length (fu1) + length (u))
219
+ fill! (rhs_tmp, zero (eltype (u)))
220
+ end
221
+
222
+ return LevenbergMarquardtCache {iip, !_unwrap_val(linsolve_with_JᵀJ)} (f, alg, u, fu1,
223
+ fu2, du, p, uf, linsolve, J,
188
224
jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol, prob, DᵀD,
189
225
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
190
226
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
191
- zero (u), zero (fu1), mat_tmp, NLStats (1 , 0 , 0 , 0 , 0 ))
227
+ zero (u), zero (fu1), mat_tmp, rhs_tmp, J², NLStats (1 , 0 , 0 , 0 , 0 ))
192
228
end
193
229
194
- function perform_step! (cache:: LevenbergMarquardtCache{true} )
230
+ function perform_step! (cache:: LevenbergMarquardtCache{true, fastqr} ) where {fastqr}
195
231
@unpack fu1, f, make_new_J = cache
196
232
if iszero (fu1)
197
233
cache. force_stop = true
@@ -200,35 +236,57 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
200
236
201
237
if make_new_J
202
238
jacobian!! (cache. J, cache)
203
- __matmul! (cache. JᵀJ, cache. J' , cache. J)
204
- cache. DᵀD .= max .(cache. DᵀD, Diagonal (cache. JᵀJ))
239
+ if fastqr
240
+ cache. J² .= cache. J .^ 2
241
+ sum! (cache. JᵀJ' , cache. J²)
242
+ cache. DᵀD. diag .= max .(cache. DᵀD. diag, cache. JᵀJ)
243
+ else
244
+ __matmul! (cache. JᵀJ, cache. J' , cache. J)
245
+ cache. DᵀD .= max .(cache. DᵀD, Diagonal (cache. JᵀJ))
246
+ end
205
247
cache. make_new_J = false
206
248
cache. stats. njacs += 1
207
249
end
208
250
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
209
251
210
252
# Usual Levenberg-Marquardt step ("velocity").
211
253
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
212
- mul! (_vec (cache. u_tmp), J' , _vec (fu1))
213
- @. cache. mat_tmp = JᵀJ + λ * DᵀD
214
- linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (cache. mat_tmp),
215
- b = _vec (cache. u_tmp), linu = _vec (cache. du), p = p, reltol = cache. abstol)
216
- cache. linsolve = linres. cache
217
- _vec (cache. v) .= - 1 .* _vec (cache. du)
254
+ if fastqr
255
+ cache. mat_tmp[1 : length (fu1), :] .= cache. J
256
+ cache. mat_tmp[(length (fu1) + 1 ): end , :] .= λ .* cache. DᵀD
257
+ cache. rhs_tmp[1 : length (fu1)] .= _vec (fu1)
258
+ linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
259
+ b = cache. rhs_tmp, linu = _vec (cache. du), p = p, reltol = cache. abstol)
260
+ _vec (cache. v) .= - _vec (cache. du)
261
+ else
262
+ mul! (_vec (cache. u_tmp), J' , _vec (fu1))
263
+ @. cache. mat_tmp = JᵀJ + λ * DᵀD
264
+ linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric (cache. mat_tmp),
265
+ b = _vec (cache. u_tmp), linu = _vec (cache. du), p = p, reltol = cache. abstol)
266
+ cache. linsolve = linres. cache
267
+ _vec (cache. v) .= - _vec (cache. du)
268
+ end
218
269
219
270
# Geodesic acceleration (step_size = v + a / 2).
220
271
@unpack v, α_geodesic, h = cache
221
- f (cache. fu_tmp, _restructure (u, _vec (u) .+ h .* _vec (v)), p)
272
+ cache. u_tmp .= _restructure (cache. u_tmp, _vec (u) .+ h .* _vec (v))
273
+ f (cache. fu_tmp, cache. u_tmp, p)
222
274
223
275
# The following lines do: cache.a = -J \ cache.fu_tmp
276
+ # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
224
277
mul! (_vec (cache. Jv), J, _vec (v))
225
278
@. cache. fu_tmp = (2 / h) * ((cache. fu_tmp - fu1) / h - cache. Jv)
226
- mul! (_vec (cache. u_tmp), J' , _vec (cache. fu_tmp))
227
- # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
228
- linres = dolinsolve (alg. precs, linsolve; b = _vec (cache. u_tmp),
229
- linu = _vec (cache. du), p = p, reltol = cache. abstol)
230
- cache. linsolve = linres. cache
231
- @. cache. a = - cache. du
279
+ if fastqr
280
+ cache. rhs_tmp[1 : length (fu1)] .= _vec (cache. fu_tmp)
281
+ linres = dolinsolve (alg. precs, linsolve; b = cache. rhs_tmp, linu = _vec (cache. du),
282
+ p = p, reltol = cache. abstol)
283
+ else
284
+ mul! (_vec (cache. u_tmp), J' , _vec (cache. fu_tmp))
285
+ linres = dolinsolve (alg. precs, linsolve; b = _vec (cache. u_tmp),
286
+ linu = _vec (cache. du), p = p, reltol = cache. abstol)
287
+ cache. linsolve = linres. cache
288
+ @. cache. a = - cache. du
289
+ end
232
290
cache. stats. nsolve += 2
233
291
cache. stats. nfactors += 2
234
292
@@ -263,7 +321,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
263
321
return nothing
264
322
end
265
323
266
- function perform_step! (cache:: LevenbergMarquardtCache{false} )
324
+ function perform_step! (cache:: LevenbergMarquardtCache{false, fastqr} ) where {fastqr}
267
325
@unpack fu1, f, make_new_J = cache
268
326
if iszero (fu1)
269
327
cache. force_stop = true
@@ -272,40 +330,55 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
272
330
273
331
if make_new_J
274
332
cache. J = jacobian!! (cache. J, cache)
275
- cache . JᵀJ = cache . J ' * cache . J
276
- if cache. JᵀJ isa Number
277
- cache. DᵀD = max (cache. DᵀD, cache. JᵀJ)
333
+ if fastqr
334
+ cache. JᵀJ = _vec ( sum (cache . J .^ 2 ; dims = 1 ))
335
+ cache. DᵀD. diag . = max . (cache. DᵀD. diag , cache. JᵀJ)
278
336
else
279
- cache. DᵀD .= max .(cache. DᵀD, Diagonal (cache. JᵀJ))
337
+ cache. JᵀJ = cache. J' * cache. J
338
+ if cache. JᵀJ isa Number
339
+ cache. DᵀD = max (cache. DᵀD, cache. JᵀJ)
340
+ else
341
+ cache. DᵀD .= max .(cache. DᵀD, Diagonal (cache. JᵀJ))
342
+ end
280
343
end
281
344
cache. make_new_J = false
282
345
cache. stats. njacs += 1
283
346
end
284
347
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
285
348
286
- cache. mat_tmp = JᵀJ + λ * DᵀD
287
349
# Usual Levenberg-Marquardt step ("velocity").
288
- if linsolve === nothing
289
- cache. v = - cache. mat_tmp \ (J' * fu1)
350
+ if fastqr
351
+ cache. mat_tmp = vcat (J, λ .* cache. DᵀD)
352
+ cache. rhs_tmp[1 : length (fu1)] .= - _vec (fu1)
353
+ linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
354
+ b = cache. rhs_tmp, linu = _vec (cache. v), p = p, reltol = cache. abstol)
290
355
else
291
- linres = dolinsolve (alg. precs, linsolve; A = - __maybe_symmetric (cache. mat_tmp),
292
- b = _vec (J' * _vec (fu1)), linu = _vec (cache. v), p, reltol = cache. abstol)
293
- cache. linsolve = linres. cache
356
+ cache. mat_tmp = JᵀJ + λ * DᵀD
357
+ if linsolve === nothing
358
+ cache. v = - cache. mat_tmp \ (J' * fu1)
359
+ else
360
+ linres = dolinsolve (alg. precs, linsolve; A = - __maybe_symmetric (cache. mat_tmp),
361
+ b = _vec (J' * _vec (fu1)), linu = _vec (cache. v), p, reltol = cache. abstol)
362
+ cache. linsolve = linres. cache
363
+ end
294
364
end
295
365
296
366
@unpack v, h, α_geodesic = cache
297
367
# Geodesic acceleration (step_size = v + a / 2).
298
- if linsolve === nothing
299
- cache . a = - cache . mat_tmp \
300
- _vec (J ' * (( 2 / h) .* (( f (u .+ h .* v, p) .- fu1) ./ h .- J * v)))
301
- else
368
+ rhs_term = _vec ((( 2 / h) .* (( _vec ( f (u .+ h .* _restructure (u, v), p)) .-
369
+ _vec (fu1)) ./ h .- J * _vec (v))))
370
+ if fastqr
371
+ cache . rhs_tmp[ 1 : length (fu1)] . = - _vec (rhs_term)
302
372
linres = dolinsolve (alg. precs, linsolve;
303
- b = _mutable (_vec (J' * # ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
304
- _vec (((2 / h) .*
305
- ((_vec (f (u .+ h .* _restructure (u, v), p)) .-
306
- _vec (fu1)) ./ h .- J * _vec (v)))))),
307
- linu = _vec (cache. a), p, reltol = cache. abstol)
308
- cache. linsolve = linres. cache
373
+ b = cache. rhs_tmp, linu = _vec (cache. a), p = p, reltol = cache. abstol)
374
+ else
375
+ if linsolve === nothing
376
+ cache. a = - cache. mat_tmp \ _vec (J' * rhs_term)
377
+ else
378
+ linres = dolinsolve (alg. precs, linsolve; b = _mutable (_vec (J' * rhs_term)),
379
+ linu = _vec (cache. a), p, reltol = cache. abstol)
380
+ cache. linsolve = linres. cache
381
+ end
309
382
end
310
383
cache. stats. nsolve += 1
311
384
cache. stats. nfactors += 1
0 commit comments