@@ -80,10 +80,18 @@ for large-scale and numerically-difficult nonlinear systems.
80
80
Currently, the linear solver and chunk size choice only applies to in-place defined
81
81
`NonlinearProblem`s. That is expected to change in the future.
82
82
"""
83
+ EnumX. @enumx RadiusUpdateSchemes begin
84
+ Simple
85
+ Hei
86
+ Yuan
87
+ Bastin
88
+ end
89
+
83
90
struct TrustRegion{CS, AD, FDT, L, P, ST, CJ, MTR} < :
84
91
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
85
92
linsolve:: L
86
93
precs:: P
94
+ radius_update_scheme:: RadiusUpdateSchemes.T
87
95
max_trust_radius:: MTR
88
96
initial_trust_radius:: MTR
89
97
step_threshold:: MTR
@@ -98,6 +106,7 @@ function TrustRegion(; chunk_size = Val{0}(),
98
106
autodiff = Val {true} (),
99
107
standardtag = Val {true} (), concrete_jac = nothing ,
100
108
diff_type = Val{:forward }, linsolve = nothing , precs = DEFAULT_PRECS,
109
+ radius_update_scheme:: RadiusUpdateSchemes.T = RadiusUpdateSchemes. Simple, # defaults to conventional radius update
101
110
max_trust_radius:: Real = 0 // 1 ,
102
111
initial_trust_radius:: Real = 0 // 1 ,
103
112
step_threshold:: Real = 1 // 10 ,
@@ -109,7 +118,7 @@ function TrustRegion(; chunk_size = Val{0}(),
109
118
TrustRegion{_unwrap_val (chunk_size), _unwrap_val (autodiff), diff_type,
110
119
typeof (linsolve), typeof (precs), _unwrap_val (standardtag),
111
120
_unwrap_val (concrete_jac), typeof (max_trust_radius)
112
- }(linsolve, precs, max_trust_radius,
121
+ }(linsolve, precs, radius_update_scheme, max_trust_radius,
113
122
initial_trust_radius,
114
123
step_threshold,
115
124
shrink_threshold,
@@ -138,6 +147,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
138
147
retcode:: SciMLBase.ReturnCode.T
139
148
abstol:: tolType
140
149
prob:: probType
150
+ radius_update_scheme:: RadiusUpdateSchemes.T
141
151
trust_r:: trustType
142
152
max_trust_r:: trustType
143
153
step_threshold:: suType
@@ -155,20 +165,26 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
155
165
fu_new:: resType
156
166
make_new_J:: Bool
157
167
r:: floatType
168
+ p1:: floatType
169
+ p2:: floatType
170
+ p3:: floatType
171
+ p4:: floatType
172
+ ϵ:: floatType
158
173
159
174
function TrustRegionCache {iip} (f:: fType , alg:: algType , u:: uType , fu:: resType , p:: pType ,
160
175
uf:: ufType , linsolve:: L , J:: jType ,
161
176
jac_config:: JC , iter:: Int ,
162
177
force_stop:: Bool , maxiters:: Int , internalnorm:: INType ,
163
178
retcode:: SciMLBase.ReturnCode.T , abstol:: tolType ,
164
- prob:: probType , trust_r:: trustType ,
179
+ prob:: probType , radius_update_scheme :: RadiusUpdateSchemes.T , trust_r:: trustType ,
165
180
max_trust_r:: trustType , step_threshold:: suType ,
166
181
shrink_threshold:: trustType , expand_threshold:: trustType ,
167
182
shrink_factor:: trustType , expand_factor:: trustType ,
168
183
loss:: floatType , loss_new:: floatType , H:: jType ,
169
184
g:: resType , shrink_counter:: Int , step_size:: su2Type ,
170
185
u_tmp:: tmpType , fu_new:: resType , make_new_J:: Bool ,
171
- r:: floatType ) where {iip, fType, algType, uType,
186
+ r:: floatType , p1:: floatType , p2:: floatType , p3:: floatType ,
187
+ p4:: floatType , ϵ:: floatType ) where {iip, fType, algType, uType,
172
188
resType, pType, INType,
173
189
tolType, probType, ufType, L,
174
190
jType, JC, floatType, trustType,
@@ -178,13 +194,13 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
178
194
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
179
195
jac_config, iter, force_stop,
180
196
maxiters, internalnorm, retcode,
181
- abstol, prob, trust_r, max_trust_r,
197
+ abstol, prob, radius_update_scheme, trust_r, max_trust_r,
182
198
step_threshold, shrink_threshold,
183
199
expand_threshold, shrink_factor,
184
200
expand_factor, loss,
185
201
loss_new, H, g, shrink_counter,
186
202
step_size, u_tmp, fu_new,
187
- make_new_J, r)
203
+ make_new_J, r, p1, p2, p3, p4, ϵ )
188
204
end
189
205
end
190
206
@@ -238,6 +254,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
238
254
loss = get_loss (fu)
239
255
uf, linsolve, J, u_tmp, jac_config = jacobian_caches (alg, f, u, p, Val (iip))
240
256
257
+ radius_update_scheme = alg. radius_update_scheme
241
258
max_trust_radius = convert (eltype (u), alg. max_trust_radius)
242
259
initial_trust_radius = convert (eltype (u), alg. initial_trust_radius)
243
260
step_threshold = convert (eltype (u), alg. step_threshold)
@@ -262,13 +279,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
262
279
make_new_J = true
263
280
r = loss
264
281
282
+ # Parameters for the Schemes
283
+ p1 = convert (eltype (u), 0.0 )
284
+ p2 = convert (eltype (u), 0.0 )
285
+ p3 = convert (eltype (u), 0.0 )
286
+ p4 = convert (eltype (u), 0.0 )
287
+ ϵ = convert (eltype (u), 1.0e-8 )
288
+ if radius_update_scheme === RadiusUpdateSchemes. Hei
289
+ step_threshold = convert (eltype (u), 0.0 )
290
+ shrink_threshold = convert (eltype (u), 0.25 )
291
+ expand_threshold = convert (eltype (u), 0.25 )
292
+ p1 = convert (eltype (u), 5.0 ) # M
293
+ p2 = convert (eltype (u), 0.1 ) # β
294
+ p3 = convert (eltype (u), 0.15 ) # γ1
295
+ p4 = convert (eltype (u), 0.15 ) # γ2
296
+ elseif radius_update_scheme === RadiusUpdateSchemes. Yuan
297
+ step_threshold = convert (eltype (u), 0.0001 )
298
+ shrink_threshold = convert (eltype (u), 0.25 )
299
+ expand_threshold = convert (eltype (u), 0.25 )
300
+ p1 = convert (eltype (u), 2.0 ) # μ
301
+ p2 = convert (eltype (u), 1 / 6 ) # c5
302
+ p3 = convert (eltype (u), 6.0 ) # c6
303
+ p4 = convert (eltype (u), 0.0 )
304
+ end
305
+
265
306
return TrustRegionCache {iip} (f, alg, u, fu, p, uf, linsolve, J, jac_config,
266
307
1 , false , maxiters, internalnorm,
267
- ReturnCode. Default, abstol, prob, initial_trust_radius,
308
+ ReturnCode. Default, abstol, prob, radius_update_scheme, initial_trust_radius,
268
309
max_trust_radius, step_threshold, shrink_threshold,
269
310
expand_threshold, shrink_factor, expand_factor, loss,
270
311
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
271
- make_new_J, r)
312
+ make_new_J, r, p1, p2, p3, p4, ϵ )
272
313
end
273
314
274
315
function perform_step! (cache:: TrustRegionCache{true} )
@@ -289,7 +330,6 @@ function perform_step!(cache::TrustRegionCache{true})
289
330
# Compute the potentially new u
290
331
cache. u_tmp .= u .+ cache. step_size
291
332
f (cache. fu_new, cache. u_tmp, p)
292
-
293
333
trust_region_step! (cache)
294
334
return nothing
295
335
end
@@ -311,43 +351,88 @@ function perform_step!(cache::TrustRegionCache{false})
311
351
# Compute the potentially new u
312
352
cache. u_tmp = u .+ cache. step_size
313
353
cache. fu_new = f (cache. u_tmp, p)
314
-
315
354
trust_region_step! (cache)
316
355
return nothing
317
356
end
318
357
319
358
function trust_region_step! (cache:: TrustRegionCache )
320
- @unpack fu_new, step_size, g, H, loss, max_trust_r = cache
359
+ @unpack fu_new, step_size, g, H, loss, max_trust_r, radius_update_scheme = cache
321
360
cache. loss_new = get_loss (fu_new)
322
361
323
362
# Compute the ratio of the actual reduction to the predicted reduction.
324
363
cache. r = - (loss - cache. loss_new) / (step_size' * g + step_size' * H * step_size / 2 )
325
364
@unpack r = cache
326
365
327
- # Update the trust region radius.
328
- if r < cache. shrink_threshold
329
- cache. trust_r *= cache. shrink_factor
366
+ if radius_update_scheme === RadiusUpdateSchemes. Simple
367
+ # Update the trust region radius.
368
+ if r < cache. shrink_threshold
369
+ cache. trust_r *= cache. shrink_factor
370
+ cache. shrink_counter += 1
371
+ else
372
+ cache. shrink_counter = 0
373
+ end
374
+ if r > cache. step_threshold
375
+ take_step! (cache)
376
+ cache. loss = cache. loss_new
377
+
378
+ # Update the trust region radius.
379
+ if r > cache. expand_threshold
380
+ cache. trust_r = min (cache. expand_factor * cache. trust_r, max_trust_r)
381
+ end
382
+
383
+ cache. make_new_J = true
384
+ else
385
+ # No need to make a new J, no step was taken, so we try again with a smaller trust_r
386
+ cache. make_new_J = false
387
+ end
388
+
389
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol
390
+ cache. force_stop = true
391
+ end
392
+
393
+ elseif radius_update_scheme === RadiusUpdateSchemes. Hei
394
+ if r > cache. step_threshold
395
+ take_step! (cache)
396
+ cache. loss = cache. loss_new
397
+ cache. make_new_J = true
398
+ else
399
+ cache. make_new_J = false
400
+ end
401
+ # Hei's radius update scheme
402
+ @unpack shrink_threshold, p1, p2, p3, p4 = cache
403
+ if rfunc (r, shrink_threshold, p1, p3, p4, p2) * cache. internalnorm (step_size) < cache. trust_r
330
404
cache. shrink_counter += 1
331
- else
332
- cache. shrink_counter = 0
333
- end
334
- if r > cache. step_threshold
335
- take_step! (cache)
336
- cache. loss = cache. loss_new
405
+ end
406
+ cache. trust_r = rfunc (r, shrink_threshold, p1, p3, p4, p2) * cache. internalnorm (step_size) # parameters to be defined
337
407
338
- # Update the trust region radius.
339
- if r > cache. expand_threshold
340
- cache. trust_r = min (cache. expand_factor * cache. trust_r, max_trust_r)
341
- end
408
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < cache. ϵ # parameters to be defined
409
+ cache. force_stop = true
410
+ end
342
411
412
+
413
+ elseif radius_update_scheme === RadiusUpdateSchemes. Yuan
414
+ if r < cache. shrink_threshold
415
+ cache. p1 = cache. p2 * cache. p1
416
+ cache. shrink_counter += 1
417
+ elseif r >= cache. expand_threshold && cache. internalnorm (step_size) > cache. trust_r / 2
418
+ cache. p1 = cache. p3 * cache. p1
419
+ end
420
+ @unpack p1, fu, f, J = cache
421
+ # cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
422
+
423
+ if r > cache. step_threshold
424
+ take_step! (cache)
425
+ cache. loss = cache. loss_new
343
426
cache. make_new_J = true
344
- else
345
- # No need to make a new J, no step was taken, so we try again with a smaller trust_r
427
+ else
346
428
cache. make_new_J = false
347
- end
348
-
349
- if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol
429
+ end
430
+
431
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol || cache . internalnorm (g) < cache . ϵ # parameters to be defined
350
432
cache. force_stop = true
433
+ end
434
+
435
+ # elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
351
436
end
352
437
end
353
438
0 commit comments