@@ -170,9 +170,6 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
170
170
p3:: floatType
171
171
p4:: floatType
172
172
ϵ:: floatType
173
- # p5::floatType
174
- # p6::floatType
175
- # p7::floatType
176
173
177
174
function TrustRegionCache {iip} (f:: fType , alg:: algType , u:: uType , fu:: resType , p:: pType ,
178
175
uf:: ufType , linsolve:: L , J:: jType ,
@@ -283,28 +280,28 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
283
280
r = loss
284
281
285
282
# Parameters for the Schemes
286
- p1 = 0
287
- p2 = 0
288
- p3 = 0
289
- p4 = 0
290
- ϵ = 1e-8
283
+ p1 = convert ( eltype (u), 0.0 )
284
+ p2 = convert ( eltype (u), 0.0 )
285
+ p3 = convert ( eltype (u), 0.0 )
286
+ p4 = convert ( eltype (u), 0.0 )
287
+ ϵ = convert ( eltype (u), 1.0e-8 )
291
288
if radius_update_scheme === RadiusUpdateSchemes. Hei
292
- step_threshold = 0
293
- shrink_threshold = 0.25
294
- expand_threshold = 0.25
295
- p1 = 5.0 # M
296
- p2 = 0.1 # β
297
- p3 = 0.15 # γ1
298
- p4 = 0.15 # γ2
289
+ step_threshold = convert ( eltype (u), 0.0 )
290
+ shrink_threshold = convert ( eltype (u), 0.25 )
291
+ expand_threshold = convert ( eltype (u), 0.25 )
292
+ p1 = convert ( eltype (u), 5.0 ) # M
293
+ p2 = convert ( eltype (u), 0.1 ) # β
294
+ p3 = convert ( eltype (u), 0.15 ) # γ1
295
+ p4 = convert ( eltype (u), 0.15 ) # γ2
299
296
elseif radius_update_scheme === RadiusUpdateSchemes. Yuan
300
- step_threshold = 0.0001
301
- shrink_threshold = 0.25
302
- expand_threshold = 0.25
303
- p1 = 2.0 # μ
304
- p2 = 1 / 6 # c5
305
- p3 = 6 # c6
306
- p4 = 0
307
- end
297
+ step_threshold = convert ( eltype (u), 0.0001 )
298
+ shrink_threshold = convert ( eltype (u), 0.25 )
299
+ expand_threshold = convert ( eltype (u), 0.25 )
300
+ p1 = convert ( eltype (u), 2.0 ) # μ
301
+ p2 = convert ( eltype (u), 1 / 6 ) # c5
302
+ p3 = convert ( eltype (u), 6.0 ) # c6
303
+ p4 = convert ( eltype (u), 0.0 )
304
+ end
308
305
309
306
return TrustRegionCache {iip} (f, alg, u, fu, p, uf, linsolve, J, jac_config,
310
307
1 , false , maxiters, internalnorm,
@@ -333,8 +330,6 @@ function perform_step!(cache::TrustRegionCache{true})
333
330
# Compute the potentially new u
334
331
cache. u_tmp .= u .+ cache. step_size
335
332
f (cache. fu_new, cache. u_tmp, p)
336
-
337
- @unpack radius_update_scheme = cache
338
333
trust_region_step! (cache)
339
334
return nothing
340
335
end
@@ -356,8 +351,6 @@ function perform_step!(cache::TrustRegionCache{false})
356
351
# Compute the potentially new u
357
352
cache. u_tmp = u .+ cache. step_size
358
353
cache. fu_new = f (cache. u_tmp, p)
359
-
360
- @unpack radius_update_scheme = cache
361
354
trust_region_step! (cache)
362
355
return nothing
363
356
end
@@ -398,48 +391,48 @@ function trust_region_step!(cache::TrustRegionCache)
398
391
end
399
392
400
393
elseif radius_update_scheme === RadiusUpdateSchemes. Hei
401
- if r > cache. step_threshold # parameters to be defined
394
+ if r > cache. step_threshold
402
395
take_step! (cache)
403
396
cache. loss = cache. loss_new
404
397
cache. make_new_J = true
405
398
else
406
399
cache. make_new_J = false
407
400
end
408
401
# Hei's radius update scheme
409
- @unpack shrink_threshold, p1, p2, p3, p4, ϵ = cache
402
+ @unpack shrink_threshold, p1, p2, p3, p4 = cache
403
+ if rfunc (r, shrink_threshold, p1, p3, p4, p2) * cache. internalnorm (step_size) < cache. trust_r
404
+ cache. shrink_counter += 1
405
+ end
410
406
cache. trust_r = rfunc (r, shrink_threshold, p1, p3, p4, p2) * cache. internalnorm (step_size) # parameters to be defined
411
407
412
- if iszero (fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < ϵ # parameters to be defined
408
+ if iszero (cache . fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < cache . ϵ # parameters to be defined
413
409
cache. force_stop = true
414
410
end
415
411
416
412
417
413
elseif radius_update_scheme === RadiusUpdateSchemes. Yuan
414
+ if r < cache. shrink_threshold
415
+ cache. p1 = cache. p2 * cache. p1
416
+ cache. shrink_counter += 1
417
+ elseif r >= cache. expand_threshold && cache. internalnorm (step_size) > cache. trust_r / 2
418
+ cache. p1 = cache. p3 * cache. p1
419
+ end
420
+ @unpack p1, fu, f, J = cache
421
+ # cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
422
+
418
423
if r > cache. step_threshold
419
424
take_step! (cache)
420
425
cache. loss = cache. loss_new
421
426
cache. make_new_J = true
422
427
else
423
428
cache. make_new_J = false
424
429
end
425
- if r < cache. shrink_threshold
426
- cache. p1 = p2 * cache. p1
427
- elseif r >= cache. shrink_threshold && cache. internalnorm (step_size) > cache. trust_r / 2
428
- cache. p1 = p3 * cache. p1
429
- end
430
- @unpack p1 = cache. p1
431
-
432
- # yuan's scheme
433
- @unpack fu = cache
434
- cache. trust_r = p1 * cache. internalnorm (jacobian (cache, f) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
435
-
436
- if iszero (fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < ϵ # parameters to be defined
430
+
431
+ if iszero (cache. fu) || cache. internalnorm (cache. fu) < cache. abstol || cache. internalnorm (g) < cache. ϵ # parameters to be defined
437
432
cache. force_stop = true
438
433
end
439
434
440
435
# elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
441
-
442
-
443
436
end
444
437
end
445
438
0 commit comments