Skip to content

Commit c1267c8

Browse files
Merge pull request #178 from yash2798/ys/tr_fan
Trust Region - Fan's method
2 parents c698d16 + 05a3e99 commit c1267c8

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

src/trustRegion.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
294294
p2 = convert(eltype(u), 0.1) # β
295295
p3 = convert(eltype(u), 0.15) # γ1
296296
p4 = convert(eltype(u), 0.15) # γ2
297+
initial_trust_radius = convert(eltype(u), 1.0)
297298
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
298299
step_threshold = convert(eltype(u), 0.0001)
299300
shrink_threshold = convert(eltype(u), 0.25)
@@ -302,6 +303,25 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
302303
p2 = convert(eltype(u), 1/6) # c5
303304
p3 = convert(eltype(u), 6.0) # c6
304305
p4 = convert(eltype(u), 0.0)
306+
if iip
307+
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
308+
else
309+
if isa(u, Number)
310+
g = ForwardDiff.derivative(x -> f(x, p), u)
311+
else
312+
g = auto_jacvec(x -> f(x, p), u, fu)
313+
end
314+
end
315+
initial_trust_radius = convert(eltype(u), p1 * norm(g))
316+
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
317+
step_threshold = convert(eltype(u), 0.0001)
318+
shrink_threshold = convert(eltype(u), 0.25)
319+
expand_threshold = convert(eltype(u), 0.75)
320+
p1 = convert(eltype(u), 0.1) # μ
321+
p2 = convert(eltype(u), 1/4) # c5
322+
p3 = convert(eltype(u), 12) # c6
323+
p4 = convert(eltype(u), 1.0e18) # M
324+
initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99))
305325
end
306326

307327
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
@@ -435,8 +455,29 @@ function trust_region_step!(cache::TrustRegionCache)
435455
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
436456
cache.force_stop = true
437457
end
458+
#Fan's update scheme
459+
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
460+
if r < cache.shrink_threshold
461+
cache.p1 *= cache.p2
462+
cache.shrink_counter += 1
463+
elseif r > cache.expand_threshold
464+
cache.p1 = min(cache.p1*cache.p3, cache.p4)
465+
cache.shrink_counter = 0
466+
end
438467

439-
#elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
468+
if r > cache.step_threshold
469+
take_step!(cache)
470+
cache.loss = cache.loss_new
471+
cache.make_new_J = true
472+
else
473+
cache.make_new_J = false
474+
end
475+
476+
@unpack p1 = cache
477+
cache.trust_r = p1 * (cache.internalnorm(cache.fu)^0.99)
478+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ
479+
cache.force_stop = true
480+
end
440481
end
441482
end
442483

@@ -490,7 +531,7 @@ function jvp!(cache::TrustRegionCache{true})
490531
if isa(u, Number)
491532
return value_derivative(x -> f(x, p), u)
492533
end
493-
return auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
534+
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu)
494535
g
495536
end
496537

test/basictests.jl

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ function sf(u, p=nothing)
189189
end
190190

191191
u0 = [1.0, 1.0]
192-
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan]
192+
radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
193+
RadiusUpdateSchemes.Fan]
193194

194195
for radius_update_scheme in radius_update_schemes
195196
sol = benchmark_immutable(ff, cu0, radius_update_scheme)
@@ -255,7 +256,6 @@ for p in 1.1:0.1:100.0
255256
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
256257
end
257258

258-
## FAIL BECAUSE JVP CANNOT ACCEPT PARAMETERS IN FUNCTIONS
259259
g = function (p)
260260
probN = NonlinearProblem{false}(f, csu0, p)
261261
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan), abstol = 1e-9)
@@ -267,6 +267,17 @@ for p in 1.1:0.1:100.0
267267
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
268268
end
269269

270+
g = function (p)
271+
probN = NonlinearProblem{false}(f, csu0, p)
272+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-9)
273+
return sol.u[end]
274+
end
275+
276+
for p in 1.1:0.1:100.0
277+
@test g(p) sqrt(p)
278+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
279+
end
280+
270281
# Scalar
271282
f, u0 = (u, p) -> u * u - p, 1.0
272283

@@ -309,6 +320,19 @@ for p in 1.1:0.1:100.0
309320
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
310321
end
311322

323+
g = function (p)
324+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
325+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10)
326+
return sol.u
327+
end
328+
329+
@test ForwardDiff.derivative(g, 3.0) 1 / (2 * sqrt(3.0))
330+
331+
for p in 1.1:0.1:100.0
332+
@test g(p) sqrt(p)
333+
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
334+
end
335+
312336
f = (u, p) -> p[1] * u * u - p[2]
313337
t = (p) -> [sqrt(p[2] / p[1])]
314338
p = [0.9, 50.0]
@@ -328,6 +352,22 @@ end
328352
@test gnewton(p) [sqrt(p[2] / p[1])]
329353
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
330354

355+
gnewton = function (p)
356+
probN = NonlinearProblem{false}(f, 0.5, p)
357+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan))
358+
return [sol.u]
359+
end
360+
@test gnewton(p) [sqrt(p[2] / p[1])]
361+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
362+
363+
gnewton = function (p)
364+
probN = NonlinearProblem{false}(f, 0.5, p)
365+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan))
366+
return [sol.u]
367+
end
368+
@test gnewton(p) [sqrt(p[2] / p[1])]
369+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
370+
331371
# Iterator interface
332372
f = (u, p) -> u * u - p
333373
g = function (p_range)
@@ -372,6 +412,9 @@ probN = NonlinearProblem(f, u0)
372412
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Yuan)).u[end] sqrt(2.0)
373413
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Yuan, autodiff = false)).u[end] sqrt(2.0)
374414

415+
@test solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan)).u[end] sqrt(2.0)
416+
@test solve(probN, TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff = false)).u[end] sqrt(2.0)
417+
375418
for u0 in [1.0, [1, 1.0]]
376419
local f, probN, sol
377420
f = (u, p) -> u .* u .- 2.0
@@ -404,6 +447,16 @@ u = g(p)
404447
f(u, p)
405448
@test all(abs.(f(u, p)) .< 1e-10)
406449

450+
g = function (p)
451+
probN = NonlinearProblem{false}(f, u0, p)
452+
sol = solve(probN, TrustRegion(radius_update_scheme = RadiusUpdateSchemes.Fan), abstol = 1e-10)
453+
return sol.u
454+
end
455+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
456+
u = g(p)
457+
f(u, p)
458+
@test all(abs.(f(u, p)) .< 1e-10)
459+
407460
# Test kwars in `TrustRegion`
408461
max_trust_radius = [10.0, 100.0, 1000.0]
409462
initial_trust_radius = [10.0, 1.0, 0.1]

0 commit comments

Comments
 (0)