Skip to content

Commit 0d07b27

Browse files
Merge pull request #159 from yash2798/ys/tr_radius_schemes
Implementing some new Trust region radius update schemes
2 parents 60d3233 + b1f34dd commit 0d07b27

File tree

4 files changed

+125
-28
lines changed

4 files changed

+125
-28
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.5.0"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
9+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
910
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1011
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/NonlinearSolve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using ForwardDiff: Dual
1010
using LinearAlgebra
1111
using StaticArraysCore
1212
using RecursiveArrayTools
13+
import EnumX
1314
import ArrayInterface
1415
import LinearSolve
1516
using DiffEqBase
@@ -59,6 +60,8 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
5960
end
6061
end end
6162

63+
export RadiusUpdateSchemes
64+
6265
export NewtonRaphson, TrustRegion, LevenbergMarquardt
6366

6467
end # module

src/trustRegion.jl

Lines changed: 113 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,18 @@ for large-scale and numerically-difficult nonlinear systems.
8080
Currently, the linear solver and chunk size choice only applies to in-place defined
8181
`NonlinearProblem`s. That is expected to change in the future.
8282
"""
83+
EnumX.@enumx RadiusUpdateSchemes begin
84+
Simple
85+
Hei
86+
Yuan
87+
Bastin
88+
end
89+
8390
struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} <:
8491
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
8592
linsolve::L
8693
precs::P
94+
radius_update_scheme::RadiusUpdateSchemes.T
8795
max_trust_radius::MTR
8896
initial_trust_radius::MTR
8997
step_threshold::MTR
@@ -98,6 +106,7 @@ function TrustRegion(; chunk_size = Val{0}(),
98106
autodiff = Val{true}(),
99107
standardtag = Val{true}(), concrete_jac = nothing,
100108
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
109+
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple, #defaults to conventional radius update
101110
max_trust_radius::Real = 0 // 1,
102111
initial_trust_radius::Real = 0 // 1,
103112
step_threshold::Real = 1 // 10,
@@ -109,7 +118,7 @@ function TrustRegion(; chunk_size = Val{0}(),
109118
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
110119
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
111120
_unwrap_val(concrete_jac), typeof(max_trust_radius)
112-
}(linsolve, precs, max_trust_radius,
121+
}(linsolve, precs, radius_update_scheme, max_trust_radius,
113122
initial_trust_radius,
114123
step_threshold,
115124
shrink_threshold,
@@ -138,6 +147,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
138147
retcode::SciMLBase.ReturnCode.T
139148
abstol::tolType
140149
prob::probType
150+
radius_update_scheme::RadiusUpdateSchemes.T
141151
trust_r::trustType
142152
max_trust_r::trustType
143153
step_threshold::suType
@@ -155,20 +165,26 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
155165
fu_new::resType
156166
make_new_J::Bool
157167
r::floatType
168+
p1::floatType
169+
p2::floatType
170+
p3::floatType
171+
p4::floatType
172+
ϵ::floatType
158173

159174
function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
160175
uf::ufType, linsolve::L, J::jType,
161176
jac_config::JC, iter::Int,
162177
force_stop::Bool, maxiters::Int, internalnorm::INType,
163178
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
164-
prob::probType, trust_r::trustType,
179+
prob::probType, radius_update_scheme::RadiusUpdateSchemes.T, trust_r::trustType,
165180
max_trust_r::trustType, step_threshold::suType,
166181
shrink_threshold::trustType, expand_threshold::trustType,
167182
shrink_factor::trustType, expand_factor::trustType,
168183
loss::floatType, loss_new::floatType, H::jType,
169184
g::resType, shrink_counter::Int, step_size::su2Type,
170185
u_tmp::tmpType, fu_new::resType, make_new_J::Bool,
171-
r::floatType) where {iip, fType, algType, uType,
186+
r::floatType, p1::floatType, p2::floatType, p3::floatType,
187+
p4::floatType, ϵ::floatType) where {iip, fType, algType, uType,
172188
resType, pType, INType,
173189
tolType, probType, ufType, L,
174190
jType, JC, floatType, trustType,
@@ -178,13 +194,13 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
178194
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
179195
jac_config, iter, force_stop,
180196
maxiters, internalnorm, retcode,
181-
abstol, prob, trust_r, max_trust_r,
197+
abstol, prob, radius_update_scheme, trust_r, max_trust_r,
182198
step_threshold, shrink_threshold,
183199
expand_threshold, shrink_factor,
184200
expand_factor, loss,
185201
loss_new, H, g, shrink_counter,
186202
step_size, u_tmp, fu_new,
187-
make_new_J, r)
203+
make_new_J, r, p1, p2, p3, p4, ϵ)
188204
end
189205
end
190206

@@ -238,6 +254,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
238254
loss = get_loss(fu)
239255
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))
240256

257+
radius_update_scheme = alg.radius_update_scheme
241258
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
242259
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
243260
step_threshold = convert(eltype(u), alg.step_threshold)
@@ -262,13 +279,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
262279
make_new_J = true
263280
r = loss
264281

282+
# Parameters for the Schemes
283+
p1 = convert(eltype(u), 0.0)
284+
p2 = convert(eltype(u), 0.0)
285+
p3 = convert(eltype(u), 0.0)
286+
p4 = convert(eltype(u), 0.0)
287+
ϵ = convert(eltype(u), 1.0e-8)
288+
if radius_update_scheme === RadiusUpdateSchemes.Hei
289+
step_threshold = convert(eltype(u), 0.0)
290+
shrink_threshold = convert(eltype(u), 0.25)
291+
expand_threshold = convert(eltype(u), 0.25)
292+
p1 = convert(eltype(u), 5.0) # M
293+
p2 = convert(eltype(u), 0.1) # β
294+
p3 = convert(eltype(u), 0.15) # γ1
295+
p4 = convert(eltype(u), 0.15) # γ2
296+
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
297+
step_threshold = convert(eltype(u), 0.0001)
298+
shrink_threshold = convert(eltype(u), 0.25)
299+
expand_threshold = convert(eltype(u), 0.25)
300+
p1 = convert(eltype(u), 2.0) # μ
301+
p2 = convert(eltype(u), 1/6) # c5
302+
p3 = convert(eltype(u), 6.0) # c6
303+
p4 = convert(eltype(u), 0.0)
304+
end
305+
265306
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
266307
1, false, maxiters, internalnorm,
267-
ReturnCode.Default, abstol, prob, initial_trust_radius,
308+
ReturnCode.Default, abstol, prob, radius_update_scheme, initial_trust_radius,
268309
max_trust_radius, step_threshold, shrink_threshold,
269310
expand_threshold, shrink_factor, expand_factor, loss,
270311
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
271-
make_new_J, r)
312+
make_new_J, r, p1, p2, p3, p4, ϵ)
272313
end
273314

274315
function perform_step!(cache::TrustRegionCache{true})
@@ -289,7 +330,6 @@ function perform_step!(cache::TrustRegionCache{true})
289330
# Compute the potentially new u
290331
cache.u_tmp .= u .+ cache.step_size
291332
f(cache.fu_new, cache.u_tmp, p)
292-
293333
trust_region_step!(cache)
294334
return nothing
295335
end
@@ -311,43 +351,88 @@ function perform_step!(cache::TrustRegionCache{false})
311351
# Compute the potentially new u
312352
cache.u_tmp = u .+ cache.step_size
313353
cache.fu_new = f(cache.u_tmp, p)
314-
315354
trust_region_step!(cache)
316355
return nothing
317356
end
318357

319358
function trust_region_step!(cache::TrustRegionCache)
320-
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
359+
@unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
321360
cache.loss_new = get_loss(fu_new)
322361

323362
# Compute the ratio of the actual reduction to the predicted reduction.
324363
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
325364
@unpack r = cache
326365

327-
# Update the trust region radius.
328-
if r < cache.shrink_threshold
329-
cache.trust_r *= cache.shrink_factor
366+
if radius_update_scheme === RadiusUpdateSchemes.Simple
367+
# Update the trust region radius.
368+
if r < cache.shrink_threshold
369+
cache.trust_r *= cache.shrink_factor
370+
cache.shrink_counter += 1
371+
else
372+
cache.shrink_counter = 0
373+
end
374+
if r > cache.step_threshold
375+
take_step!(cache)
376+
cache.loss = cache.loss_new
377+
378+
# Update the trust region radius.
379+
if r > cache.expand_threshold
380+
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
381+
end
382+
383+
cache.make_new_J = true
384+
else
385+
# No need to make a new J, no step was taken, so we try again with a smaller trust_r
386+
cache.make_new_J = false
387+
end
388+
389+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
390+
cache.force_stop = true
391+
end
392+
393+
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
394+
if r > cache.step_threshold
395+
take_step!(cache)
396+
cache.loss = cache.loss_new
397+
cache.make_new_J = true
398+
else
399+
cache.make_new_J = false
400+
end
401+
# Hei's radius update scheme
402+
@unpack shrink_threshold, p1, p2, p3, p4 = cache
403+
if rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) < cache.trust_r
330404
cache.shrink_counter += 1
331-
else
332-
cache.shrink_counter = 0
333-
end
334-
if r > cache.step_threshold
335-
take_step!(cache)
336-
cache.loss = cache.loss_new
405+
end
406+
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) # parameters to be defined
337407

338-
# Update the trust region radius.
339-
if r > cache.expand_threshold
340-
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
341-
end
408+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
409+
cache.force_stop = true
410+
end
342411

412+
413+
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
414+
if r < cache.shrink_threshold
415+
cache.p1 = cache.p2 * cache.p1
416+
cache.shrink_counter += 1
417+
elseif r >= cache.expand_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
418+
cache.p1 = cache.p3 * cache.p1
419+
end
420+
@unpack p1, fu, f, J = cache
421+
#cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
422+
423+
if r > cache.step_threshold
424+
take_step!(cache)
425+
cache.loss = cache.loss_new
343426
cache.make_new_J = true
344-
else
345-
# No need to make a new J, no step was taken, so we try again with a smaller trust_r
427+
else
346428
cache.make_new_J = false
347-
end
348-
349-
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
429+
end
430+
431+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
350432
cache.force_stop = true
433+
end
434+
435+
#elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
351436
end
352437
end
353438

src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,11 @@ end
128128
function get_loss(fu)
129129
return norm(fu)^2 / 2
130130
end
131+
132+
function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
133+
if (r >= c2)
134+
return (2 * (M - 1 - γ2) * atan(r - c2) + (1 + γ2)) / π
135+
else
136+
return (1 - γ1 - β) * (exp(r - c2) + β / (1 - γ1 - β))
137+
end
138+
end

0 commit comments

Comments
 (0)