Skip to content

Commit b2b5d89

Browse files
committed
add NLsolve-like trust region initialization
1 parent 5f298e3 commit b2b5d89

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/trustRegion.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,19 +238,26 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
238238
make_new_J = true
239239
r = loss
240240

241+
floatType = typeof(r)
242+
241243
# set trust region update scheme
242244
radius_update_scheme = alg.radius_update_scheme
243245

244246
# set default type for all trust region parameters
245-
trustType = Float64 #typeof(alg.initial_trust_radius)
246-
max_trust_radius = convert(trustType, alg.max_trust_radius)
247-
if iszero(max_trust_radius)
248-
max_trust_radius = convert(trustType, max(norm(fu1), maximum(u) - minimum(u)))
247+
trustType = floatType
248+
if radius_update_scheme == RadiusUpdateSchemes.NLsolve
249+
max_trust_radius = convert(trustType, Inf)
250+
initial_trust_radius = norm(u0) > 0 ? convert(trustType, norm(u0)) : one(trustType)
251+
else
252+
max_trust_radius = convert(trustType, alg.max_trust_radius)
253+
if iszero(max_trust_radius)
254+
max_trust_radius = convert(trustType, max(norm(fu1), maximum(u) - minimum(u)))
255+
end
256+
initial_trust_radius = convert(trustType, alg.initial_trust_radius)
257+
if iszero(initial_trust_radius)
258+
initial_trust_radius = convert(trustType, max_trust_radius / 11)
259+
end
249260
end
250-
initial_trust_radius = convert(trustType, alg.initial_trust_radius)
251-
if iszero(initial_trust_radius)
252-
initial_trust_radius = convert(trustType, max_trust_radius / 11)
253-
end
254261
step_threshold = convert(trustType, alg.step_threshold)
255262
shrink_threshold = convert(trustType, alg.shrink_threshold)
256263
expand_threshold = convert(trustType, alg.expand_threshold)

0 commit comments

Comments
 (0)