Skip to content

Commit 439415b

Browse files
committed
parameter types should not be converted to eltype(u). For now, default to Float64.
1 parent 0e99655 commit 439415b

File tree

1 file changed

+59
-59
lines changed

1 file changed

+59
-59
lines changed

src/trustRegion.jl

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ end
206206
fu_new
207207
make_new_J::Bool
208208
r::floatType
209-
p1::floatType
210-
p2::floatType
211-
p3::floatType
212-
p4::floatType
209+
p1::parType
210+
p2::parType
211+
p3::parType
212+
p4::parType
213213
ϵ::floatType
214214
stats::NLStats
215215
end
@@ -227,23 +227,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
227227
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
228228
linsolve_kwargs)
229229

230-
radius_update_scheme = alg.radius_update_scheme
231-
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
232-
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
233-
step_threshold = convert(eltype(u), alg.step_threshold)
234-
shrink_threshold = convert(eltype(u), alg.shrink_threshold)
235-
expand_threshold = convert(eltype(u), alg.expand_threshold)
236-
shrink_factor = convert(eltype(u), alg.shrink_factor)
237-
expand_factor = convert(eltype(u), alg.expand_factor)
238-
239-
# Set default trust region radius if not specified
240-
if iszero(max_trust_radius)
241-
max_trust_radius = convert(eltype(u), max(norm(fu1), maximum(u) - minimum(u)))
242-
end
243-
if iszero(initial_trust_radius)
244-
initial_trust_radius = convert(eltype(u), max_trust_radius / 11)
245-
end
246-
247230
loss_new = loss
248231
H = zero(J)
249232
g = _mutable_zero(fu1)
@@ -253,31 +236,50 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
253236
make_new_J = true
254237
r = loss
255238

239+
# set trust region update scheme
240+
radius_update_scheme = alg.radius_update_scheme
241+
242+
# set default type for all trust region parameters
243+
trustType = Float64 #typeof(alg.initial_trust_radius)
244+
max_trust_radius = convert(trustType, alg.max_trust_radius)
245+
if iszero(max_trust_radius)
246+
max_trust_radius = convert(trustType, max(norm(fu), maximum(u) - minimum(u)))
247+
end
248+
initial_trust_radius = convert(trustType, alg.initial_trust_radius)
249+
if iszero(initial_trust_radius)
250+
initial_trust_radius = convert(trustType, max_trust_radius / 11)
251+
end
252+
step_threshold = convert(trustType, alg.step_threshold)
253+
shrink_threshold = convert(trustType, alg.shrink_threshold)
254+
expand_threshold = convert(trustType, alg.expand_threshold)
255+
shrink_factor = convert(trustType, alg.shrink_factor)
256+
expand_factor = convert(trustType, alg.expand_factor)
257+
256258
# Parameters for the Schemes
257-
p1 = convert(eltype(u), 0.0)
258-
p2 = convert(eltype(u), 0.0)
259-
p3 = convert(eltype(u), 0.0)
260-
p4 = convert(eltype(u), 0.0)
261-
ϵ = convert(eltype(u), 1.0e-8)
259+
parType = Float64
260+
p1 = convert(parType, 0.0)
261+
p2 = convert(parType, 0.0)
262+
p3 = convert(parType, 0.0)
263+
p4 = convert(parType, 0.0)
264+
ϵ = convert(typeof(r), 1.0e-8)
262265
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
263-
p1 = convert(eltype(u), 0.5)
266+
p1 = convert(parType, 0.5)
264267
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
265-
step_threshold = convert(eltype(u), 0.0)
266-
shrink_threshold = convert(eltype(u), 0.25)
267-
expand_threshold = convert(eltype(u), 0.25)
268-
p1 = convert(eltype(u), 5.0) # M
269-
p2 = convert(eltype(u), 0.1) # β
270-
p3 = convert(eltype(u), 0.15) # γ1
271-
p4 = convert(eltype(u), 0.15) # γ2
272-
initial_trust_radius = convert(eltype(u), 1.0)
268+
step_threshold = convert(trustType, 0.0)
269+
shrink_threshold = convert(trustType, 0.25)
270+
expand_threshold = convert(trustType, 0.25)
271+
p1 = convert(parType, 5.0) # M
272+
p2 = convert(parType, 0.1) # β
273+
p3 = convert(parType, 0.15) # γ1
274+
p4 = convert(parType, 0.15) # γ2
275+
initial_trust_radius = convert(trustType, 1.0)
273276
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
274-
step_threshold = convert(eltype(u), 0.0001)
275-
shrink_threshold = convert(eltype(u), 0.25)
276-
expand_threshold = convert(eltype(u), 0.25)
277-
p1 = convert(eltype(u), 2.0) # μ
278-
p2 = convert(eltype(u), 1 / 6) # c5
279-
p3 = convert(eltype(u), 6.0) # c6
280-
p4 = convert(eltype(u), 0.0)
277+
step_threshold = convert(trustType, 0.0001)
278+
shrink_threshold = convert(trustType, 0.25)
279+
expand_threshold = convert(trustType, 0.25)
280+
p1 = convert(parType, 2.0) # μ
281+
p2 = convert(parType, 1 / 6) # c5
282+
p3 = convert(parType, 6.0) # c6
281283
if iip
282284
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu1)
283285
else
@@ -287,25 +289,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
287289
g = auto_jacvec(x -> f(x, p), u, fu1)
288290
end
289291
end
290-
initial_trust_radius = convert(eltype(u), p1 * norm(g))
292+
initial_trust_radius = convert(trustType, p1 * norm(g))
291293
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
292-
step_threshold = convert(eltype(u), 0.0001)
293-
shrink_threshold = convert(eltype(u), 0.25)
294-
expand_threshold = convert(eltype(u), 0.75)
295-
p1 = convert(eltype(u), 0.1) # μ
296-
p2 = convert(eltype(u), 1 / 4) # c5
297-
p3 = convert(eltype(u), 12) # c6
298-
p4 = convert(eltype(u), 1.0e18) # M
299-
initial_trust_radius = convert(eltype(u), p1 * (norm(fu1)^0.99))
294+
step_threshold = convert(trustType, 0.0001)
295+
shrink_threshold = convert(trustType, 0.25)
296+
expand_threshold = convert(trustType, 0.75)
297+
p1 = convert(parType, 0.1) # μ
298+
p2 = convert(parType, 0.25) # c5
299+
p3 = convert(parType, 12.0) # c6
300+
p4 = convert(parType, 1.0e18) # M
301+
initial_trust_radius = convert(trustType, p1 * (norm(fu)^0.99))
300302
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
301-
step_threshold = convert(eltype(u), 0.05)
302-
shrink_threshold = convert(eltype(u), 0.05)
303-
expand_threshold = convert(eltype(u), 0.9)
304-
p1 = convert(eltype(u), 2.5) #alpha_1
305-
p2 = convert(eltype(u), 0.25) # alpha_2
306-
p3 = convert(eltype(u), 0) # not required
307-
p4 = convert(eltype(u), 0) # not required
308-
initial_trust_radius = convert(eltype(u), 1.0)
303+
step_threshold = convert(trustType, 0.05)
304+
shrink_threshold = convert(trustType, 0.05)
305+
expand_threshold = convert(trustType, 0.9)
306+
p1 = convert(parType, 2.5) #alpha_1
307+
p2 = convert(parType, 0.25) # alpha_2
308+
initial_trust_radius = convert(trustType, 1.0)
309309
end
310310

311311
return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,

0 commit comments

Comments
 (0)