@@ -158,7 +158,6 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
158
158
J:: jType
159
159
du_tmp:: duType
160
160
jac_config:: JC
161
- iter:: Int
162
161
force_stop:: Bool
163
162
maxiters:: Int
164
163
internalnorm:: INType
@@ -185,10 +184,11 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
185
184
make_new_J:: Bool
186
185
fu_tmp:: resType
187
186
mat_tmp:: jType
187
+ stats:: NLStats
188
188
189
189
function LevenbergMarquardtCache {iip} (f:: fType , alg:: algType , u:: uType , fu:: resType ,
190
190
p:: pType , uf:: ufType , linsolve:: L , J:: jType ,
191
- du_tmp:: duType , jac_config:: JC , iter :: Int ,
191
+ du_tmp:: duType , jac_config:: JC ,
192
192
force_stop:: Bool , maxiters:: Int ,
193
193
internalnorm:: INType ,
194
194
retcode:: SciMLBase.ReturnCode.T , abstol:: tolType ,
@@ -202,7 +202,7 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
202
202
norm_v_old:: lossType , δ:: uType ,
203
203
loss_old:: lossType , make_new_J:: Bool ,
204
204
fu_tmp:: resType ,
205
- mat_tmp:: jType ) where {
205
+ mat_tmp:: jType , stats :: NLStats ) where {
206
206
iip, fType, algType,
207
207
uType, duType, resType,
208
208
pType, INType, tolType,
@@ -213,15 +213,15 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
213
213
new{iip, fType, algType, uType, duType, resType,
214
214
pType, INType, tolType, probType, ufType, L,
215
215
jType, JC, DᵀDType, λType, lossType}(f, alg, u, fu, p, uf, linsolve, J, du_tmp,
216
- jac_config, iter, force_stop, maxiters,
216
+ jac_config, force_stop, maxiters,
217
217
internalnorm, retcode, abstol, prob, DᵀD,
218
218
JᵀJ, λ, λ_factor,
219
219
damping_increase_factor,
220
220
damping_decrease_factor, h,
221
221
α_geodesic, b_uphill, min_damping_D,
222
222
v, a, tmp_vec, v_old,
223
223
norm_v_old, δ, loss_old, make_new_J,
224
- fu_tmp, mat_tmp)
224
+ fu_tmp, mat_tmp, stats )
225
225
end
226
226
end
227
227
@@ -301,14 +301,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
301
301
mat_tmp = zero (J)
302
302
303
303
return LevenbergMarquardtCache {iip} (f, alg, u, fu, p, uf, linsolve, J, du_tmp,
304
- jac_config,
305
- 1 , false , maxiters, internalnorm,
304
+ jac_config, false , maxiters, internalnorm,
306
305
ReturnCode. Default, abstol, prob, DᵀD, JᵀJ,
307
306
λ, λ_factor, damping_increase_factor,
308
307
damping_decrease_factor, h,
309
308
α_geodesic, b_uphill, min_damping_D,
310
309
v, a, tmp_vec, v_old, loss, δ, loss, make_new_J,
311
- fu_tmp, mat_tmp)
310
+ fu_tmp, mat_tmp, NLStats ( 1 , 0 , 0 , 0 , 0 ) )
312
311
end
313
312
function perform_step! (cache:: LevenbergMarquardtCache{true} )
314
313
@unpack fu, f, make_new_J = cache
@@ -321,6 +320,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
321
320
mul! (cache. JᵀJ, cache. J' , cache. J)
322
321
cache. DᵀD .= max .(cache. DᵀD, Diagonal (cache. JᵀJ))
323
322
cache. make_new_J = false
323
+ cache. stats. njacs += 1
324
324
end
325
325
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
326
326
@@ -344,13 +344,16 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
344
344
linu = _vec (cache. du_tmp), p = p, reltol = cache. abstol)
345
345
cache. linsolve = linres. cache
346
346
@. cache. a = - cache. du_tmp
347
+ cache. stats. nsolve += 2
348
+ cache. stats. nfactors += 2
347
349
348
350
# Require acceptable steps to satisfy the following condition.
349
351
norm_v = norm (v)
350
352
if (2 * norm (cache. a) / norm_v) < α_geodesic
351
353
@. cache. δ = v + cache. a / 2
352
354
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
353
355
f (cache. fu_tmp, u .+ δ, p)
356
+ cache. stats. nf += 1
354
357
loss = cache. internalnorm (cache. fu_tmp)
355
358
356
359
# Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
@@ -390,6 +393,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
390
393
cache. DᵀD .= max .(cache. DᵀD, Diagonal (cache. JᵀJ))
391
394
end
392
395
cache. make_new_J = false
396
+ cache. stats. njacs += 1
393
397
end
394
398
@unpack u, p, λ, JᵀJ, DᵀD, J = cache
395
399
@@ -399,13 +403,16 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
399
403
@unpack v, h, α_geodesic = cache
400
404
# Geodesic acceleration (step_size = v + a / 2).
401
405
cache. a = - J \ ((2 / h) .* ((f (u .+ h .* v, p) .- fu) ./ h .- J * v))
406
+ cache. stats. nsolve += 1
407
+ cache. stats. nfactors += 1
402
408
403
409
# Require acceptable steps to satisfy the following condition.
404
410
norm_v = norm (v)
405
411
if (2 * norm (cache. a) / norm_v) < α_geodesic
406
412
cache. δ = v .+ cache. a ./ 2
407
413
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
408
414
fu_new = f (u .+ δ, p)
415
+ cache. stats. nf += 1
409
416
loss = cache. internalnorm (fu_new)
410
417
411
418
# Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
@@ -431,17 +438,17 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
431
438
end
432
439
433
440
function SciMLBase. solve! (cache:: LevenbergMarquardtCache )
434
- while ! cache. force_stop && cache. iter < cache. maxiters
441
+ while ! cache. force_stop && cache. stats . nsteps < cache. maxiters
435
442
perform_step! (cache)
436
- cache. iter += 1
443
+ cache. stats . nsteps += 1
437
444
end
438
445
439
- if cache. iter == cache. maxiters
446
+ if cache. stats . nsteps == cache. maxiters
440
447
cache. retcode = ReturnCode. MaxIters
441
448
else
442
449
cache. retcode = ReturnCode. Success
443
450
end
444
451
445
452
SciMLBase. build_solution (cache. prob, cache. alg, cache. u, cache. fu;
446
- retcode = cache. retcode)
453
+ retcode = cache. retcode, stats = cache . stats )
447
454
end
0 commit comments