Skip to content

Commit 9f578c8

Browse files
committed
Proper handling of complex numbers and failures
1 parent 1e4c3c0 commit 9f578c8

File tree

7 files changed

+46
-32
lines changed

7 files changed

+46
-32
lines changed

src/broyden.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
111111

112112
if all(cache.reset_check, du) || all(cache.reset_check, dfu)
113113
if cache.resets cache.max_resets
114-
cache.retcode = ReturnCode.Unstable
114+
cache.retcode = ReturnCode.ConvergenceFailure
115115
cache.force_stop = true
116116
return nothing
117117
end
@@ -153,7 +153,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
153153
cache.dfu = cache.fu2 .- cache.fu
154154
if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
155155
if cache.resets cache.max_resets
156-
cache.retcode = ReturnCode.Unstable
156+
cache.retcode = ReturnCode.ConvergenceFailure
157157
cache.force_stop = true
158158
return nothing
159159
end

src/default.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip},
128128
@unpack adkwargs, linsolve, precs = alg
129129

130130
algs = (
131-
# Klement(),
132-
# Broyden(),
131+
GeneralKlement(; linsolve, precs),
132+
GeneralBroyden(),
133133
NewtonRaphson(; linsolve, precs, adkwargs...),
134134
NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...),
135135
TrustRegion(; linsolve, precs, adkwargs...),
@@ -159,7 +159,7 @@ end
159159
]
160160
else
161161
[
162-
:(GeneralKlement()),
162+
:(GeneralKlement(; linsolve, precs)),
163163
:(GeneralBroyden()),
164164
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
165165
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
@@ -191,7 +191,7 @@ end
191191
push!(calls,
192192
quote
193193
resids = tuple($(Tuple(resids)...))
194-
minfu, idx = findmin(DEFAULT_NORM, resids)
194+
minfu, idx = __findmin(DEFAULT_NORM, resids)
195195
end)
196196

197197
for i in 1:length(algs)
@@ -249,7 +249,7 @@ end
249249
retcode = ReturnCode.MaxIters
250250

251251
fus = tuple($(Tuple(resids)...))
252-
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
252+
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
253253
stats = cache.caches[idx].stats
254254
u = cache.caches[idx].u
255255

src/klement.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ function perform_step!(cache::GeneralKlementCache{true})
118118
if singular
119119
if cache.resets == alg.max_resets
120120
cache.force_stop = true
121-
cache.retcode = ReturnCode.Unstable
121+
cache.retcode = ReturnCode.ConvergenceFailure
122122
return nothing
123123
end
124124
fact_done = false
@@ -176,7 +176,7 @@ function perform_step!(cache::GeneralKlementCache{false})
176176
if singular
177177
if cache.resets == alg.max_resets
178178
cache.force_stop = true
179-
cache.retcode = ReturnCode.Unstable
179+
cache.retcode = ReturnCode.ConvergenceFailure
180180
return nothing
181181
end
182182
fact_done = false

src/lbroyden.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
128128
if cache.iterations_since_reset > size(cache.U, 1) &&
129129
(all(cache.reset_check, du) || all(cache.reset_check, cache.dfu))
130130
if cache.resets cache.max_resets
131-
cache.retcode = ReturnCode.Unstable
131+
cache.retcode = ReturnCode.ConvergenceFailure
132132
cache.force_stop = true
133133
return nothing
134134
end
@@ -188,7 +188,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
188188
if cache.iterations_since_reset > size(cache.U, 1) &&
189189
(all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu))
190190
if cache.resets cache.max_resets
191-
cache.retcode = ReturnCode.Unstable
191+
cache.retcode = ReturnCode.ConvergenceFailure
192192
cache.force_stop = true
193193
return nothing
194194
end

src/raphson.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ end
8080

8181
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
8282
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
83-
termination_condition = nothing,
84-
internalnorm = DEFAULT_NORM,
83+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
8584
linsolve_kwargs = (;), kwargs...) where {uType, iip}
8685
alg = get_concrete_algorithm(alg_, prob)
8786
@unpack f, u0, p = prob
@@ -91,9 +90,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
9190
linsolve_kwargs)
9291

9392
abstol, reltol, termination_condition = _init_termination_elements(abstol,
94-
reltol,
95-
termination_condition,
96-
eltype(u))
93+
reltol, termination_condition, eltype(u))
9794

9895
mode = DiffEqBase.get_termination_mode(termination_condition)
9996

src/trustRegion.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,6 @@ for large-scale and numerically-difficult nonlinear systems.
141141
`expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
142142
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
143143
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
144-
145-
!!! warning
146-
147-
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
148-
Support for the OOP version is planned!
149144
"""
150145
@concrete struct TrustRegion{CJ, AD, MTR} <:
151146
AbstractNewtonAlgorithm{CJ, AD}
@@ -250,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
250245
linsolve_kwargs)
251246
u_tmp = zero(u)
252247
u_cauchy = zero(u)
253-
u_gauss_newton = zero(u)
248+
u_gauss_newton = _mutable_zero(u)
254249

255250
loss_new = loss
256251
H = zero(J' * J)
@@ -338,10 +333,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
338333
initial_trust_radius = convert(trustType, 1.0)
339334
end
340335

341-
abstol, reltol, termination_condition = _init_termination_elements(abstol,
342-
reltol,
343-
termination_condition,
344-
eltype(u))
336+
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
337+
termination_condition, eltype(u))
345338

346339
mode = DiffEqBase.get_termination_mode(termination_condition)
347340

@@ -368,8 +361,7 @@ function perform_step!(cache::TrustRegionCache{true})
368361
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
369362
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
370363
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu),
371-
linu = _vec(u_gauss_newton),
372-
p = p, reltol = cache.abstol)
364+
linu = _vec(u_gauss_newton), p = p, reltol = cache.abstol)
373365
cache.linsolve = linres.cache
374366
@. cache.u_gauss_newton = -1 * u_gauss_newton
375367
end
@@ -395,7 +387,12 @@ function perform_step!(cache::TrustRegionCache{false})
395387
cache.H = J' * J
396388
cache.g = _restructure(fu, J' * _vec(fu))
397389
cache.stats.njacs += 1
398-
cache.u_gauss_newton = -1 .* _restructure(cache.g, cache.H \ _vec(cache.g))
390+
391+
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
392+
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
393+
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J, b = -_vec(fu),
394+
linu = _vec(cache.u_gauss_newton), p = p, reltol = cache.abstol)
395+
cache.linsolve = linres.cache
399396
end
400397

401398
# Compute the Newton step.
@@ -718,8 +715,16 @@ function jvp!(cache::TrustRegionCache{true})
718715
end
719716

720717
function not_terminated(cache::TrustRegionCache)
721-
return !cache.force_stop && cache.stats.nsteps < cache.maxiters &&
722-
cache.shrink_counter < cache.alg.max_shrink_times
718+
non_shrink_terminated = cache.force_stop || cache.stats.nsteps cache.maxiters
719+
# Terminated due to convergence or maxiters
720+
non_shrink_terminated && return false
721+
# Terminated due to too many shrink
722+
shrink_terminated = cache.shrink_counter cache.alg.max_shrink_times
723+
if shrink_terminated
724+
cache.retcode = ReturnCode.ConvergenceFailure
725+
return false
726+
end
727+
return true
723728
end
724729
get_fu(cache::TrustRegionCache) = cache.fu
725730

src/utils.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ end
1313
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u))
1414
@inline DEFAULT_NORM(u) = norm(u)
1515

16+
# Ignores NaN
17+
function __findmin(f, x)
18+
return findmin(x) do xᵢ
19+
fx = f(xᵢ)
20+
return isnan(fx) ? Inf : fx
21+
end
22+
end
23+
1624
"""
1725
default_adargs_to_adtype(; chunk_size = Val{0}(), autodiff = Val{true}(),
1826
standardtag = Val{true}(), diff_type = Val{:forward})
@@ -210,9 +218,13 @@ function __get_concrete_algorithm(alg, prob)
210218
return set_ad(alg, ad)
211219
end
212220

221+
__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
222+
__cvt_real(::Type{T}, x) where {T} = real(T(x))
223+
213224
function _get_tolerance(η, tc_η, ::Type{T}) where {T}
214225
fallback_η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
215-
return T(ifelse!== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η)))
226+
return ifelse!== nothing, __cvt_real(T, η),
227+
ifelse(tc_η !== nothing, __cvt_real(T, tc_η), fallback_η))
216228
end
217229

218230
function _init_termination_elements(abstol, reltol, termination_condition,

0 commit comments

Comments
 (0)