@@ -134,7 +134,9 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
134
134
trustType, suType, su2Type, tmpType}
135
135
f:: fType
136
136
alg:: algType
137
+ u_prev:: uType
137
138
u:: uType
139
+ fu_prev:: resType
138
140
fu:: resType
139
141
p:: pType
140
142
uf:: ufType
@@ -172,7 +174,8 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
172
174
ϵ:: floatType
173
175
stats:: NLStats
174
176
175
- function TrustRegionCache {iip} (f:: fType , alg:: algType , u:: uType , fu:: resType , p:: pType ,
177
+ function TrustRegionCache {iip} (f:: fType , alg:: algType , u_prev:: uType , u:: uType ,
178
+ fu_prev:: resType , fu:: resType , p:: pType ,
176
179
uf:: ufType , linsolve:: L , J:: jType , jac_config:: JC ,
177
180
force_stop:: Bool , maxiters:: Int , internalnorm:: INType ,
178
181
retcode:: SciMLBase.ReturnCode.T , abstol:: tolType ,
@@ -194,7 +197,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
194
197
suType, su2Type, tmpType}
195
198
new{iip, fType, algType, uType, resType, pType,
196
199
INType, tolType, probType, ufType, L, jType, JC, floatType,
197
- trustType, suType, su2Type, tmpType}(f, alg, u , fu, p, uf, linsolve, J,
200
+ trustType, suType, su2Type, tmpType}(f, alg, u_prev, u, fu_prev , fu, p, uf, linsolve, J,
198
201
jac_config, force_stop,
199
202
maxiters, internalnorm, retcode,
200
203
abstol, prob, radius_update_scheme,
@@ -246,6 +249,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
246
249
else
247
250
u = deepcopy (prob. u0)
248
251
end
252
+ u_prev = zero (u)
249
253
f = prob. f
250
254
p = prob. p
251
255
if iip
@@ -254,6 +258,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
254
258
else
255
259
fu = f (u, p)
256
260
end
261
+ fu_prev = zero (fu)
257
262
258
263
loss = get_loss (fu)
259
264
uf, linsolve, J, u_tmp, jac_config = jacobian_caches (alg, f, u, p, Val (iip))
@@ -325,9 +330,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
325
330
p3 = convert (eltype (u), 12 ) # c6
326
331
p4 = convert (eltype (u), 1.0e18 ) # M
327
332
initial_trust_radius = convert (eltype (u), p1 * (norm (fu)^ 0.99 ))
333
+ elseif radius_update_scheme === RadiusUpdateSchemes. Bastin
334
+ step_threshold = convert (eltype (u), 0.05 )
335
+ shrink_threshold = convert (eltype (u), 0.05 )
336
+ expand_threshold = convert (eltype (u), 0.9 )
337
+ p1 = convert (eltype (u), 2.5 ) # alpha_1
338
+ p2 = convert (eltype (u), 0.25 ) # alpha_2
339
+ p3 = convert (eltype (u), 0 ) # not required
340
+ p4 = convert (eltype (u), 0 ) # not required
341
+ initial_trust_radius = convert (eltype (u), 1.0 )
328
342
end
329
343
330
- return TrustRegionCache {iip} (f, alg, u, fu, p, uf, linsolve, J, jac_config,
344
+ return TrustRegionCache {iip} (f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
345
+ jac_config,
331
346
false , maxiters, internalnorm,
332
347
ReturnCode. Default, abstol, prob, radius_update_scheme,
333
348
initial_trust_radius,
@@ -388,6 +403,30 @@ function perform_step!(cache::TrustRegionCache{false})
388
403
return nothing
389
404
end
390
405
406
+ function retrospective_step! (cache:: TrustRegionCache{true} )
407
+ @unpack J, fu_prev, fu, u_prev, u = cache
408
+ jacobian! (J, cache)
409
+ mul! (cache. H, J, J)
410
+ mul! (cache. g, J, fu)
411
+ cache. stats. njacs += 1
412
+ @unpack H, g, step_size = cache
413
+
414
+ return - (get_loss (fu_prev) - get_loss (fu)) /
415
+ (step_size' * g + step_size' * H * step_size / 2 )
416
+ end
417
+
418
+ function retrospective_step! (cache:: TrustRegionCache{false} )
419
+ @unpack J, fu_prev, fu, u_prev, u, f = cache
420
+ J = jacobian (cache, f)
421
+ cache. H = J * J
422
+ cache. g = J * fu
423
+ cache. stats. njacs += 1
424
+ @unpack H, g, step_size = cache
425
+
426
+ return - (get_loss (fu_prev) - get_loss (fu)) /
427
+ (step_size' * g + step_size' * H * step_size / 2 )
428
+ end
429
+
391
430
function trust_region_step! (cache:: TrustRegionCache )
392
431
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
393
432
cache. loss_new = get_loss (fu_new)
@@ -495,6 +534,23 @@ function trust_region_step!(cache::TrustRegionCache)
495
534
cache. internalnorm (g) < cache. ϵ
496
535
cache. force_stop = true
497
536
end
537
+ elseif radius_update_scheme === RadiusUpdateSchemes. Bastin
538
+ if r > cache. step_threshold
539
+ take_step! (cache)
540
+ cache. loss = cache. loss_new
541
+ cache. make_new_J = true
542
+ if retrospective_step! (cache) >= cache. expand_threshold
543
+ cache. trust_r = max (cache. p1 * cache. internalnorm (step_size), cache. trust_r)
544
+ end
545
+
546
+ else
547
+ cache. make_new_J = false
548
+ cache. trust_r *= cache. p2
549
+ cache. shrink_counter += 1
550
+ end
551
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol
552
+ cache. force_stop = true
553
+ end
498
554
end
499
555
end
500
556
@@ -526,12 +582,16 @@ function dogleg!(cache::TrustRegionCache)
526
582
end
527
583
528
584
function take_step! (cache:: TrustRegionCache{true} )
585
+ cache. u_prev .= cache. u
529
586
cache. u .= cache. u_tmp
587
+ cache. fu_prev .= cache. fu
530
588
cache. fu .= cache. fu_new
531
589
end
532
590
533
591
function take_step! (cache:: TrustRegionCache{false} )
592
+ cache. u_prev = cache. u
534
593
cache. u = cache. u_tmp
594
+ cache. fu_prev = cache. fu
535
595
cache. fu = cache. fu_new
536
596
end
537
597
0 commit comments