Skip to content

Commit ada7d8d

Browse files
committed
cache modification
1 parent 8d009b9 commit ada7d8d

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/trustRegion.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
134134
trustType, suType, su2Type, tmpType}
135135
f::fType
136136
alg::algType
137+
u_prev::uType
137138
u::uType
139+
fu_prev::resType
138140
fu::resType
139141
p::pType
140142
uf::ufType
@@ -172,7 +174,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
172174
ϵ::floatType
173175
stats::NLStats
174176

175-
function TrustRegionCache{iip}(f::fType, alg::algType, u::uType, fu::resType, p::pType,
177+
function TrustRegionCache{iip}(f::fType, alg::algType, u_prev::uType, u::uType, fu_prev::resType, fu::resType, p::pType,
176178
uf::ufType, linsolve::L, J::jType, jac_config::JC,
177179
force_stop::Bool, maxiters::Int, internalnorm::INType,
178180
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
@@ -194,7 +196,7 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
194196
suType, su2Type, tmpType}
195197
new{iip, fType, algType, uType, resType, pType,
196198
INType, tolType, probType, ufType, L, jType, JC, floatType,
197-
trustType, suType, su2Type, tmpType}(f, alg, u, fu, p, uf, linsolve, J,
199+
trustType, suType, su2Type, tmpType}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J,
198200
jac_config, force_stop,
199201
maxiters, internalnorm, retcode,
200202
abstol, prob, radius_update_scheme,
@@ -246,6 +248,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
246248
else
247249
u = deepcopy(prob.u0)
248250
end
251+
u_prev = deepcopy(u)
249252
f = prob.f
250253
p = prob.p
251254
if iip
@@ -254,6 +257,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
254257
else
255258
fu = f(u, p)
256259
end
260+
fu_prev = deepcopy(fu)
257261

258262
loss = get_loss(fu)
259263
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))
@@ -325,9 +329,18 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
325329
p3 = convert(eltype(u), 12) # c6
326330
p4 = convert(eltype(u), 1.0e18) # M
327331
initial_trust_radius = convert(eltype(u), p1 * (norm(fu)^0.99))
332+
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
333+
step_threshold = convert(eltype(u), 0.05)
334+
shrink_threshold = convert(eltype(u), 0.05)
335+
expand_threshold = convert(eltype(u), 0.9)
336+
p1 = convert(eltype(u), 2.5) #alpha_1
337+
p2 = convert(eltype(u), 0.25) # alpha_2
338+
p3 = convert(eltype(u), 0) # not required
339+
p4 = convert(eltype(u), 0) # not required
340+
initial_trust_radius = convert(eltype(u), 1.0)
328341
end
329342

330-
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
343+
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu, p, uf, linsolve, J, jac_config,
331344
false, maxiters, internalnorm,
332345
ReturnCode.Default, abstol, prob, radius_update_scheme,
333346
initial_trust_radius,

0 commit comments

Comments
 (0)