@@ -25,6 +25,13 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
25
25
"""
26
26
Simple
27
27
28
+ """
29
+ `RadiusUpdateSchemes.NLsolve`
30
+
31
+ The same updating rule as in NLsolve's trust region implementation
32
+ """
33
+ NLsolve
34
+
28
35
"""
29
36
`RadiusUpdateSchemes.Hei`
30
37
@@ -244,7 +251,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
244
251
p3 = convert (eltype (u), 0.0 )
245
252
p4 = convert (eltype (u), 0.0 )
246
253
ϵ = convert (eltype (u), 1.0e-8 )
247
- if radius_update_scheme === RadiusUpdateSchemes. Hei
254
+ if radius_update_scheme === RadiusUpdateSchemes. NLsolve
255
+ p1 = convert (eltype (u), 0.5 )
256
+ elseif radius_update_scheme === RadiusUpdateSchemes. Hei
248
257
step_threshold = convert (eltype (u), 0.0 )
249
258
shrink_threshold = convert (eltype (u), 0.25 )
250
259
expand_threshold = convert (eltype (u), 0.25 )
@@ -310,8 +319,9 @@ function perform_step!(cache::TrustRegionCache{true})
310
319
cache. stats. njacs += 1
311
320
end
312
321
313
- linres = dolinsolve (alg. precs, linsolve; A = cache. H, b = _vec (cache. g),
314
- linu = _vec (u_tmp), p, reltol = cache. abstol)
322
+ linres = dolinsolve (alg. precs, linsolve, A = J, b = _vec (fu), # cache.H, b = _vec(cache.g),
323
+ linu = _vec (u_tmp),
324
+ p = p, reltol = cache. abstol)
315
325
cache. linsolve = linres. cache
316
326
cache. u_tmp .= - 1 .* u_tmp
317
327
dogleg! (cache)
@@ -374,7 +384,7 @@ function trust_region_step!(cache::TrustRegionCache)
374
384
375
385
# Compute the ratio of the actual reduction to the predicted reduction.
376
386
cache. r = - (loss - cache. loss_new) / (dot (step_size, g) + dot (step_size, H, step_size) / 2 )
377
- @unpack r = cache
387
+ @unpack r = cache
378
388
379
389
if radius_update_scheme === RadiusUpdateSchemes. Simple
380
390
# Update the trust region radius.
@@ -403,6 +413,30 @@ function trust_region_step!(cache::TrustRegionCache)
403
413
cache. force_stop = true
404
414
end
405
415
416
+ elseif radius_update_scheme === RadiusUpdateSchemes. NLsolve
417
+ # accept/reject decision
418
+ if r > cache. step_threshold # accept
419
+ take_step! (cache)
420
+ cache. loss = cache. loss_new
421
+ cache. make_new_J = true
422
+ else # reject
423
+ cache. make_new_J = false
424
+ end
425
+
426
+ # trust region update
427
+ if r < cache. shrink_threshold # default 1 // 10
428
+ cache. trust_r *= cache. shrink_factor # default 1 // 2
429
+ elseif r >= cache. expand_threshold # default 9 // 10
430
+ cache. trust_r = cache. expand_factor * norm (cache. step_size) # default 2
431
+ elseif r >= cache. p1 # default 1 // 2
432
+ cache. trust_r = max (cache. trust_r, cache. expand_factor * norm (cache. step_size))
433
+ end
434
+
435
+ # convergence test
436
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol
437
+ cache. force_stop = true
438
+ end
439
+
406
440
elseif radius_update_scheme === RadiusUpdateSchemes. Hei
407
441
if r > cache. step_threshold
408
442
take_step! (cache)
0 commit comments