Skip to content

Commit d63e00b

Browse files
committed
Rework the Termination Condition API to be type stable
1 parent e9316ae commit d63e00b

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

src/termination_conditions.jl

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,194 @@
1+
@enumx NonlinearSafeTerminationReturnCode begin
2+
Success
3+
Default
4+
PatienceTermination
5+
ProtectiveTermination
6+
Failure
7+
end
8+
9+
abstract type AbstractNonlinearTerminationMode end
10+
abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end
11+
abstract type AbstractSafeBestNonlinearTerminationMode <:
12+
AbstractSafeNonlinearTerminationMode end
13+
14+
# TODO: Add a mode where the user can pass in custom termination criteria function
15+
for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode,
16+
:NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode,
17+
:AbsNormTerminationMode)
18+
@eval begin
19+
struct $(mode) <: AbstractNonlinearTerminationMode end
20+
end
21+
end
22+
23+
for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode)
24+
@eval begin
25+
Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode
26+
protective_threshold::T1 = 1000
27+
patience_steps::Int = 30
28+
patience_objective_multiplier::T2 = 3
29+
min_max_factor::T3 = 1.3
30+
end
31+
end
32+
end
33+
34+
for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode)
35+
@eval begin
36+
Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode
37+
protective_threshold::T1 = 1000
38+
patience_steps::Int = 30
39+
patience_objective_multiplier::T2 = 3
40+
min_max_factor::T3 = 1.3
41+
end
42+
end
43+
end
44+
45+
mutable struct NonlinearTerminationModeCache{uType, T,
46+
M <: AbstractNonlinearTerminationMode, I, OT}
47+
u::uType
48+
retcode::NonlinearSafeTerminationReturnCode.T
49+
abstol::T
50+
reltol::T
51+
best_objective_value::T
52+
mode::M
53+
initial_objective::I
54+
objectives_trace::OT
55+
nsteps::Int
56+
end
57+
58+
function __update_u!!(cache::NonlinearTerminationModeCache, u)
59+
cache.u === nothing && return
60+
if ArrayInterface.can_setindex(cache.u)
61+
copyto!(cache.u, u)
62+
else
63+
cache.u = u
64+
end
65+
end
66+
67+
__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
68+
__cvt_real(::Type{T}, x) where {T} = real(T(x))
69+
70+
_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η)
71+
function _get_tolerance(::Nothing, ::Type{T}) where {T}
72+
η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
73+
return _get_tolerance(η, T)
74+
end
75+
76+
function SciMLBase.init(u::AbstractArray{T}, mode::AbstractNonlinearTerminationMode;
77+
abstol = nothing, reltol = nothing, kwargs...) where {T}
78+
abstol = _get_tolerance(abstol, T)
79+
reltol = _get_tolerance(reltol, T)
80+
best_value = __cvt_real(T, Inf)
81+
TT = typeof(abstol)
82+
u_ = mode isa AbstractSafeBestNonlinearTerminationMode ?
83+
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
84+
if mode isa AbstractSafeNonlinearTerminationMode
85+
initial_objective = TT(0)
86+
objectives_trace = Vector{TT}(undef, mode.patience_steps)
87+
else
88+
initial_objective = nothing
89+
objectives_trace = nothing
90+
end
91+
return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode),
92+
typeof(initial_objective), typeof(objectives_trace)}(u_,
93+
NonlinearSafeTerminationReturnCode.Default, abstol, reltol, best_value, mode,
94+
initial_objective, objectives_trace, 0)
95+
end
96+
97+
# This dispatch is needed based on how Terminating Callback works!
98+
# This intentially drops the `abstol` and `reltol` arguments
99+
function (cache::NonlinearTerminationModeCache)(integrator, _, _, min_t)
100+
return cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev)
101+
end
102+
(cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev)
103+
104+
function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du,
105+
u, uprev)
106+
return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol)
107+
end
108+
109+
function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode,
110+
du, u, uprev)
111+
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
112+
objective = NLSOLVE_DEFAULT_NORM(du)
113+
criteria = cache.abstol
114+
else
115+
objective = NLSOLVE_DEFAULT_NORM(du) /
116+
(NLSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol))
117+
criteria = cache.reltol
118+
end
119+
120+
# Check if best solution
121+
if mode isa AbstractSafeBestNonlinearTerminationMode &&
122+
objective < cache.best_objective_value
123+
cache.best_objective_value = objective
124+
__update_u!!(cache, u)
125+
end
126+
127+
# Main Termination Condition
128+
if objective criteria
129+
cache.retcode = NonlinearSafeTerminationReturnCode.Success
130+
return true
131+
end
132+
133+
# Terminate if there has been no improvement for the last `patience_steps`
134+
cache.nsteps += 1
135+
cache.nsteps == 1 && (cache.initial_objective = objective)
136+
cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective
137+
138+
if objective cache.mode.patience_objective_multiplier * criteria
139+
if cache.nsteps cache.mode.patience_steps
140+
if cache.nsteps < length(cache.objectives_trace)
141+
min_obj, max_obj = extrema(@view(cache.objectives_trace[1:cache.nsteps]))
142+
else
143+
min_obj, max_obj = extrema(cache.objectives_trace)
144+
end
145+
if min_obj < cache.mode.min_max_factor * max_obj
146+
cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination
147+
return true
148+
end
149+
end
150+
end
151+
152+
# Protective Break
153+
if objective cache.initial_objective * cache.mode.protective_threshold * length(du)
154+
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
155+
return true
156+
end
157+
158+
cache.retcode = NonlinearSafeTerminationReturnCode.Failure
159+
return false
160+
end
161+
162+
function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,
163+
reltol)
164+
return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ)))
165+
end
166+
function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,
167+
reltol)
168+
return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) ||
169+
isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol)
170+
end
171+
function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
172+
du_norm = NLSOLVE_DEFAULT_NORM(duₙ)
173+
return du_norm abstol || du_norm reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ)
174+
end
175+
function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
176+
return all(abs.(duₙ) .≤ reltol .* abs.(uₙ))
177+
end
178+
function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode,
179+
RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
180+
return NLSOLVE_DEFAULT_NORM(duₙ) reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ)
181+
end
182+
function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
183+
return all(abs.(duₙ) .≤ abstol)
184+
end
185+
function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode,
186+
AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
187+
return NLSOLVE_DEFAULT_NORM(duₙ) abstol
188+
end
189+
190+
# NOTE: Deprecate the following API eventually. This API leads to quite a bit of type
191+
# instability
1192
@enumx NLSolveSafeTerminationReturnCode begin
2193
Success
3194
PatienceTermination

0 commit comments

Comments
 (0)