Skip to content

Commit efb8c5e

Browse files
committed
Update all algorithms to use termination condition
1 parent 37b7015 commit efb8c5e

File tree

5 files changed

+170
-40
lines changed

5 files changed

+170
-40
lines changed

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
2626
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
2727

2828
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
29-
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
29+
abstract type AbstractNewtonAlgorithm{CJ, AD, TC} <: AbstractNonlinearSolveAlgorithm end
3030

3131
abstract type AbstractNonlinearSolveCache{iip} end
3232

src/levenberg.jl

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ numerically-difficult nonlinear systems.
7474
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
7575
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
7676
"""
77-
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
77+
@concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition} <:
78+
AbstractNewtonAlgorithm{CJ, AD, TC}
7879
ad::AD
7980
linsolve
8081
precs
@@ -85,6 +86,7 @@ numerically-difficult nonlinear systems.
8586
α_geodesic::T
8687
b_uphill::T
8788
min_damping_D::T
89+
termination_condition::TC
8890
end
8991

9092
function set_ad(alg::LevenbergMarquardt{CJ}, ad) where {CJ}
@@ -97,17 +99,22 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
9799
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
98100
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
99101
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
102+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
103+
abstol = nothing,
104+
reltol = nothing),
100105
adkwargs...)
101106
ad = default_adargs_to_adtype(; adkwargs...)
102107
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
103108
damping_initial, damping_increase_factor, damping_decrease_factor,
104-
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
109+
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D,
110+
termination_condition)
105111
end
106112

107113
@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
108114
f
109115
alg
110116
u
117+
u_prev
111118
fu1
112119
fu2
113120
du
@@ -121,6 +128,7 @@ end
121128
internalnorm
122129
retcode::ReturnCode.T
123130
abstol
131+
reltol
124132
prob
125133
DᵀD
126134
JᵀJ
@@ -145,11 +153,13 @@ end
145153
Jv
146154
mat_tmp
147155
stats::NLStats
156+
tc_storage
148157
end
149158

150159
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
151-
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
152-
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
160+
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
161+
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
162+
internalnorm = DEFAULT_NORM,
153163
linsolve_kwargs = (;), kwargs...) where {uType, iip}
154164
alg = get_concrete_algorithm(alg_, prob)
155165
@unpack f, u0, p = prob
@@ -184,15 +194,30 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
184194
fu_tmp = zero(fu1)
185195
mat_tmp = zero(JᵀJ)
186196

187-
return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
188-
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
197+
tc = alg.termination_condition
198+
mode = DiffEqBase.get_termination_mode(tc)
199+
200+
atol = _get_tolerance(abstol, tc.abstol, eltype(u))
201+
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))
202+
203+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
204+
nothing
205+
206+
return LevenbergMarquardtCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve,
207+
J,
208+
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, DᵀD,
189209
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
190210
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
191-
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
211+
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage)
212+
192213
end
193214

194215
function perform_step!(cache::LevenbergMarquardtCache{true})
195216
@unpack fu1, f, make_new_J = cache
217+
218+
tc_storage = cache.tc_storage
219+
termination_condition = cache.alg.termination_condition(tc_storage)
220+
196221
if iszero(fu1)
197222
cache.force_stop = true
198223
return nothing
@@ -205,7 +230,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
205230
cache.make_new_J = false
206231
cache.stats.njacs += 1
207232
end
208-
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
233+
@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
209234

210235
# Usual Levenberg-Marquardt step ("velocity").
211236
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
@@ -246,7 +271,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
246271
if (1 - β)^b_uphill * loss loss_old
247272
# Accept step.
248273
cache.u .+= δ
249-
if loss < cache.abstol
274+
if termination_condition(cache.fu_tmp,
275+
cache.u,
276+
u_prev,
277+
cache.abstol,
278+
cache.reltol)
250279
cache.force_stop = true
251280
return nothing
252281
end
@@ -258,13 +287,18 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
258287
cache.make_new_J = true
259288
end
260289
end
290+
@. u_prev = u
261291
cache.λ *= cache.λ_factor
262292
cache.λ_factor = cache.damping_increase_factor
263293
return nothing
264294
end
265295

266296
function perform_step!(cache::LevenbergMarquardtCache{false})
267297
@unpack fu1, f, make_new_J = cache
298+
299+
tc_storage = cache.tc_storage
300+
termination_condition = cache.alg.termination_condition(tc_storage)
301+
268302
if iszero(fu1)
269303
cache.force_stop = true
270304
return nothing
@@ -281,7 +315,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
281315
cache.make_new_J = false
282316
cache.stats.njacs += 1
283317
end
284-
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
318+
319+
@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
285320

286321
cache.mat_tmp = JᵀJ + λ * DᵀD
287322
# Usual Levenberg-Marquardt step ("velocity").
@@ -322,7 +357,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
322357
if (1 - β)^b_uphill * loss loss_old
323358
# Accept step.
324359
cache.u += δ
325-
if loss < cache.abstol
360+
if termination_condition(fu_new, cache.u, u_prev, cache.abstol, cache.reltol)
326361
cache.force_stop = true
327362
return nothing
328363
end
@@ -334,6 +369,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
334369
cache.make_new_J = true
335370
end
336371
end
372+
cache.u_prev = @. cache.u
337373
cache.λ *= cache.λ_factor
338374
cache.λ_factor = cache.damping_increase_factor
339375
return nothing

src/raphson.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ for large-scale and numerically-difficult nonlinear systems.
3030
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3131
used here directly, and they will be converted to the correct `LineSearch`.
3232
"""
33-
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: AbstractNewtonAlgorithm{CJ, AD}
33+
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <:
34+
AbstractNewtonAlgorithm{CJ, AD, TC}
3435
ad::AD
3536
linsolve
3637
precs
@@ -43,19 +44,24 @@ function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ}
4344
end
4445

4546
function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
46-
linesearch = LineSearch(), precs = DEFAULT_PRECS, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
47-
abstol = nothing,
48-
reltol = nothing), adkwargs...)
47+
linesearch = LineSearch(), precs = DEFAULT_PRECS,
48+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
49+
abstol = nothing,
50+
reltol = nothing), adkwargs...)
4951
ad = default_adargs_to_adtype(; adkwargs...)
5052
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
51-
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, termination_condition)
53+
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad,
54+
linsolve,
55+
precs,
56+
linesearch,
57+
termination_condition)
5258
end
5359

5460
@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
5561
f
5662
alg
5763
u
58-
uprev
64+
u_prev
5965
fu1
6066
fu2
6167
du
@@ -77,7 +83,8 @@ end
7783
end
7884

7985
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
80-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
86+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
87+
internalnorm = DEFAULT_NORM,
8188
linsolve_kwargs = (;), kwargs...) where {uType, iip}
8289
alg = get_concrete_algorithm(alg_, prob)
8390
@unpack f, u0, p = prob
@@ -86,7 +93,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
8693
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
8794
linsolve_kwargs)
8895

89-
9096
tc = alg.termination_condition
9197
mode = DiffEqBase.get_termination_mode(tc)
9298

@@ -98,11 +104,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
98104

99105
return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
100106
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob,
101-
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), storage)
107+
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)),
108+
storage)
102109
end
103110

104111
function perform_step!(cache::NewtonRaphsonCache{true})
105-
@unpack u, uprev, fu1, f, p, alg, J, linsolve, du = cache
112+
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache
106113
jacobian!!(J, cache)
107114

108115
tc_storage = cache.tc_storage
@@ -118,9 +125,10 @@ function perform_step!(cache::NewtonRaphsonCache{true})
118125
@. u = u - α * du
119126
f(cache.fu1, u, p)
120127

121-
termination_condition(cache.fu1, u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
128+
termination_condition(cache.fu1, u, u_prev, cache.abstol, cache.reltol) &&
129+
(cache.force_stop = true)
122130

123-
@. uprev = u
131+
@. u_prev = u
124132
cache.stats.nf += 1
125133
cache.stats.njacs += 1
126134
cache.stats.nsolve += 1
@@ -129,12 +137,11 @@ function perform_step!(cache::NewtonRaphsonCache{true})
129137
end
130138

131139
function perform_step!(cache::NewtonRaphsonCache{false})
132-
@unpack u, uprev, fu1, f, p, alg, linsolve = cache
140+
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
133141

134142
tc_storage = cache.tc_storage
135143
termination_condition = cache.alg.termination_condition(tc_storage)
136144

137-
138145
cache.J = jacobian!!(cache.J, cache)
139146
# u = u - J \ fu
140147
if linsolve === nothing
@@ -150,9 +157,10 @@ function perform_step!(cache::NewtonRaphsonCache{false})
150157
cache.u = @. u - α * cache.du # `u` might not support mutation
151158
cache.fu1 = f(cache.u, p)
152159

153-
termination_condition(cache.fu1, cache.u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
160+
termination_condition(cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) &&
161+
(cache.force_stop = true)
154162

155-
cache.uprev = cache.u
163+
cache.u_prev = @. cache.u
156164
cache.stats.nf += 1
157165
cache.stats.njacs += 1
158166
cache.stats.nsolve += 1

0 commit comments

Comments
 (0)