@@ -203,13 +203,14 @@ end
203
203
shrink_counter:: Int
204
204
step_size
205
205
u_tmp
206
+ u_c
206
207
fu_new
207
208
make_new_J:: Bool
208
209
r:: floatType
209
- p1:: parType
210
- p2:: parType
211
- p3:: parType
212
- p4:: parType
210
+ p1:: floatType
211
+ p2:: floatType
212
+ p3:: floatType
213
+ p4:: floatType
213
214
ϵ:: floatType
214
215
stats:: NLStats
215
216
end
@@ -226,6 +227,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
226
227
loss = get_loss (fu1)
227
228
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches (alg, f, u, p, Val (iip);
228
229
linsolve_kwargs)
230
+ u_c = zero (u)
229
231
230
232
loss_new = loss
231
233
H = zero (J)
@@ -243,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
243
245
trustType = Float64 # typeof(alg.initial_trust_radius)
244
246
max_trust_radius = convert (trustType, alg. max_trust_radius)
245
247
if iszero (max_trust_radius)
246
- max_trust_radius = convert (trustType, max (norm (fu ), maximum (u) - minimum (u)))
248
+ max_trust_radius = convert (trustType, max (norm (fu1 ), maximum (u) - minimum (u)))
247
249
end
248
250
initial_trust_radius = convert (trustType, alg. initial_trust_radius)
249
251
if iszero (initial_trust_radius)
@@ -256,30 +258,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
256
258
expand_factor = convert (trustType, alg. expand_factor)
257
259
258
260
# Parameters for the Schemes
259
- parType = Float64
260
- p1 = convert (parType , 0.0 )
261
- p2 = convert (parType , 0.0 )
262
- p3 = convert (parType , 0.0 )
263
- p4 = convert (parType , 0.0 )
264
- ϵ = convert (typeof (r) , 1.0e-8 )
261
+ floatType = typeof (r)
262
+ p1 = convert (floatType , 0.0 )
263
+ p2 = convert (floatType , 0.0 )
264
+ p3 = convert (floatType , 0.0 )
265
+ p4 = convert (floatType , 0.0 )
266
+ ϵ = convert (floatType , 1.0e-8 )
265
267
if radius_update_scheme === RadiusUpdateSchemes. NLsolve
266
- p1 = convert (parType , 0.5 )
268
+ p1 = convert (floatType , 0.5 )
267
269
elseif radius_update_scheme === RadiusUpdateSchemes. Hei
268
270
step_threshold = convert (trustType, 0.0 )
269
271
shrink_threshold = convert (trustType, 0.25 )
270
272
expand_threshold = convert (trustType, 0.25 )
271
- p1 = convert (parType , 5.0 ) # M
272
- p2 = convert (parType , 0.1 ) # β
273
- p3 = convert (parType , 0.15 ) # γ1
274
- p4 = convert (parType , 0.15 ) # γ2
273
+ p1 = convert (floatType , 5.0 ) # M
274
+ p2 = convert (floatType , 0.1 ) # β
275
+ p3 = convert (floatType , 0.15 ) # γ1
276
+ p4 = convert (floatType , 0.15 ) # γ2
275
277
initial_trust_radius = convert (trustType, 1.0 )
276
278
elseif radius_update_scheme === RadiusUpdateSchemes. Yuan
277
279
step_threshold = convert (trustType, 0.0001 )
278
280
shrink_threshold = convert (trustType, 0.25 )
279
281
expand_threshold = convert (trustType, 0.25 )
280
- p1 = convert (parType , 2.0 ) # μ
281
- p2 = convert (parType , 1 / 6 ) # c5
282
- p3 = convert (parType , 6.0 ) # c6
282
+ p1 = convert (floatType , 2.0 ) # μ
283
+ p2 = convert (floatType , 1 / 6 ) # c5
284
+ p3 = convert (floatType , 6.0 ) # c6
283
285
if iip
284
286
auto_jacvec! (g, (fu, x) -> f (fu, x, p), u, fu1)
285
287
else
@@ -294,25 +296,25 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
294
296
step_threshold = convert (trustType, 0.0001 )
295
297
shrink_threshold = convert (trustType, 0.25 )
296
298
expand_threshold = convert (trustType, 0.75 )
297
- p1 = convert (parType , 0.1 ) # μ
298
- p2 = convert (parType , 0.25 ) # c5
299
- p3 = convert (parType , 12.0 ) # c6
300
- p4 = convert (parType , 1.0e18 ) # M
299
+ p1 = convert (floatType , 0.1 ) # μ
300
+ p2 = convert (floatType , 0.25 ) # c5
301
+ p3 = convert (floatType , 12.0 ) # c6
302
+ p4 = convert (floatType , 1.0e18 ) # M
301
303
initial_trust_radius = convert (trustType, p1 * (norm (fu)^ 0.99 ))
302
304
elseif radius_update_scheme === RadiusUpdateSchemes. Bastin
303
305
step_threshold = convert (trustType, 0.05 )
304
306
shrink_threshold = convert (trustType, 0.05 )
305
307
expand_threshold = convert (trustType, 0.9 )
306
- p1 = convert (parType , 2.5 ) # alpha_1
307
- p2 = convert (parType , 0.25 ) # alpha_2
308
+ p1 = convert (floatType , 2.5 ) # alpha_1
309
+ p2 = convert (floatType , 0.25 ) # alpha_2
308
310
initial_trust_radius = convert (trustType, 1.0 )
309
311
end
310
312
311
313
return TrustRegionCache {iip} (f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
312
314
jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol, prob,
313
315
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
314
316
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
315
- H, g, shrink_counter, step_size, du, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ,
317
+ H, g, shrink_counter, step_size, du, u_c, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ,
316
318
NLStats (1 , 0 , 0 , 0 , 0 ))
317
319
end
318
320
@@ -321,7 +323,7 @@ isinplace(::TrustRegionCache{iip}) where {iip} = iip
321
323
function perform_step! (cache:: TrustRegionCache{true} )
322
324
@unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache
323
325
if cache. make_new_J
324
- jacobian! (J, cache)
326
+ jacobian!! (J, cache)
325
327
mul! (cache. H, J' , J)
326
328
mul! (cache. g, J' , fu)
327
329
cache. stats. njacs += 1
@@ -348,7 +350,7 @@ function perform_step!(cache::TrustRegionCache{false})
348
350
@unpack make_new_J, fu, f, u, p = cache
349
351
350
352
if make_new_J
351
- J = jacobian (cache, f )
353
+ J = jacobian!! (cache. J, cache )
352
354
cache. H = J' * J
353
355
cache. g = J' * fu
354
356
cache. stats. njacs += 1
@@ -373,11 +375,11 @@ function retrospective_step!(cache::TrustRegionCache)
373
375
@unpack J, fu_prev, fu, u_prev, u = cache
374
376
J = jacobian!! (deepcopy (J), cache)
375
377
if J isa Number
376
- cache. H = J * J
377
- cache. g = J * fu
378
+ cache. H = J' * J
379
+ cache. g = J' * fu
378
380
else
379
- mul! (cache. H, J, J)
380
- mul! (cache. g, J, fu)
381
+ mul! (cache. H, J' , J)
382
+ mul! (cache. g, J' , fu)
381
383
end
382
384
cache. stats. njacs += 1
383
385
@unpack H, g, step_size = cache
0 commit comments