Skip to content

Commit 4f970f7

Browse files
committed
feat: share the termination condition code in NonlinearSolve and SimpleNonlinearSolve
1 parent f69bccd commit 4f970f7

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@ using LinearAlgebra: norm
88
using Markdown: @doc_str
99
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
1010
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
11-
AbstractNonlinearFunction, @add_kwonly, StandardNonlinearProblem,
12-
NullParameters, NonlinearProblem, isinplace
11+
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
12+
@add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem,
13+
isinplace
1314
using StaticArraysCore: StaticArray
1415

1516
include("public.jl")
1617
include("utils.jl")
1718

19+
include("immutable_problem.jl")
1820
include("common_defaults.jl")
1921
include("termination_conditions.jl")
20-
include("immutable_problem.jl")
2122

2223
# Unexported Public API
2324
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))

lib/NonlinearSolveBase/src/termination_conditions.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,37 @@ end
245245
function check_convergence(mode::AbsNormModes, duₙ, _, __, abstol, ___)
246246
return Utils.apply_norm(mode.internalnorm, duₙ) abstol
247247
end
248+
249+
# High-Level API with defaults.
250+
## This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve
251+
function default_termination_mode(
252+
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple})
253+
return AbsNormTerminationMode(Base.Fix1(maximum, abs))
254+
end
255+
function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:simple})
256+
return AbsNormTerminationMode(Base.Fix2(norm, 2))
257+
end
258+
259+
function default_termination_mode(
260+
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular})
261+
return AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32)
262+
end
263+
264+
function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular})
265+
return AbsNormSafeBestTerminationMode(Base.Fix2(norm, 2); max_stalled_steps = 32)
266+
end
267+
268+
function init_termination_cache(
269+
prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val)
270+
return init_termination_cache(
271+
prob, abstol, reltol, du, u, default_termination_mode(prob, callee), callee)
272+
end
273+
274+
function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du,
275+
u, tc::AbstractNonlinearTerminationMode, ::Val)
276+
T = promote_type(eltype(du), eltype(u))
277+
abstol = get_tolerance(abstol, T)
278+
reltol = get_tolerance(reltol, T)
279+
cache = init(du, u, tc; abstol, reltol)
280+
return abstol, reltol, cache
281+
end

0 commit comments

Comments
 (0)