@@ -384,8 +384,7 @@ function retrospective_step!(cache::TrustRegionCache{iip}) where {iip}
384
384
__update_JᵀJ! (cache, J)
385
385
__update_Jᵀf! (cache, J)
386
386
387
- num = __trust_region_loss (cache, cache. fu) -
388
- __get_trust_region_loss (cache, cache. fu_cache)
387
+ num = __trust_region_loss (cache, cache. fu) - __trust_region_loss (cache, cache. fu_cache)
389
388
denom = dot (_vec (cache. du), _vec (cache. Jᵀf)) + __lr_mul (cache, cache. JᵀJ, cache. du) / 2
390
389
return num / denom
391
390
end
@@ -441,7 +440,7 @@ function trust_region_step!(cache::TrustRegionCache)
441
440
end
442
441
elseif radius_update_scheme === RadiusUpdateSchemes. Hei
443
442
@unpack shrink_threshold, p1, p2, p3, p4 = cache
444
- tr_new = __rfunc (r, shrink_threshold, p1, p3, p4, p2) * cache. internalnorm (du)
443
+ tr_new = __rfunc (r, shrink_threshold, p1, p3, p4, p2) * cache. internalnorm (cache . du)
445
444
if tr_new < cache. trust_r
446
445
cache. shrink_counter += 1
447
446
else
@@ -479,7 +478,7 @@ function trust_region_step!(cache::TrustRegionCache)
479
478
elseif radius_update_scheme === RadiusUpdateSchemes. Bastin
480
479
if r > cache. step_threshold
481
480
if retrospective_step! (cache) ≥ cache. expand_threshold
482
- cache. trust_r = max (cache. p1 * cache. internalnorm (du), cache. trust_r)
481
+ cache. trust_r = max (cache. p1 * cache. internalnorm (cache . du), cache. trust_r)
483
482
end
484
483
cache. shrink_counter = 0
485
484
else
0 commit comments