Skip to content

Commit 146dec9

Browse files
committed
add NLsolve trust region updating scheme and change GN step to -J\fu to avoid growing ill-conditioning
1 parent 6f3556e commit 146dec9

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

src/trustRegion.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
2525
"""
2626
Simple
2727

28+
"""
29+
`RadiusUpdateSchemes.NLsolve`
30+
31+
The same updating rule as in NLsolve's trust region implementation
32+
"""
33+
NLsolve
34+
2835
"""
2936
`RadiusUpdateSchemes.Hei`
3037
@@ -244,7 +251,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
244251
p3 = convert(eltype(u), 0.0)
245252
p4 = convert(eltype(u), 0.0)
246253
ϵ = convert(eltype(u), 1.0e-8)
247-
if radius_update_scheme === RadiusUpdateSchemes.Hei
254+
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
255+
p1 = convert(eltype(u), 0.5)
256+
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
248257
step_threshold = convert(eltype(u), 0.0)
249258
shrink_threshold = convert(eltype(u), 0.25)
250259
expand_threshold = convert(eltype(u), 0.25)
@@ -310,8 +319,9 @@ function perform_step!(cache::TrustRegionCache{true})
310319
cache.stats.njacs += 1
311320
end
312321

313-
linres = dolinsolve(alg.precs, linsolve; A = cache.H, b = _vec(cache.g),
314-
linu = _vec(u_tmp), p, reltol = cache.abstol)
322+
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), # cache.H, b = _vec(cache.g),
323+
linu = _vec(u_tmp),
324+
p = p, reltol = cache.abstol)
315325
cache.linsolve = linres.cache
316326
cache.u_tmp .= -1 .* u_tmp
317327
dogleg!(cache)
@@ -374,7 +384,7 @@ function trust_region_step!(cache::TrustRegionCache)
374384

375385
# Compute the ratio of the actual reduction to the predicted reduction.
376386
cache.r = -(loss - cache.loss_new) / (dot(step_size, g) + dot(step_size, H, step_size) / 2)
377-
@unpack r = cache
387+
@unpack r = cache
378388

379389
if radius_update_scheme === RadiusUpdateSchemes.Simple
380390
# Update the trust region radius.
@@ -403,6 +413,30 @@ function trust_region_step!(cache::TrustRegionCache)
403413
cache.force_stop = true
404414
end
405415

416+
elseif radius_update_scheme === RadiusUpdateSchemes.NLsolve
417+
# accept/reject decision
418+
if r > cache.step_threshold # accept
419+
take_step!(cache)
420+
cache.loss = cache.loss_new
421+
cache.make_new_J = true
422+
else # reject
423+
cache.make_new_J = false
424+
end
425+
426+
# trust region update
427+
if r < cache.shrink_threshold # default 1 // 10
428+
cache.trust_r *= cache.shrink_factor # default 1 // 2
429+
elseif r >= cache.expand_threshold # default 9 // 10
430+
cache.trust_r = cache.expand_factor * norm(cache.step_size) # default 2
431+
elseif r >= cache.p1 # default 1 // 2
432+
cache.trust_r = max(cache.trust_r, cache.expand_factor * norm(cache.step_size))
433+
end
434+
435+
# convergence test
436+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
437+
cache.force_stop = true
438+
end
439+
406440
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
407441
if r > cache.step_threshold
408442
take_step!(cache)

0 commit comments

Comments
 (0)