Skip to content

Commit 7c8a0a7

Browse files
committed
Improve termination conditions
1 parent e9b7146 commit 7c8a0a7

17 files changed

+272
-409
lines changed

Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
164164

165165
[[deps.DiffEqBase]]
166166
deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"]
167-
git-tree-sha1 = "e5049e32074cd22f86d74036caf6663637623003"
167+
git-tree-sha1 = "4e661d0beddac31da05e71b79afd769232622de8"
168168
repo-rev = "ap/tstable_termination"
169169
repo-url = "https://github.com/SciML/DiffEqBase.jl"
170170
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
@@ -689,9 +689,9 @@ version = "0.1.0"
689689

690690
[[deps.SLEEFPirates]]
691691
deps = ["IfElse", "Static", "VectorizationBase"]
692-
git-tree-sha1 = "897b39ec056c0619ea87adc7eeadba0bec0cf931"
692+
git-tree-sha1 = "f5c896d781486f1d67c8492f0e0ead2c3517208c"
693693
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
694-
version = "0.6.40"
694+
version = "0.6.41"
695695

696696
[[deps.SciMLBase]]
697697
deps = ["ADTypes", "ArrayInterface", "ChainRulesCore", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces", "ZygoteRules"]

src/NonlinearSolve.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ PrecompileTools.@recompile_invalidations begin
3030
end
3131

3232
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
33-
import DiffEqBase: AbstractNonlinearTerminationMode
33+
import DiffEqBase: AbstractNonlinearTerminationMode,
34+
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
35+
NonlinearSafeTerminationReturnCode, get_termination_mode
3436

3537
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
3638
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
@@ -54,6 +56,8 @@ function not_terminated(cache::AbstractNonlinearSolveCache)
5456
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
5557
end
5658
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
59+
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)
60+
get_u(cache::AbstractNonlinearSolveCache) = cache.u
5761

5862
function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
5963
while not_terminated(cache)
@@ -70,7 +74,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
7074
end
7175
end
7276

73-
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
77+
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
7478
cache.retcode, cache.stats)
7579
end
7680

@@ -114,4 +118,10 @@ export RobustMultiNewton, FastShortcutNonlinearPolyalg
114118

115119
export LineSearch, LiFukushimaLineSearch
116120

121+
# Export the termination conditions from DiffEqBase
122+
export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
123+
NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode,
124+
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
125+
RelSafeBestTerminationMode, AbsSafeBestTerminationMode
126+
117127
end # module

src/broyden.jl

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ end
5353
prob
5454
stats::NLStats
5555
ls_cache
56-
termination_condition
57-
tc_storage
56+
tc_cache
5857
end
5958

6059
get_fu(cache::GeneralBroydenCache) = cache.fu
60+
set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu)
6161

6262
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
6363
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
@@ -71,34 +71,26 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
7171
alg.reset_tolerance
7272
reset_check = x -> abs(x) reset_tolerance
7373

74-
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
75-
termination_condition, eltype(u))
76-
77-
mode = DiffEqBase.get_termination_mode(termination_condition)
74+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
75+
termination_condition)
7876

79-
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
80-
nothing
8177
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
8278
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
8379
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
8480
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
85-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
86-
storage)
81+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache)
8782
end
8883

8984
function perform_step!(cache::GeneralBroydenCache{true})
90-
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache
91-
92-
termination_condition = cache.termination_condition(tc_storage)
85+
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache
9386
T = eltype(u)
9487

9588
mul!(_vec(du), J⁻¹, _vec(fu))
9689
α = perform_linesearch!(cache.ls_cache, u, du)
9790
_axpy!(-α, du, u)
9891
f(fu2, u, p)
9992

100-
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
101-
(cache.force_stop = true)
93+
check_and_update!(cache, fu2, u, u_prev)
10294
cache.stats.nf += 1
10395

10496
cache.force_stop && return nothing
@@ -130,9 +122,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
130122
end
131123

132124
function perform_step!(cache::GeneralBroydenCache{false})
133-
@unpack f, p, tc_storage = cache
134-
135-
termination_condition = cache.termination_condition(tc_storage)
125+
@unpack f, p = cache
136126

137127
T = eltype(cache.u)
138128

@@ -141,8 +131,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
141131
cache.u = cache.u .- α * cache.du
142132
cache.fu2 = f(cache.u, p)
143133

144-
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
145-
(cache.force_stop = true)
134+
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
146135
cache.stats.nf += 1
147136

148137
cache.force_stop && return nothing
@@ -172,9 +161,8 @@ function perform_step!(cache::GeneralBroydenCache{false})
172161
end
173162

174163
function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
175-
abstol = cache.abstol, reltol = cache.reltol,
176-
termination_condition = cache.termination_condition,
177-
maxiters = cache.maxiters) where {iip}
164+
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
165+
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
178166
cache.p = p
179167
if iip
180168
recursivecopy!(cache.u, u0)
@@ -185,12 +173,12 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
185173
cache.fu = cache.f(cache.u, p)
186174
end
187175

188-
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
176+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
189177
termination_condition)
190178

191179
cache.abstol = abstol
192180
cache.reltol = reltol
193-
cache.termination_condition = termination_condition
181+
cache.tc_cache = tc_cache
194182
cache.maxiters = maxiters
195183
cache.stats.nf = 1
196184
cache.stats.nsteps = 1

src/dfsane.jl

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4,
6363
n_exp, η_strategy, max_inner_iterations)
6464
end
6565

66-
@concrete mutable struct DFSaneCache{iip}
66+
# FIXME: Someone please make this code conform to the style of the remaining solvers
67+
@concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip}
6768
alg
6869
uₙ
6970
uₙ₋₁
@@ -91,10 +92,13 @@ end
9192
reltol
9293
prob
9394
stats::NLStats
94-
termination_condition
95-
tc_storage
95+
tc_cache
9696
end
9797

98+
get_fu(cache::DFSaneCache) = cache.fuₙ
99+
set_fu!(cache::DFSaneCache, fu) = (cache.fuₙ = fu)
100+
get_u(cache::DFSaneCache) = cache.uₙ
101+
98102
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
99103
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
100104
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
@@ -124,24 +128,18 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
124128

125129
= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)
126130

127-
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
128-
termination_condition, T)
129-
130-
mode = DiffEqBase.get_termination_mode(termination_condition)
131-
132-
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
133-
nothing
131+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fuₙ₋₁, uₙ₋₁,
132+
termination_condition)
134133

135134
return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
136135
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
137136
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
138-
termination_condition, storage)
137+
tc_cache)
139138
end
140139

141140
function perform_step!(cache::DFSaneCache{true})
142-
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
141+
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
143142

144-
termination_condition = cache.termination_condition(tc_storage)
145143
f = (dx, x) -> cache.prob.f(dx, x, cache.p)
146144

147145
T = eltype(cache.uₙ)
@@ -184,9 +182,7 @@ function perform_step!(cache::DFSaneCache{true})
184182
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
185183
end
186184

187-
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
188-
cache.force_stop = true
189-
end
185+
check_and_update!(cache, cache.fuₙ, cache.uₙ, cache.uₙ₋₁)
190186

191187
# Update spectral parameter
192188
@. cache.uₙ₋₁ = cache.uₙ - cache.uₙ₋₁
@@ -215,9 +211,8 @@ function perform_step!(cache::DFSaneCache{true})
215211
end
216212

217213
function perform_step!(cache::DFSaneCache{false})
218-
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
214+
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
219215

220-
termination_condition = cache.termination_condition(tc_storage)
221216
f = x -> cache.prob.f(x, cache.p)
222217

223218
T = eltype(cache.uₙ)
@@ -260,9 +255,7 @@ function perform_step!(cache::DFSaneCache{false})
260255
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
261256
end
262257

263-
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
264-
cache.force_stop = true
265-
end
258+
check_and_update!(cache, cache.fuₙ, cache.uₙ, cache.uₙ₋₁)
266259

267260
# Update spectral parameter
268261
cache.uₙ₋₁ = @. cache.uₙ - cache.uₙ₋₁
@@ -290,26 +283,9 @@ function perform_step!(cache::DFSaneCache{false})
290283
return nothing
291284
end
292285

293-
function SciMLBase.solve!(cache::DFSaneCache)
294-
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
295-
cache.stats.nsteps += 1
296-
perform_step!(cache)
297-
end
298-
299-
if cache.stats.nsteps == cache.maxiters
300-
cache.retcode = ReturnCode.MaxIters
301-
else
302-
cache.retcode = ReturnCode.Success
303-
end
304-
305-
return SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ;
306-
retcode = cache.retcode, stats = cache.stats)
307-
end
308-
309286
function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
310-
abstol = cache.abstol, reltol = cache.reltol,
311-
termination_condition = cache.termination_condition,
312-
maxiters = cache.maxiters) where {iip}
287+
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
288+
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
313289
cache.p = p
314290
if iip
315291
recursivecopy!(cache.uₙ, u0)
@@ -330,12 +306,12 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
330306
T = eltype(cache.uₙ)
331307
cache.σₙ = T(cache.alg.σ_1)
332308

333-
termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
309+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fuₙ, cache.uₙ,
334310
termination_condition)
335311

336312
cache.abstol = abstol
337313
cache.reltol = reltol
338-
cache.termination_condition = termination_condition
314+
cache.tc_cache = tc_cache
339315
cache.maxiters = maxiters
340316
cache.stats.nf = 1
341317
cache.stats.nsteps = 1

0 commit comments

Comments
 (0)