@@ -206,10 +206,10 @@ end
206
206
fu_new
207
207
make_new_J:: Bool
208
208
r:: floatType
209
- p1:: floatType
210
- p2:: floatType
211
- p3:: floatType
212
- p4:: floatType
209
+ p1:: parType
210
+ p2:: parType
211
+ p3:: parType
212
+ p4:: parType
213
213
ϵ:: floatType
214
214
stats:: NLStats
215
215
end
@@ -227,23 +227,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
227
227
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches (alg, f, u, p, Val (iip);
228
228
linsolve_kwargs)
229
229
230
- radius_update_scheme = alg. radius_update_scheme
231
- max_trust_radius = convert (eltype (u), alg. max_trust_radius)
232
- initial_trust_radius = convert (eltype (u), alg. initial_trust_radius)
233
- step_threshold = convert (eltype (u), alg. step_threshold)
234
- shrink_threshold = convert (eltype (u), alg. shrink_threshold)
235
- expand_threshold = convert (eltype (u), alg. expand_threshold)
236
- shrink_factor = convert (eltype (u), alg. shrink_factor)
237
- expand_factor = convert (eltype (u), alg. expand_factor)
238
-
239
- # Set default trust region radius if not specified
240
- if iszero (max_trust_radius)
241
- max_trust_radius = convert (eltype (u), max (norm (fu1), maximum (u) - minimum (u)))
242
- end
243
- if iszero (initial_trust_radius)
244
- initial_trust_radius = convert (eltype (u), max_trust_radius / 11 )
245
- end
246
-
247
230
loss_new = loss
248
231
H = zero (J)
249
232
g = _mutable_zero (fu1)
@@ -253,31 +236,50 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
253
236
make_new_J = true
254
237
r = loss
255
238
239
+ # set trust region update scheme
240
+ radius_update_scheme = alg. radius_update_scheme
241
+
242
+ # set default type for all trust region parameters
243
+ trustType = Float64 # typeof(alg.initial_trust_radius)
244
+ max_trust_radius = convert (trustType, alg. max_trust_radius)
245
+ if iszero (max_trust_radius)
246
+ max_trust_radius = convert (trustType, max (norm (fu), maximum (u) - minimum (u)))
247
+ end
248
+ initial_trust_radius = convert (trustType, alg. initial_trust_radius)
249
+ if iszero (initial_trust_radius)
250
+ initial_trust_radius = convert (trustType, max_trust_radius / 11 )
251
+ end
252
+ step_threshold = convert (trustType, alg. step_threshold)
253
+ shrink_threshold = convert (trustType, alg. shrink_threshold)
254
+ expand_threshold = convert (trustType, alg. expand_threshold)
255
+ shrink_factor = convert (trustType, alg. shrink_factor)
256
+ expand_factor = convert (trustType, alg. expand_factor)
257
+
256
258
# Parameters for the Schemes
257
- p1 = convert (eltype (u), 0.0 )
258
- p2 = convert (eltype (u), 0.0 )
259
- p3 = convert (eltype (u), 0.0 )
260
- p4 = convert (eltype (u), 0.0 )
261
- ϵ = convert (eltype (u), 1.0e-8 )
259
+ parType = Float64
260
+ p1 = convert (parType, 0.0 )
261
+ p2 = convert (parType, 0.0 )
262
+ p3 = convert (parType, 0.0 )
263
+ p4 = convert (parType, 0.0 )
264
+ ϵ = convert (typeof (r), 1.0e-8 )
262
265
if radius_update_scheme === RadiusUpdateSchemes. NLsolve
263
- p1 = convert (eltype (u) , 0.5 )
266
+ p1 = convert (parType , 0.5 )
264
267
elseif radius_update_scheme === RadiusUpdateSchemes. Hei
265
- step_threshold = convert (eltype (u) , 0.0 )
266
- shrink_threshold = convert (eltype (u) , 0.25 )
267
- expand_threshold = convert (eltype (u) , 0.25 )
268
- p1 = convert (eltype (u) , 5.0 ) # M
269
- p2 = convert (eltype (u) , 0.1 ) # β
270
- p3 = convert (eltype (u) , 0.15 ) # γ1
271
- p4 = convert (eltype (u) , 0.15 ) # γ2
272
- initial_trust_radius = convert (eltype (u) , 1.0 )
268
+ step_threshold = convert (trustType , 0.0 )
269
+ shrink_threshold = convert (trustType , 0.25 )
270
+ expand_threshold = convert (trustType , 0.25 )
271
+ p1 = convert (parType , 5.0 ) # M
272
+ p2 = convert (parType , 0.1 ) # β
273
+ p3 = convert (parType , 0.15 ) # γ1
274
+ p4 = convert (parType , 0.15 ) # γ2
275
+ initial_trust_radius = convert (trustType , 1.0 )
273
276
elseif radius_update_scheme === RadiusUpdateSchemes. Yuan
274
- step_threshold = convert (eltype (u), 0.0001 )
275
- shrink_threshold = convert (eltype (u), 0.25 )
276
- expand_threshold = convert (eltype (u), 0.25 )
277
- p1 = convert (eltype (u), 2.0 ) # μ
278
- p2 = convert (eltype (u), 1 / 6 ) # c5
279
- p3 = convert (eltype (u), 6.0 ) # c6
280
- p4 = convert (eltype (u), 0.0 )
277
+ step_threshold = convert (trustType, 0.0001 )
278
+ shrink_threshold = convert (trustType, 0.25 )
279
+ expand_threshold = convert (trustType, 0.25 )
280
+ p1 = convert (parType, 2.0 ) # μ
281
+ p2 = convert (parType, 1 / 6 ) # c5
282
+ p3 = convert (parType, 6.0 ) # c6
281
283
if iip
282
284
auto_jacvec! (g, (fu, x) -> f (fu, x, p), u, fu1)
283
285
else
@@ -287,25 +289,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
287
289
g = auto_jacvec (x -> f (x, p), u, fu1)
288
290
end
289
291
end
290
- initial_trust_radius = convert (eltype (u) , p1 * norm (g))
292
+ initial_trust_radius = convert (trustType , p1 * norm (g))
291
293
elseif radius_update_scheme === RadiusUpdateSchemes. Fan
292
- step_threshold = convert (eltype (u) , 0.0001 )
293
- shrink_threshold = convert (eltype (u) , 0.25 )
294
- expand_threshold = convert (eltype (u) , 0.75 )
295
- p1 = convert (eltype (u) , 0.1 ) # μ
296
- p2 = convert (eltype (u), 1 / 4 ) # c5
297
- p3 = convert (eltype (u) , 12 ) # c6
298
- p4 = convert (eltype (u) , 1.0e18 ) # M
299
- initial_trust_radius = convert (eltype (u) , p1 * (norm (fu1 )^ 0.99 ))
294
+ step_threshold = convert (trustType , 0.0001 )
295
+ shrink_threshold = convert (trustType , 0.25 )
296
+ expand_threshold = convert (trustType , 0.75 )
297
+ p1 = convert (parType , 0.1 ) # μ
298
+ p2 = convert (parType, 0.25 ) # c5
299
+ p3 = convert (parType , 12.0 ) # c6
300
+ p4 = convert (parType , 1.0e18 ) # M
301
+ initial_trust_radius = convert (trustType , p1 * (norm (fu )^ 0.99 ))
300
302
elseif radius_update_scheme === RadiusUpdateSchemes. Bastin
301
- step_threshold = convert (eltype (u), 0.05 )
302
- shrink_threshold = convert (eltype (u), 0.05 )
303
- expand_threshold = convert (eltype (u), 0.9 )
304
- p1 = convert (eltype (u), 2.5 ) # alpha_1
305
- p2 = convert (eltype (u), 0.25 ) # alpha_2
306
- p3 = convert (eltype (u), 0 ) # not required
307
- p4 = convert (eltype (u), 0 ) # not required
308
- initial_trust_radius = convert (eltype (u), 1.0 )
303
+ step_threshold = convert (trustType, 0.05 )
304
+ shrink_threshold = convert (trustType, 0.05 )
305
+ expand_threshold = convert (trustType, 0.9 )
306
+ p1 = convert (parType, 2.5 ) # alpha_1
307
+ p2 = convert (parType, 0.25 ) # alpha_2
308
+ initial_trust_radius = convert (trustType, 1.0 )
309
309
end
310
310
311
311
return TrustRegionCache {iip} (f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
0 commit comments