Skip to content

Commit 52ba7e4

Browse files
committed
Add termination conditions for NonlinearProblem and SSProblem
1 parent f34fcd7 commit 52ba7e4

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.73.2"
4+
version = "6.73.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -10,6 +10,7 @@ DEDataArrays = "754358af-613d-5f8d-9788-280bf1605d4c"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
13+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1314
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1415
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1516
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -41,6 +42,7 @@ DEDataArrays = "0.2"
4142
DataStructures = "0.18"
4243
Distributions = "0.25"
4344
DocStringExtensions = "0.8"
45+
EnumX = "1"
4446
FastBroadcast = "0.1.4"
4547
ForwardDiff = "0.10"
4648
FunctionWrappers = "1.0"

src/DiffEqBase.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ using Setfield
3838

3939
using ForwardDiff
4040

41+
using EnumX
42+
4143
@reexport using SciMLBase
4244

4345
using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator,
@@ -126,6 +128,8 @@ include("init.jl")
126128
include("forwarddiff.jl")
127129
include("chainrules.jl")
128130

131+
include("termination_conditions.jl")
132+
129133
include("precompile.jl")
130134

131135
"""
@@ -153,4 +157,6 @@ export NLNewton, NLFunctional, NLAnderson
153157

154158
export SensitivityADPassThrough
155159

160+
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition
161+
156162
end # module

src/termination_conditions.jl

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
@enumx NLSolveSafeTerminationReturnCode begin
2+
Success
3+
PatienceTermination
4+
ProtectiveTermination
5+
Failure
6+
end
7+
8+
# SteadyStateDefault and NLSolveDefault are needed to be compatible with the existing
9+
# termination conditions in NonlinearSolve and SteadyStateDiffEq
10+
@enumx NLSolveTerminationMode begin
11+
SteadyStateDefault
12+
NLSolveDefault
13+
Norm
14+
Rel
15+
RelNorm
16+
Abs
17+
AbsNorm
18+
RelSafe
19+
RelSafeBest
20+
AbsSafe
21+
AbsSafeBest
22+
end
23+
24+
struct NLSolveSafeTerminationOptions{T1,T2,T3}
25+
protective_threshold::T1
26+
patience_steps::Int
27+
patience_objective_multiplier::T2
28+
min_max_factor::T3
29+
end
30+
31+
const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
32+
NLSolveTerminationMode.NLSolveDefault,
33+
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
34+
NLSolveTerminationMode.RelNorm,
35+
NLSolveTerminationMode.Abs, NLSolveTerminationMode.AbsNorm)
36+
37+
const SAFE_TERMINATION_MODES = (NLSolveTerminationMode.RelSafe,
38+
NLSolveTerminationMode.RelSafeBest,
39+
NLSolveTerminationMode.AbsSafe,
40+
NLSolveTerminationMode.AbsSafeBest)
41+
42+
const SAFE_BEST_TERMINATION_MODES = (NLSolveTerminationMode.RelSafeBest,
43+
NLSolveTerminationMode.AbsSafeBest)
44+
45+
@doc doc"""
46+
NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
47+
protective_threshold = 1e3, patience_steps::Int = 30,
48+
patience_objective_multiplier = 3, min_max_factor = 1.3)
49+
50+
Define the termination criteria for the NonlinearProblem or SteadyStateProblem.
51+
52+
## Termination Conditions
53+
54+
#### Termination on Absolute Tolerance
55+
56+
* `SteadyStateTerminationMode.Abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``
57+
* `SteadyStateTerminationMode.AbsNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``
58+
* `SteadyStateTerminationMode.AbsSafe`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
59+
* `SteadyStateTerminationMode.AbsSafeBest`: Same as `SteadyStateTerminationMode.AbsSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged
60+
61+
#### Termination on Relative Tolerance
62+
63+
* `SteadyStateTerminationMode.Rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``
64+
* `SteadyStateTerminationMode.RelNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|``
65+
* `SteadyStateTerminationMode.RelSafe`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges)
66+
* `SteadyStateTerminationMode.RelSafeBest`: Same as `SteadyStateTerminationMode.RelSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged
67+
68+
#### Termination using both Absolute and Relative Tolerances
69+
70+
* `SteadyStateTerminationMode.Norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` or ``\| \frac{\partial u}{\partial t} \| \leq abstol``
71+
* `SteadyStateTerminationMode.SteadyStateDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems but doesn't scale well for neural networks.
72+
* `SteadyStateTerminationMode.NLSolveDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. Or check that the value of the current and previous state is within the specified tolerances. This is usable for small problems but doesn't scale well for neural networks.
73+
74+
## General Arguments
75+
76+
* `abstol`: Absolute Tolerance
77+
* `reltol`: Relative Tolerance
78+
79+
## Arguments specific to `*Safe*` modes
80+
81+
* `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately.
82+
* `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate.
83+
84+
"""
85+
struct NLSolveTerminationCondition{mode, T,
86+
S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
87+
abstol::T
88+
reltol::T
89+
safe_termination_options::S
90+
end
91+
92+
function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode}
93+
print(io, "NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)")
94+
if mode SAFE_TERMINATION_MODES
95+
print(io, ", safe_termination_options = ", s.safe_termination_options, ")")
96+
else
97+
print(io, ")")
98+
end
99+
end
100+
101+
# Don't specify `mode` since the defaults would depend on the package
102+
function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
103+
protective_threshold = 1e3, patience_steps::Int = 30,
104+
patience_objective_multiplier = 3,
105+
min_max_factor = 1.3) where {T}
106+
@assert mode instances(NLSolveTerminationMode.T)
107+
options = if mode SAFE_TERMINATION_MODES
108+
NLSolveSafeTerminationOptions(protective_threshold, patience_steps,
109+
patience_objective_multiplier, min_max_factor)
110+
else
111+
nothing
112+
end
113+
return NLSolveTerminationCriteria{mode, T, typeof(options)}(abstol, reltol, options)
114+
end
115+
116+
function get_termination_condition(cond::NLSolveTerminationCondition{mode},
117+
storage::Union{<:AbstractDict, Nothing}) where {mode}
118+
# We need both the dispatches to support solvers that don't use the integrator
119+
# interface like SimpleNonlinearSolve
120+
if mode in BASIC_TERMINATION_MODES
121+
function _termination_condition_closure_basic(integrator, abstol, reltol, min_t)
122+
return _termination_condition_closure_basic(get_du(integrator), integrator.u,
123+
integrator.uprev, abstol, reltol)
124+
end
125+
function _termination_condition_closure_basic(du, u, uprev, abstol = cond.abstol,
126+
reltol = cond.reltol)
127+
return _has_converged(du, u, uprev, cond, abstol, reltol)
128+
end
129+
return _termination_condition_closure_basic
130+
else
131+
mode SAFE_BEST_TERMINATION_MODES && @assert storage !== nothing
132+
133+
function _termination_condition_closure_safe(integrator, abstol, reltol, min_t)
134+
return _termination_condition_closure_safe(get_du(integrator), integrator.u,
135+
integrator.uprev, abstol, reltol)
136+
end
137+
@inbounds function _termination_condition_closure_safe(du, u, uprev,
138+
abstol = cond.abstol,
139+
reltol = cond.reltol)
140+
aType = typeof(cond.abstol)
141+
nstep = 0
142+
protective_threshold = aType(cond.safe_termination_options.protective_threshold)
143+
objective_values = aType[]
144+
patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier
145+
146+
if mode SAFE_BEST_TERMINATION_MODES
147+
storage[:best_objective_value] = oftype(Inf)
148+
storage[:best_objective_value_iteration] = 0
149+
end
150+
151+
if mode SAFE_BEST_TERMINATION_MODES
152+
objective = norm(du)
153+
criteria = abstol
154+
else
155+
objective = norm(du) / (norm(du .+ u) + eps(aType))
156+
criteria = reltol
157+
end
158+
159+
if mode SAFE_BEST_TERMINATION_MODES
160+
if objective < storage[:best_objective_value]
161+
storage[:best_objective_value] = objective
162+
storage[:best_objective_value_iteration] = nstep + 1
163+
end
164+
end
165+
166+
# Main Termination Criteria
167+
if objective <= criteria
168+
storage[:return_code] = NLSolveSafeTerminationReturnCode.Success
169+
return true
170+
end
171+
172+
# Terminate if there has been no improvement for the last `patience_steps`
173+
nstep += 1
174+
push!(objective_values, objective)
175+
176+
if objective <= typeof(criteria)(patience_objective_multiplier) * criteria
177+
if nstep >= cond.safe_termination_options.patience_steps
178+
last_k_values = objective_values[max(1,
179+
length(objective_values) -
180+
cond.safe_termination_options.patience_steps):end]
181+
if maximum(last_k_values) <
182+
typeof(criteria)(cond.safe_termination_options.min_max_factor) *
183+
minimum(last_k_values)
184+
storage[:return_code] = NLSolveSafeTerminationReturnCode.PatienceTermination
185+
return true
186+
end
187+
end
188+
end
189+
190+
# Protective break
191+
if objective >= objective_values[1] * protective_threshold * length(du)
192+
storage[:return_code] = NLSolveSafeTerminationReturnCode.ProtectiveTermination
193+
return true
194+
end
195+
196+
storage[:return_code] = NLSolveSafeTerminationReturnCode.Failure
197+
return false
198+
end
199+
return _termination_condition_closure_safe
200+
end
201+
end
202+
203+
204+
# Convergence Criterions
205+
@inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCriteria{mode},
206+
abstol = cond.abstol, reltol = cond.reltol) where {mode}
207+
return _has_converged(du, u, uprev, mode, abstol, reltol)
208+
end
209+
210+
@inline @inbounds function _has_converged(du, u, uprev, mode, abstol, reltol)
211+
if mode == NLSolveTerminationMode.Norm
212+
du_norm = norm(du)
213+
return du_norm <= abstol || du_norm <= reltol * norm(du + u)
214+
elseif mode == NLSolveTerminationMode.Rel
215+
return all(abs.(du) .<= reltol .* abs.(u))
216+
elseif mode (NLSolveTerminationMode.RelNorm, NLSolveTerminationMode.RelSafe,
217+
NLSolveTerminationMode.RelSafeBest)
218+
return norm(du) <= reltol * norm(du .+ u)
219+
elseif mode == NLSolveTerminationMode.Abs
220+
return all(abs.(du) .<= abstol)
221+
elseif mode (NLSolveTerminationMode.AbsNorm, NLSolveTerminationMode.AbsSafe,
222+
NLSolveTerminationMode.AbsSafeBest)
223+
return norm(du) <= abstol
224+
elseif mode == NLSolveTerminationMode.SteadyStateDefault
225+
return all((abs.(du) .<= abstol) .|| (abs.(du) .<= reltol .* abs.(u)))
226+
elseif mode == NLSolveTerminationMode.NLSolveDefault
227+
atol, rtol = abstol, reltol
228+
return all((abs.(du) .<= abstol) .|| (abs.(du) .<= reltol .* abs.(u))) ||
229+
isapprox(u, uprev; atol, rtol)
230+
else
231+
throw(ArgumentError("Unknown termination mode: $mode"))
232+
end
233+
end

0 commit comments

Comments
 (0)