@@ -302,6 +302,14 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
302
302
p2 = convert (eltype (u), 1 / 6 ) # c5
303
303
p3 = convert (eltype (u), 6.0 ) # c6
304
304
p4 = convert (eltype (u), 0.0 )
305
+ elseif radius_update_scheme === RadiusUpdateSchemes. Fan
306
+ step_threshold = convert (eltype (u), 0.0001 )
307
+ shrink_threshold = convert (eltype (u), 0.25 )
308
+ expand_threshold = convert (eltype (u), 0.75 )
309
+ p1 = convert (eltype (u), 0.1 ) # μ
310
+ p2 = convert (eltype (u), 1 / 4 ) # c5
311
+ p3 = convert (eltype (u), 12 ) # c6
312
+ p4 = convert (eltype (u), 1.0e18 ) # M
305
313
end
306
314
307
315
return TrustRegionCache {iip} (f, alg, u, fu, p, uf, linsolve, J, jac_config,
@@ -435,8 +443,29 @@ function trust_region_step!(cache::TrustRegionCache)
435
443
if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < cache. ϵ
436
444
cache. force_stop = true
437
445
end
446
+ # Fan's update scheme
447
+ elseif radius_update_scheme === RadiusUpdateSchemes. Fan
448
+ if r < cache. shrink_threshold
449
+ cache. p1 *= cache. p2
450
+ cache. shrink_counter += 1
451
+ elseif r > cache. expand_threshold
452
+ cache. p1 = min (cache. p1* cache. p3, cache. p4)
453
+ cache. shrink_counter = 0
454
+ end
438
455
439
- # elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
456
+ if r > cache. step_threshold
457
+ take_step! (cache)
458
+ cache. loss = cache. loss_new
459
+ cache. make_new_J = true
460
+ else
461
+ cache. make_new_J = false
462
+ end
463
+
464
+ @unpack p1 = cache
465
+ cache. trust_r = p1 * (cache. internalnorm (cache. fu)^ 0.99 )
466
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < cache. ϵ
467
+ cache. force_stop = true
468
+ end
440
469
end
441
470
end
442
471
0 commit comments