Skip to content

Commit eddcfde

Browse files
committed
Complete moving termination condition to all the algorithms
1 parent 1c5bb59 commit eddcfde

File tree

8 files changed

+99
-73
lines changed

8 files changed

+99
-73
lines changed

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import EnumX: @enumx
1414
import ForwardDiff: Dual
1515
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
1616
import RecursiveArrayTools: ArrayPartition,
17-
AbstractVectorOfArray, recursivecopy!, recursivefill!
17+
AbstractVectorOfArray, recursivecopy!, recursivefill!, recursive_unitless_bottom_eltype
1818
import Reexport: @reexport
1919
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
2020
import StaticArraysCore: StaticArray, SVector, SArray, MArray

src/dfsane.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,16 @@ end
8989
internalnorm
9090
retcode::SciMLBase.ReturnCode.T
9191
abstol
92+
reltol
9293
prob
9394
stats::NLStats
95+
termination_condition
96+
tc_storage
9497
end
9598

9699
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
97-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
100+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
101+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
98102
kwargs...) where {uType, iip}
99103
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)
100104

@@ -123,14 +127,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
123127
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁
124128

125129
= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)
130+
131+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
132+
reltol,
133+
termination_condition,
134+
T)
135+
136+
mode = DiffEqBase.get_termination_mode(termination_condition)
137+
138+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
139+
nothing
140+
126141
return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
127142
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
128-
internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
143+
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
144+
termination_condition, storage)
129145
end
130146

131147
function perform_step!(cache::DFSaneCache{true})
132-
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
148+
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
133149

150+
termination_condition = cache.termination_condition(tc_storage)
134151
f = (dx, x) -> cache.prob.f(dx, x, cache.p)
135152

136153
T = eltype(cache.uₙ)
@@ -175,7 +192,7 @@ function perform_step!(cache::DFSaneCache{true})
175192
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
176193
end
177194

178-
if cache.internalnorm(cache.fuₙ) < cache.abstol
195+
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
179196
cache.force_stop = true
180197
end
181198

@@ -206,8 +223,9 @@ function perform_step!(cache::DFSaneCache{true})
206223
end
207224

208225
function perform_step!(cache::DFSaneCache{false})
209-
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
226+
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
210227

228+
termination_condition = cache.termination_condition(tc_storage)
211229
f = x -> cache.prob.f(x, cache.p)
212230

213231
T = eltype(cache.uₙ)
@@ -250,7 +268,7 @@ function perform_step!(cache::DFSaneCache{false})
250268
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
251269
end
252270

253-
if cache.internalnorm(cache.fuₙ) < cache.abstol
271+
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
254272
cache.force_stop = true
255273
end
256274

src/gaussnewton.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,27 +36,22 @@ for large-scale and numerically-difficult nonlinear least squares problems.
3636
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
3737
construction. This will be fixed in the near future.
3838
"""
39-
@concrete struct GaussNewton{CJ, AD, TC} <: AbstractNewtonAlgorithm{CJ, AD}
39+
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
4040
ad::AD
4141
linsolve
4242
precs
43-
termination_condition::TC
4443
end
4544

4645
function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
4746
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
4847
end
4948

5049
function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
51-
precs = DEFAULT_PRECS,
52-
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
53-
abstol = nothing,
54-
reltol = nothing), adkwargs...)
50+
precs = DEFAULT_PRECS, adkwargs...)
5551
ad = default_adargs_to_adtype(; adkwargs...)
5652
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
5753
linsolve,
58-
precs,
59-
termination_condition)
54+
precs)
6055
end
6156

6257
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -84,10 +79,12 @@ end
8479
prob
8580
stats::NLStats
8681
tc_storage
82+
termination_condition
8783
end
8884

8985
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
9086
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
87+
termination_condition = nothing,
9188
internalnorm = DEFAULT_NORM,
9289
kwargs...) where {uType, iip}
9390
alg = get_concrete_algorithm(alg_, prob)
@@ -102,29 +99,30 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
10299
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);
103100
linsolve_with_JᵀJ = Val(true))
104101

105-
tc = alg.termination_condition
106-
mode = DiffEqBase.get_termination_mode(tc)
102+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
103+
reltol,
104+
termination_condition,
105+
eltype(u); mode = NLSolveTerminationMode.AbsNorm)
107106

108-
atol = _get_tolerance(abstol, tc.abstol, eltype(u))
109-
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))
107+
mode = DiffEqBase.get_termination_mode(termination_condition)
110108

111109
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
112110
nothing
113111

114112
return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
115113
linsolve, J,
116-
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol,
117-
prob, NLStats(1, 0, 0, 0, 0), storage)
114+
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
115+
reltol,
116+
prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
118117
end
119118

120119
function perform_step!(cache::GaussNewtonCache{true})
121-
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
120+
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du, tc_storage = cache
122121
jacobian!!(J, cache)
123122
__matmul!(JᵀJ, J', J)
124123
__matmul!(Jᵀf, J', fu1)
125124

126-
tc_storage = cache.tc_storage
127-
termination_condition = cache.alg.termination_condition(tc_storage)
125+
termination_condition = cache.termination_condition(tc_storage)
128126

129127
# u = u - J \ fu
130128
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
@@ -151,10 +149,9 @@ function perform_step!(cache::GaussNewtonCache{true})
151149
end
152150

153151
function perform_step!(cache::GaussNewtonCache{false})
154-
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
152+
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache
155153

156-
tc_storage = cache.tc_storage
157-
termination_condition = cache.alg.termination_condition(tc_storage)
154+
termination_condition = cache.termination_condition(tc_storage)
158155

159156
cache.J = jacobian!!(cache.J, cache)
160157

src/levenberg.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ numerically-difficult nonlinear systems.
7474
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
7575
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
7676
"""
77-
@concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition} <:
77+
@concrete struct LevenbergMarquardt{CJ, AD, T} <:
7878
AbstractNewtonAlgorithm{CJ, AD}
7979
ad::AD
8080
linsolve
@@ -86,7 +86,6 @@ numerically-difficult nonlinear systems.
8686
α_geodesic::T
8787
b_uphill::T
8888
min_damping_D::T
89-
termination_condition::TC
9089
end
9190

9291
function set_ad(alg::LevenbergMarquardt{CJ}, ad) where {CJ}
@@ -99,15 +98,11 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
9998
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
10099
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
101100
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
102-
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.AbsNorm;
103-
abstol = nothing,
104-
reltol = nothing),
105101
adkwargs...)
106102
ad = default_adargs_to_adtype(; adkwargs...)
107103
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
108104
damping_initial, damping_increase_factor, damping_decrease_factor,
109-
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D,
110-
termination_condition)
105+
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
111106
end
112107

113108
@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -153,12 +148,14 @@ end
153148
Jv
154149
mat_tmp
155150
stats::NLStats
151+
termination_condition
156152
tc_storage
157153
end
158154

159155
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
160156
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
161157
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
158+
termination_condition = nothing,
162159
internalnorm = DEFAULT_NORM,
163160
linsolve_kwargs = (;), kwargs...) where {uType, iip}
164161
alg = get_concrete_algorithm(alg_, prob)
@@ -168,6 +165,11 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
168165
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, Val(iip);
169166
linsolve_kwargs, linsolve_with_JᵀJ = Val(true))
170167

168+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
169+
reltol,
170+
termination_condition,
171+
eltype(u); mode = NLSolveTerminationMode.AbsNorm)
172+
171173
λ = convert(eltype(u), alg.damping_initial)
172174
λ_factor = convert(eltype(u), alg.damping_increase_factor)
173175
damping_increase_factor = convert(eltype(u), alg.damping_increase_factor)
@@ -194,28 +196,24 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
194196
fu_tmp = zero(fu1)
195197
mat_tmp = zero(JᵀJ)
196198

197-
tc = alg.termination_condition
198-
mode = DiffEqBase.get_termination_mode(tc)
199-
200-
atol = _get_tolerance(abstol, tc.abstol, eltype(u))
201-
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))
199+
mode = DiffEqBase.get_termination_mode(termination_condition)
202200

203201
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
204202
nothing
205203

206204
return LevenbergMarquardtCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve,
207205
J,
208-
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, DᵀD,
206+
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
207+
DᵀD,
209208
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
210209
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
211-
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage)
210+
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), termination_condition, storage)
212211
end
213212

214213
function perform_step!(cache::LevenbergMarquardtCache{true})
215-
@unpack fu1, f, make_new_J = cache
214+
@unpack fu1, f, make_new_J, tc_storage = cache
216215

217-
tc_storage = cache.tc_storage
218-
termination_condition = cache.alg.termination_condition(tc_storage)
216+
termination_condition = cache.termination_condition(tc_storage)
219217

220218
if iszero(fu1)
221219
cache.force_stop = true
@@ -270,7 +268,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
270268
if (1 - β)^b_uphill * loss loss_old
271269
# Accept step.
272270
cache.u .+= δ
273-
if loss < cache.abstol
271+
if termination_condition(cache.fu_tmp,
272+
cache.u,
273+
u_prev,
274+
cache.abstol,
275+
cache.reltol)
274276
cache.force_stop = true
275277
return nothing
276278
end
@@ -289,10 +291,9 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
289291
end
290292

291293
function perform_step!(cache::LevenbergMarquardtCache{false})
292-
@unpack fu1, f, make_new_J = cache
294+
@unpack fu1, f, make_new_J, tc_storage = cache
293295

294-
tc_storage = cache.tc_storage
295-
termination_condition = cache.alg.termination_condition(tc_storage)
296+
termination_condition = cache.termination_condition(tc_storage)
296297

297298
if iszero(fu1)
298299
cache.force_stop = true

src/raphson.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
107107
end
108108

109109
function perform_step!(cache::NewtonRaphsonCache{true})
110-
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache
110+
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, tc_storage = cache
111111
jacobian!!(J, cache)
112112

113-
tc_storage = cache.tc_storage
114113
termination_condition = cache.termination_condition(tc_storage)
115114

116115
# u = u - J \ fu

0 commit comments

Comments
 (0)