Skip to content

Commit 0ba652b

Browse files
committed
finish rebase to master
1 parent 439415b commit 0ba652b

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

src/trustRegion.jl

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,14 @@ end
203203
shrink_counter::Int
204204
step_size
205205
u_tmp
206+
u_c
206207
fu_new
207208
make_new_J::Bool
208209
r::floatType
209-
p1::parType
210-
p2::parType
211-
p3::parType
212-
p4::parType
210+
p1::floatType
211+
p2::floatType
212+
p3::floatType
213+
p4::floatType
213214
ϵ::floatType
214215
stats::NLStats
215216
end
@@ -226,6 +227,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
226227
loss = get_loss(fu1)
227228
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
228229
linsolve_kwargs)
230+
u_c = zero(u)
229231

230232
loss_new = loss
231233
H = zero(J)
@@ -243,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
243245
trustType = Float64 #typeof(alg.initial_trust_radius)
244246
max_trust_radius = convert(trustType, alg.max_trust_radius)
245247
if iszero(max_trust_radius)
246-
max_trust_radius = convert(trustType, max(norm(fu), maximum(u) - minimum(u)))
248+
max_trust_radius = convert(trustType, max(norm(fu1), maximum(u) - minimum(u)))
247249
end
248250
initial_trust_radius = convert(trustType, alg.initial_trust_radius)
249251
if iszero(initial_trust_radius)
@@ -256,30 +258,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
256258
expand_factor = convert(trustType, alg.expand_factor)
257259

258260
# Parameters for the Schemes
259-
parType = Float64
260-
p1 = convert(parType, 0.0)
261-
p2 = convert(parType, 0.0)
262-
p3 = convert(parType, 0.0)
263-
p4 = convert(parType, 0.0)
264-
ϵ = convert(typeof(r), 1.0e-8)
261+
floatType = typeof(r)
262+
p1 = convert(floatType, 0.0)
263+
p2 = convert(floatType, 0.0)
264+
p3 = convert(floatType, 0.0)
265+
p4 = convert(floatType, 0.0)
266+
ϵ = convert(floatType, 1.0e-8)
265267
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
266-
p1 = convert(parType, 0.5)
268+
p1 = convert(floatType, 0.5)
267269
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
268270
step_threshold = convert(trustType, 0.0)
269271
shrink_threshold = convert(trustType, 0.25)
270272
expand_threshold = convert(trustType, 0.25)
271-
p1 = convert(parType, 5.0) # M
272-
p2 = convert(parType, 0.1) # β
273-
p3 = convert(parType, 0.15) # γ1
274-
p4 = convert(parType, 0.15) # γ2
273+
p1 = convert(floatType, 5.0) # M
274+
p2 = convert(floatType, 0.1) # β
275+
p3 = convert(floatType, 0.15) # γ1
276+
p4 = convert(floatType, 0.15) # γ2
275277
initial_trust_radius = convert(trustType, 1.0)
276278
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
277279
step_threshold = convert(trustType, 0.0001)
278280
shrink_threshold = convert(trustType, 0.25)
279281
expand_threshold = convert(trustType, 0.25)
280-
p1 = convert(parType, 2.0) # μ
281-
p2 = convert(parType, 1 / 6) # c5
282-
p3 = convert(parType, 6.0) # c6
282+
p1 = convert(floatType, 2.0) # μ
283+
p2 = convert(floatType, 1 / 6) # c5
284+
p3 = convert(floatType, 6.0) # c6
283285
if iip
284286
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu1)
285287
else
@@ -294,25 +296,25 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
294296
step_threshold = convert(trustType, 0.0001)
295297
shrink_threshold = convert(trustType, 0.25)
296298
expand_threshold = convert(trustType, 0.75)
297-
p1 = convert(parType, 0.1) # μ
298-
p2 = convert(parType, 0.25) # c5
299-
p3 = convert(parType, 12.0) # c6
300-
p4 = convert(parType, 1.0e18) # M
299+
p1 = convert(floatType, 0.1) # μ
300+
p2 = convert(floatType, 0.25) # c5
301+
p3 = convert(floatType, 12.0) # c6
302+
p4 = convert(floatType, 1.0e18) # M
301303
initial_trust_radius = convert(trustType, p1 * (norm(fu)^0.99))
302304
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
303305
step_threshold = convert(trustType, 0.05)
304306
shrink_threshold = convert(trustType, 0.05)
305307
expand_threshold = convert(trustType, 0.9)
306-
p1 = convert(parType, 2.5) #alpha_1
307-
p2 = convert(parType, 0.25) # alpha_2
308+
p1 = convert(floatType, 2.5) # alpha_1
309+
p2 = convert(floatType, 0.25) # alpha_2
308310
initial_trust_radius = convert(trustType, 1.0)
309311
end
310312

311313
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
312314
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
313315
radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold,
314316
shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new,
315-
H, g, shrink_counter, step_size, du, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ,
317+
H, g, shrink_counter, step_size, du, u_c, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ,
316318
NLStats(1, 0, 0, 0, 0))
317319
end
318320

@@ -321,7 +323,7 @@ isinplace(::TrustRegionCache{iip}) where {iip} = iip
321323
function perform_step!(cache::TrustRegionCache{true})
322324
@unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache
323325
if cache.make_new_J
324-
jacobian!(J, cache)
326+
jacobian!!(J, cache)
325327
mul!(cache.H, J', J)
326328
mul!(cache.g, J', fu)
327329
cache.stats.njacs += 1
@@ -348,7 +350,7 @@ function perform_step!(cache::TrustRegionCache{false})
348350
@unpack make_new_J, fu, f, u, p = cache
349351

350352
if make_new_J
351-
J = jacobian(cache, f)
353+
J = jacobian!!(cache.J, cache)
352354
cache.H = J' * J
353355
cache.g = J' * fu
354356
cache.stats.njacs += 1
@@ -373,11 +375,11 @@ function retrospective_step!(cache::TrustRegionCache)
373375
@unpack J, fu_prev, fu, u_prev, u = cache
374376
J = jacobian!!(deepcopy(J), cache)
375377
if J isa Number
376-
cache.H = J * J
377-
cache.g = J * fu
378+
cache.H = J' * J
379+
cache.g = J' * fu
378380
else
379-
mul!(cache.H, J, J)
380-
mul!(cache.g, J, fu)
381+
mul!(cache.H, J', J)
382+
mul!(cache.g, J', fu)
381383
end
382384
cache.stats.njacs += 1
383385
@unpack H, g, step_size = cache

0 commit comments

Comments
 (0)