Skip to content

Commit 0abdc34

Browse files
Merge pull request #191 from yash2798/ys/bastin_new
Bastin's radius update scheme
2 parents 8d009b9 + a8f4c33 commit 0abdc34

File tree

3 files changed

+159
-5
lines changed

3 files changed

+159
-5
lines changed

src/trustRegion.jl

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
134134
trustType, suType, su2Type, tmpType}
135135
f::fType
136136
alg::algType
137+
u_prev::uType
137138
u::uType
139+
fu_prev::resType
138140
fu::resType
139141
p::pType
140142
uf::ufType
@@ -172,7 +174,8 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
172174
ϵ::floatType
173175
stats::NLStats
174176

175-
function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
177+
function TrustRegionCache{iip}(f::fType, alg::algType, u_prev::uType, u::uType,
178+
fu_prev::resType, fu::resType, p::pType,
176179
uf::ufType, linsolve::L, J::jType, jac_config::JC,
177180
force_stop::Bool, maxiters::Int, internalnorm::INType,
178181
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
@@ -194,7 +197,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
194197
suType, su2Type, tmpType}
195198
new{iip, fType, algType, uType, resType, pType,
196199
INType, tolType, probType, ufType, L, jType, JC, floatType,
197-
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
200+
trustType, suType, su2Type, tmpType}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
198201
jac_config, force_stop,
199202
maxiters, internalnorm, retcode,
200203
abstol, prob, radius_update_scheme,
@@ -246,6 +249,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
246249
else
247250
u = deepcopy(prob.u0)
248251
end
252+
u_prev = zero(u)
249253
f = prob.f
250254
p = prob.p
251255
if iip
@@ -254,6 +258,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
254258
else
255259
fu = f(u, p)
256260
end
261+
fu_prev = zero(fu)
257262

258263
loss = get_loss(fu)
259264
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))
@@ -325,9 +330,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
325330
p3 = convert(eltype(u), 12) # c6
326331
p4 = convert(eltype(u), 1.0e18) # M
327332
initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99))
333+
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
334+
step_threshold = convert(eltype(u), 0.05)
335+
shrink_threshold = convert(eltype(u), 0.05)
336+
expand_threshold = convert(eltype(u), 0.9)
337+
p1 = convert(eltype(u), 2.5) #alpha_1
338+
p2 = convert(eltype(u), 0.25) # alpha_2
339+
p3 = convert(eltype(u), 0) # not required
340+
p4 = convert(eltype(u), 0) # not required
341+
initial_trust_radius = convert(eltype(u), 1.0)
328342
end
329343

330-
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
344+
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
345+
jac_config,
331346
false, maxiters, internalnorm,
332347
ReturnCode.Default, abstol, prob, radius_update_scheme,
333348
initial_trust_radius,
@@ -388,6 +403,30 @@ function perform_step!(cache::TrustRegionCache{false})
388403
return nothing
389404
end
390405

406+
function retrospective_step!(cache::TrustRegionCache{true})
407+
@unpack J, fu_prev, fu, u_prev, u = cache
408+
jacobian!(J, cache)
409+
mul!(cache.H, J, J)
410+
mul!(cache.g, J, fu)
411+
cache.stats.njacs += 1
412+
@unpack H, g, step_size = cache
413+
414+
return -(get_loss(fu_prev) - get_loss(fu)) /
415+
(step_size' * g + step_size' * H * step_size / 2)
416+
end
417+
418+
function retrospective_step!(cache::TrustRegionCache{false})
419+
@unpack J, fu_prev, fu, u_prev, u, f = cache
420+
J = jacobian(cache, f)
421+
cache.H = J * J
422+
cache.g = J * fu
423+
cache.stats.njacs += 1
424+
@unpack H, g, step_size = cache
425+
426+
return -(get_loss(fu_prev) - get_loss(fu)) /
427+
(step_size' * g + step_size' * H * step_size / 2)
428+
end
429+
391430
function trust_region_step!(cache::TrustRegionCache)
392431
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
393432
cache.loss_new = get_loss(fu_new)
@@ -495,6 +534,23 @@ function trust_region_step!(cache::TrustRegionCache)
495534
cache.internalnorm(g) < cache.ϵ
496535
cache.force_stop = true
497536
end
537+
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
538+
if r > cache.step_threshold
539+
take_step!(cache)
540+
cache.loss = cache.loss_new
541+
cache.make_new_J = true
542+
if retrospective_step!(cache) >= cache.expand_threshold
543+
cache.trust_r = max(cache.p1 * cache.internalnorm(step_size), cache.trust_r)
544+
end
545+
546+
else
547+
cache.make_new_J = false
548+
cache.trust_r *= cache.p2
549+
cache.shrink_counter += 1
550+
end
551+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
552+
cache.force_stop = true
553+
end
498554
end
499555
end
500556

@@ -526,12 +582,16 @@ function dogleg!(cache::TrustRegionCache)
526582
end
527583

528584
function take_step!(cache::TrustRegionCache{true})
585+
cache.u_prev .= cache.u
529586
cache.u .= cache.u_tmp
587+
cache.fu_prev .= cache.fu
530588
cache.fu .= cache.fu_new
531589
end
532590

533591
function take_step!(cache::TrustRegionCache{false})
592+
cache.u_prev = cache.u
534593
cache.u = cache.u_tmp
594+
cache.fu_prev = cache.fu
535595
cache.fu = cache.fu_new
536596
end
537597

test/basictests.jl

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ end
193193

194194
u0 = [1.0, 1.0]
195195
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei,
196-
RadiusUpdateSchemes.Yuan,
197-
RadiusUpdateSchemes.Fan]
196+
RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
198197

199198
for radius_update_scheme in radius_update_schemes
200199
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
@@ -286,6 +285,18 @@ for p in 1.1:0.1:100.0
286285
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
287286
end
288287

288+
g = function (p)
289+
probN = NonlinearProblem{false}(f, csu0, p)
290+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
291+
abstol = 1e-9)
292+
return sol.u[end]
293+
end
294+
295+
for p in 1.1:0.1:100.0
296+
@test g(p) sqrt(p)
297+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
298+
end
299+
289300
# Scalar
290301
f, u0 = (u, p) -> u * u - p, 1.0
291302

@@ -344,6 +355,20 @@ for p in 1.1:0.1:100.0
344355
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
345356
end
346357

358+
g = function (p)
359+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
360+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
361+
abstol = 1e-10)
362+
return sol.u
363+
end
364+
365+
@test ForwardDiff.derivative(g, 3.0) 1 / (2 * sqrt(3.0))
366+
367+
for p in 1.1:0.1:100.0
368+
@test g(p) sqrt(p)
369+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
370+
end
371+
347372
f = (u, p) -> p[1] * u * u - p[2]
348373
t = (p) -> [sqrt(p[2] / p[1])]
349374
p = [0.9, 50.0]
@@ -379,6 +404,14 @@ end
379404
@test gnewton(p) [sqrt(p[2] / p[1])]
380405
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
381406

407+
gnewton = function (p)
408+
probN = NonlinearProblem{false}(f, 0.5, p)
409+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin))
410+
return [sol.u]
411+
end
412+
@test gnewton(p) [sqrt(p[2] / p[1])]
413+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
414+
382415
# Iterator interface
383416
f = (u, p) -> u * u - p
384417
g = function (p_range)
@@ -432,6 +465,11 @@ probN = NonlinearProblem(f, u0)
432465
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end]
433466
sqrt(2.0)
434467

468+
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin)).u[end]
469+
sqrt(2.0)
470+
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff = false)).u[end]
471+
sqrt(2.0)
472+
435473
for u0 in [1.0, [1, 1.0]]
436474
local f, probN, sol
437475
f = (u, p) -> u .* u .- 2.0
@@ -475,6 +513,17 @@ u = g(p)
475513
f(u, p)
476514
@test all(abs.(f(u, p)) .< 1e-10)
477515

516+
g = function (p)
517+
probN = NonlinearProblem{false}(f, u0, p)
518+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Bastin),
519+
abstol = 1e-10)
520+
return sol.u
521+
end
522+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
523+
u = g(p)
524+
f(u, p)
525+
@test all(abs.(f(u, p)) .< 1e-10)
526+
478527
# Test kwars in `TrustRegion`
479528
max_trust_radius = [10.0, 100.0, 1000.0]
480529
initial_trust_radius = [10.0, 1.0, 0.1]
@@ -542,6 +591,11 @@ for maxiters in maxiterations
542591
@test iip == oop
543592
end
544593

594+
for maxiters in maxiterations
595+
iip, oop = iip_oop(ff, ffiip, u0, RadiusUpdateSchemes.Bastin, maxiters)
596+
@test iip == oop
597+
end
598+
545599
# --- LevenbergMarquardt tests ---
546600

547601
function benchmark_immutable(f, u0)

test/convergencetests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using NonlinearSolve
2+
using StaticArrays
3+
using BenchmarkTools
4+
using Test
5+
6+
using SciMLNLSolve
7+
8+
###-----Trust Region tests-----###
9+
10+
# some simple functions #
11+
function f_oop(u, p)
12+
u .* u .- p
13+
end
14+
15+
function f_iip(du, u, p)
16+
du .= u .* u .- p
17+
end
18+
19+
function f_scalar(u, p)
20+
u * u - p
21+
end
22+
23+
u0 = [1.0, 1.0]
24+
csu0 = 1.0
25+
p = [2.0, 2.0]
26+
radius_update_scheme = RadiusUpdateSchemes.Simple
27+
tol = 1e-9
28+
29+
function convergence_test_oop(f, u0, p, radius_update_scheme)
30+
prob = NonlinearProblem{false}(f, oftype(p, u0), p)
31+
cache = init(prob,
32+
TrustRegion(radius_update_scheme = radius_update_scheme),
33+
abstol = 1e-9)
34+
sol = solve!(cache)
35+
return cache.internalnorm(cache.u_prev - cache.u), cache.iter, sol.retcode
36+
end
37+
38+
residual, iterations, return_code = convergence_test_oop(f_oop, u0, p, radius_update_scheme)
39+
@test return_code === ReturnCode.Success
40+
@test residual tol

0 commit comments

Comments
 (0)