Skip to content

Commit 4aa4813

Browse files
committed
added hei's and yuan's schemes
1 parent 76185a8 commit 4aa4813

File tree

1 file changed

+56
-6
lines changed

1 file changed

+56
-6
lines changed

src/trustRegion.jl

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
165165
fu_new::resType
166166
make_new_J::Bool
167167
r::floatType
168+
p1::floatType
169+
p2::floatType
170+
p3::floatType
171+
p4::floatType
172+
ϵ::floatType
173+
# p5::floatType
174+
# p6::floatType
175+
# p7::floatType
168176

169177
function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
170178
uf::ufType, linsolve::L, J::jType,
@@ -178,7 +186,8 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
178186
loss::floatType, loss_new::floatType, H::jType,
179187
g::resType, shrink_counter::Int, step_size::su2Type,
180188
u_tmp::tmpType, fu_new::resType, make_new_J::Bool,
181-
r::floatType) where {iip, fType, algType, uType,
189+
r::floatType, p1::floatType, p2::floatType, p3::floatType,
190+
p4::floatType, ϵ::floatType) where {iip, fType, algType, uType,
182191
resType, pType, INType,
183192
tolType, probType, ufType, L,
184193
jType, JC, floatType, trustType,
@@ -194,7 +203,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
194203
expand_factor, loss,
195204
loss_new, H, g, shrink_counter,
196205
step_size, u_tmp, fu_new,
197-
make_new_J, r)
206+
make_new_J, r, p1, p2, p3, p4, ϵ)
198207
end
199208
end
200209

@@ -273,13 +282,33 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
273282
make_new_J = true
274283
r = loss
275284

285+
# Parameters for the Schemes
286+
ϵ = 1e-8
287+
if radius_update_scheme === RadiusUpdateSchemes.Hei
288+
step_threshold = 0
289+
shrink_threshold = 0.25
290+
expand_threshold = 0.25
291+
p1 = 5.0 # M
292+
p2 = 0.1 # β
293+
p3 = 0.15 # γ1
294+
p4 = 0.15 # γ2
295+
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
296+
step_threshold = 0.0001
297+
shrink_threshold = 0.25
298+
expand_threshold = 0.25
299+
p1 = 2.0 # μ
300+
p2 = 1/6 # c5
301+
p3 = 6 # c6
302+
p4 = 0
303+
end
304+
276305
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
277306
1, false, maxiters, internalnorm,
278307
ReturnCode.Default, abstol, prob, radius_update_scheme, initial_trust_radius,
279308
max_trust_radius, step_threshold, shrink_threshold,
280309
expand_threshold, shrink_factor, expand_factor, loss,
281310
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
282-
make_new_J, r)
311+
make_new_J, r, p1, p2, p3, p4, ϵ)
283312
end
284313

285314
function perform_step!(cache::TrustRegionCache{true})
@@ -365,25 +394,46 @@ function trust_region_step!(cache::TrustRegionCache)
365394
end
366395

367396
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
368-
if r > c1 # parameters to be defined
397+
if r > cache.step_threshold # parameters to be defined
369398
take_step!(cache)
370399
cache.loss = cache.loss_new
371400
cache.make_new_J = true
372401
else
373402
cache.make_new_J = false
374403
end
375404
# Hei's radius update scheme
376-
cache.trust_r = rfunc(r, c2, M, γ1, γ2, β) * cache.internalnorm(step_size) # parameters to be defined
405+
@unpack shrink_threshold, p1, p2, p3, p4, ϵ = cache
406+
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) # parameters to be defined
377407

378408
if iszero(fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < ϵ # parameters to be defined
379409
cache.force_stop = true
380410
end
381411

382412

383413
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
414+
if r > cache.step_threshold
415+
take_step!(cache)
416+
cache.loss = cache.loss_new
417+
cache.make_new_J = true
418+
else
419+
cache.make_new_J = false
420+
end
421+
if r < cache.shrink_threshold
422+
cache.p1 = p2 * cache.p1
423+
elseif r >= cache.shrink_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
424+
cache.p1 = p3 * cache.p1
425+
end
426+
@unpack p1 = cache.p1
427+
428+
# yuan's scheme
429+
@unpack fu = cache
430+
cache.trust_r = p1 * cache.internalnorm(jacobian(cache, f) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
384431

432+
if iszero(fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < ϵ # parameters to be defined
433+
cache.force_stop = true
434+
end
385435

386-
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
436+
#elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
387437

388438

389439
end

0 commit comments

Comments
 (0)