Skip to content

Commit b1f34dd

Browse files
committed
type fixes
1 parent 17abaf0 commit b1f34dd

File tree

1 file changed

+37
-44
lines changed

1 file changed

+37
-44
lines changed

src/trustRegion.jl

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,6 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
170170
p3::floatType
171171
p4::floatType
172172
ϵ::floatType
173-
# p5::floatType
174-
# p6::floatType
175-
# p7::floatType
176173

177174
function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
178175
uf::ufType, linsolve::L, J::jType,
@@ -283,28 +280,28 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
283280
r = loss
284281

285282
# Parameters for the Schemes
286-
p1 = 0
287-
p2 = 0
288-
p3 = 0
289-
p4 = 0
290-
ϵ = 1e-8
283+
p1 = convert(eltype(u), 0.0)
284+
p2 = convert(eltype(u), 0.0)
285+
p3 = convert(eltype(u), 0.0)
286+
p4 = convert(eltype(u), 0.0)
287+
ϵ = convert(eltype(u), 1.0e-8)
291288
if radius_update_scheme === RadiusUpdateSchemes.Hei
292-
step_threshold = 0
293-
shrink_threshold = 0.25
294-
expand_threshold = 0.25
295-
p1 = 5.0 # M
296-
p2 = 0.1 # β
297-
p3 = 0.15 # γ1
298-
p4 = 0.15 # γ2
289+
step_threshold = convert(eltype(u), 0.0)
290+
shrink_threshold = convert(eltype(u), 0.25)
291+
expand_threshold = convert(eltype(u), 0.25)
292+
p1 = convert(eltype(u), 5.0) # M
293+
p2 = convert(eltype(u), 0.1) # β
294+
p3 = convert(eltype(u), 0.15) # γ1
295+
p4 = convert(eltype(u), 0.15) # γ2
299296
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
300-
step_threshold = 0.0001
301-
shrink_threshold = 0.25
302-
expand_threshold = 0.25
303-
p1 = 2.0 # μ
304-
p2 = 1/6 # c5
305-
p3 = 6 # c6
306-
p4 = 0
307-
end
297+
step_threshold = convert(eltype(u), 0.0001)
298+
shrink_threshold = convert(eltype(u), 0.25)
299+
expand_threshold = convert(eltype(u), 0.25)
300+
p1 = convert(eltype(u), 2.0) # μ
301+
p2 = convert(eltype(u), 1/6) # c5
302+
p3 = convert(eltype(u), 6.0) # c6
303+
p4 = convert(eltype(u), 0.0)
304+
end
308305

309306
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
310307
1, false, maxiters, internalnorm,
@@ -333,8 +330,6 @@ function perform_step!(cache::TrustRegionCache{true})
333330
# Compute the potentially new u
334331
cache.u_tmp .= u .+ cache.step_size
335332
f(cache.fu_new, cache.u_tmp, p)
336-
337-
@unpack radius_update_scheme = cache
338333
trust_region_step!(cache)
339334
return nothing
340335
end
@@ -356,8 +351,6 @@ function perform_step!(cache::TrustRegionCache{false})
356351
# Compute the potentially new u
357352
cache.u_tmp = u .+ cache.step_size
358353
cache.fu_new = f(cache.u_tmp, p)
359-
360-
@unpack radius_update_scheme = cache
361354
trust_region_step!(cache)
362355
return nothing
363356
end
@@ -398,48 +391,48 @@ function trust_region_step!(cache::TrustRegionCache)
398391
end
399392

400393
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
401-
if r > cache.step_threshold # parameters to be defined
394+
if r > cache.step_threshold
402395
take_step!(cache)
403396
cache.loss = cache.loss_new
404397
cache.make_new_J = true
405398
else
406399
cache.make_new_J = false
407400
end
408401
# Hei's radius update scheme
409-
@unpack shrink_threshold, p1, p2, p3, p4, ϵ = cache
402+
@unpack shrink_threshold, p1, p2, p3, p4 = cache
403+
if rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) < cache.trust_r
404+
cache.shrink_counter += 1
405+
end
410406
cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(step_size) # parameters to be defined
411407

412-
if iszero(fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < ϵ # parameters to be defined
408+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
413409
cache.force_stop = true
414410
end
415411

416412

417413
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
414+
if r < cache.shrink_threshold
415+
cache.p1 = cache.p2 * cache.p1
416+
cache.shrink_counter += 1
417+
elseif r >= cache.expand_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
418+
cache.p1 = cache.p3 * cache.p1
419+
end
420+
@unpack p1, fu, f, J = cache
421+
#cache.trust_r = p1 * cache.internalnorm(jacobian!(J, cache) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
422+
418423
if r > cache.step_threshold
419424
take_step!(cache)
420425
cache.loss = cache.loss_new
421426
cache.make_new_J = true
422427
else
423428
cache.make_new_J = false
424429
end
425-
if r < cache.shrink_threshold
426-
cache.p1 = p2 * cache.p1
427-
elseif r >= cache.shrink_threshold && cache.internalnorm(step_size) > cache.trust_r / 2
428-
cache.p1 = p3 * cache.p1
429-
end
430-
@unpack p1 = cache.p1
431-
432-
# yuan's scheme
433-
@unpack fu = cache
434-
cache.trust_r = p1 * cache.internalnorm(jacobian(cache, f) * fu) # we need the gradient at the new (k+1)th point WILL THIS BECOME ALLOCATING?
435-
436-
if iszero(fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < ϵ # parameters to be defined
430+
431+
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || cache.internalnorm(g) < cache.ϵ # parameters to be defined
437432
cache.force_stop = true
438433
end
439434

440435
#elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
441-
442-
443436
end
444437
end
445438

0 commit comments

Comments
 (0)