Skip to content

Commit 0b91de7

Browse files
committed
refactor: use functionality from NonlinearSolveBase instead of DiffEqBase
1 parent f6041af commit 0b91de7

28 files changed

+65
-122
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ BandedMatrices = "1.5"
6464
BenchmarkTools = "1.4"
6565
CUDA = "5.5"
6666
ConcreteStructs = "0.2.3"
67-
DiffEqBase = "6.155.3"
67+
DiffEqBase = "6.158.3"
6868
DifferentiationInterface = "0.6.1"
6969
Enzyme = "0.13.2"
7070
ExplicitImports = "1.5"

docs/src/basics/solve.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ solve(prob::SciMLBase.NonlinearProblem, args...; kwargs...)
2121
`real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
2222
- `reltol::Number`: The relative tolerance. Defaults to
2323
`real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
24-
- `termination_condition`: Termination Condition from DiffEqBase. Defaults to
24+
- `termination_condition`: Termination Condition from NonlinearSolveBase. Defaults to
2525
`AbsSafeBestTerminationMode()` for `NonlinearSolve.jl` and `AbsTerminateMode()` for
2626
`SimpleNonlinearSolve.jl`.
2727

docs/src/basics/termination_condition.md

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ cache = init(du, u, AbsSafeBestTerminationMode(); abstol = 1e-9, reltol = 1e-9)
1414
If `abstol` and `reltol` are not supplied, then we choose a default based on the element
1515
types of `du` and `u`.
1616

17-
We can query the `cache` using `DiffEqBase.get_termination_mode`, `DiffEqBase.get_abstol`
18-
and `DiffEqBase.get_reltol`.
19-
2017
To test for termination simply call the `cache`:
2118

2219
```julia
@@ -28,29 +25,15 @@ terminated = cache(du, u, uprev)
2825
```@docs
2926
AbsTerminationMode
3027
AbsNormTerminationMode
31-
AbsSafeTerminationMode
32-
AbsSafeBestTerminationMode
28+
AbsNormSafeTerminationMode
29+
AbsNormSafeBestTerminationMode
3330
```
3431

3532
### Relative Tolerance
3633

3734
```@docs
3835
RelTerminationMode
3936
RelNormTerminationMode
40-
RelSafeTerminationMode
41-
RelSafeBestTerminationMode
42-
```
43-
44-
### Both Absolute and Relative Tolerance
45-
46-
```@docs
47-
NormTerminationMode
48-
SteadyStateDiffEqTerminationMode
49-
```
50-
51-
The following was named to match an older version of SimpleNonlinearSolve. It is currently
52-
not used as a default anywhere.
53-
54-
```@docs
55-
SimpleNonlinearSolveTerminationMode
37+
RelNormSafeTerminationMode
38+
RelNormSafeBestTerminationMode
5639
```

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module NonlinearSolveFastLevenbergMarquardtExt
33
using ArrayInterface: ArrayInterface
44
using FastClosures: @closure
55
using FastLevenbergMarquardt: FastLevenbergMarquardt
6+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
67
using NonlinearSolve: NonlinearSolve, FastLevenbergMarquardtJL
78
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
89
using StaticArraysCore: SArray
@@ -33,8 +34,8 @@ function SciMLBase.__solve(prob::Union{NonlinearLeastSquaresProblem, NonlinearPr
3334
else
3435
@closure (du, u, p) -> fn(du, u)
3536
end
36-
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
37-
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
37+
abstol = get_tolerance(abstol, eltype(u))
38+
reltol = get_tolerance(reltol, eltype(u))
3839

3940
_jac_fn = NonlinearSolve.__construct_extension_jac(
4041
prob, alg, u, resid; alg.autodiff, can_handle_oop = Val(prob.u0 isa SArray))

ext/NonlinearSolveFixedPointAccelerationExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module NonlinearSolveFixedPointAccelerationExt
22

3+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
34
using NonlinearSolve: NonlinearSolve, FixedPointAccelerationJL
45
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
56
using FixedPointAcceleration: FixedPointAcceleration, fixed_point
@@ -13,7 +14,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL
1314

1415
f, u0, resid = NonlinearSolve.__construct_extension_f(
1516
prob; alias_u0, make_fixed_point = Val(true), force_oop = Val(true))
16-
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
17+
tol = get_tolerance(abstol, eltype(u0))
1718

1819
sol = fixed_point(f, u0; Algorithm = alg.algorithm, MaxIter = maxiters, MaxM = alg.m,
1920
ConvergenceMetricThreshold = tol, ExtrapolationPeriod = alg.extrapolation_period,

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module NonlinearSolveLeastSquaresOptimExt
22

33
using ConcreteStructs: @concrete
44
using LeastSquaresOptim: LeastSquaresOptim
5+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
56
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL, TraceMinimal
67
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
78

@@ -42,8 +43,8 @@ function SciMLBase.__init(prob::Union{NonlinearLeastSquaresProblem, NonlinearPro
4243
NonlinearSolve.__test_termination_condition(termination_condition, :LeastSquaresOptim)
4344

4445
f!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
45-
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
46-
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
46+
abstol = get_tolerance(abstol, eltype(u))
47+
reltol = get_tolerance(reltol, eltype(u))
4748

4849
if prob.f.jac === nothing && alg.autodiff isa Symbol
4950
lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid, alg.autodiff,

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NonlinearSolveMINPACKExt
22

33
using MINPACK: MINPACK
4+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
45
using NonlinearSolve: NonlinearSolve, CMINPACK
56
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
67
using FastClosures: @closure
@@ -21,7 +22,7 @@ function SciMLBase.__solve(
2122

2223
show_trace = ShT
2324
tracing = StT
24-
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
25+
tol = get_tolerance(abstol, eltype(u0))
2526

2627
if alg.autodiff === missing && prob.f.jac === nothing
2728
original = MINPACK.fsolve(

ext/NonlinearSolveNLSolversExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using FiniteDiff: FiniteDiff
66
using ForwardDiff: ForwardDiff
77
using LinearAlgebra: norm
88
using NLSolvers: NLSolvers, NEqOptions, NEqProblem
9+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
910
using NonlinearSolve: NonlinearSolve, NLSolversJL
1011
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
1112

@@ -14,8 +15,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
1415
alias_u0::Bool = false, termination_condition = nothing, kwargs...)
1516
NonlinearSolve.__test_termination_condition(termination_condition, :NLSolversJL)
1617

17-
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(prob.u0))
18-
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(prob.u0))
18+
abstol = get_tolerance(abstol, eltype(prob.u0))
19+
reltol = get_tolerance(reltol, eltype(prob.u0))
1920

2021
options = NEqOptions(; maxiter = maxiters, f_abstol = abstol, f_reltol = reltol)
2122

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NonlinearSolveNLsolveExt
22

33
using LineSearches: Static
4+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
45
using NonlinearSolve: NonlinearSolve, NLsolveJL, TraceMinimal
56
using NLsolve: NLsolve, OnceDifferentiable, nlsolve
67
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
@@ -27,7 +28,7 @@ function SciMLBase.__solve(
2728
df = OnceDifferentiable(f!, jac!, vec(u0), vec(resid), J)
2829
end
2930

30-
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
31+
abstol = get_tolerance(abstol, eltype(u0))
3132
show_trace = ShT
3233
store_trace = StT
3334
extended_trace = !(trace_level isa TraceMinimal)

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NonlinearSolveSIAMFANLEquationsExt
22

33
using FastClosures: @closure
4+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
45
using NonlinearSolve: NonlinearSolve, SIAMFANLEquationsJL
56
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
67
using SIAMFANLEquations: SIAMFANLEquations, aasol, nsol, nsoli, nsolsc, ptcsol, ptcsoli,
@@ -40,8 +41,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
4041

4142
(; method, delta, linsolve, m, beta) = alg
4243
T = eltype(prob.u0)
43-
atol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, T)
44-
rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, T)
44+
atol = get_tolerance(abstol, T)
45+
rtol = get_tolerance(reltol, T)
4546

4647
if prob.u0 isa Number
4748
f = @closure u -> prob.f(u, prob.p)

0 commit comments

Comments
 (0)