@@ -74,7 +74,8 @@ numerically-difficult nonlinear systems.
74
74
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
75
75
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
76
76
"""
77
- @concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
77
+ @concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition } < :
78
+ AbstractNewtonAlgorithm{CJ, AD, TC}
78
79
ad:: AD
79
80
linsolve
80
81
precs
@@ -85,6 +86,7 @@ numerically-difficult nonlinear systems.
85
86
α_geodesic:: T
86
87
b_uphill:: T
87
88
min_damping_D:: T
89
+ termination_condition:: TC
88
90
end
89
91
90
92
function set_ad (alg:: LevenbergMarquardt{CJ} , ad) where {CJ}
@@ -97,17 +99,22 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
97
99
precs = DEFAULT_PRECS, damping_initial:: Real = 1.0 , damping_increase_factor:: Real = 2.0 ,
98
100
damping_decrease_factor:: Real = 3.0 , finite_diff_step_geodesic:: Real = 0.1 ,
99
101
α_geodesic:: Real = 0.75 , b_uphill:: Real = 1.0 , min_damping_D:: AbstractFloat = 1e-8 ,
102
+ termination_condition = NLSolveTerminationCondition (NLSolveTerminationMode. NLSolveDefault;
103
+ abstol = nothing ,
104
+ reltol = nothing ),
100
105
adkwargs... )
101
106
ad = default_adargs_to_adtype (; adkwargs... )
102
107
return LevenbergMarquardt {_unwrap_val(concrete_jac)} (ad, linsolve, precs,
103
108
damping_initial, damping_increase_factor, damping_decrease_factor,
104
- finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
109
+ finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D,
110
+ termination_condition)
105
111
end
106
112
107
113
@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
108
114
f
109
115
alg
110
116
u
117
+ u_prev
111
118
fu1
112
119
fu2
113
120
du
121
128
internalnorm
122
129
retcode:: ReturnCode.T
123
130
abstol
131
+ reltol
124
132
prob
125
133
DᵀD
126
134
JᵀJ
@@ -145,11 +153,13 @@ end
145
153
Jv
146
154
mat_tmp
147
155
stats:: NLStats
156
+ tc_storage
148
157
end
149
158
150
159
function SciMLBase. __init (prob:: Union {NonlinearProblem{uType, iip},
151
- NonlinearLeastSquaresProblem{uType, iip}}, alg_:: LevenbergMarquardt ,
152
- args... ; alias_u0 = false , maxiters = 1000 , abstol = 1e-6 , internalnorm = DEFAULT_NORM,
160
+ NonlinearLeastSquaresProblem{uType, iip}}, alg_:: LevenbergMarquardt ,
161
+ args... ; alias_u0 = false , maxiters = 1000 , abstol = nothing , reltol = nothing ,
162
+ internalnorm = DEFAULT_NORM,
153
163
linsolve_kwargs = (;), kwargs... ) where {uType, iip}
154
164
alg = get_concrete_algorithm (alg_, prob)
155
165
@unpack f, u0, p = prob
@@ -184,15 +194,30 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
184
194
fu_tmp = zero (fu1)
185
195
mat_tmp = zero (JᵀJ)
186
196
187
- return LevenbergMarquardtCache {iip} (f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
188
- jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol, prob, DᵀD,
197
+ tc = alg. termination_condition
198
+ mode = DiffEqBase. get_termination_mode (tc)
199
+
200
+ atol = _get_tolerance (abstol, tc. abstol, eltype (u))
201
+ rtol = _get_tolerance (reltol, tc. reltol, eltype (u))
202
+
203
+ storage = mode ∈ DiffEqBase. SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult () :
204
+ nothing
205
+
206
+ return LevenbergMarquardtCache {iip} (f, alg, u, copy (u), fu1, fu2, du, p, uf, linsolve,
207
+ J,
208
+ jac_cache, false , maxiters, internalnorm, ReturnCode. Default, atol, rtol, prob, DᵀD,
189
209
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
190
210
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
191
- zero (u), zero (fu1), mat_tmp, NLStats (1 , 0 , 0 , 0 , 0 ))
211
+ zero (u), zero (fu1), mat_tmp, NLStats (1 , 0 , 0 , 0 , 0 ), storage)
212
+
192
213
end
193
214
194
215
function perform_step! (cache:: LevenbergMarquardtCache{true} )
195
216
@unpack fu1, f, make_new_J = cache
217
+
218
+ tc_storage = cache. tc_storage
219
+ termination_condition = cache. alg. termination_condition (tc_storage)
220
+
196
221
if iszero (fu1)
197
222
cache. force_stop = true
198
223
return nothing
@@ -205,7 +230,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
205
230
cache. make_new_J = false
206
231
cache. stats. njacs += 1
207
232
end
208
- @unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
233
+ @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
209
234
210
235
# Usual Levenberg-Marquardt step ("velocity").
211
236
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
@@ -246,7 +271,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
246
271
if (1 - β)^ b_uphill * loss ≤ loss_old
247
272
# Accept step.
248
273
cache. u .+ = δ
249
- if loss < cache. abstol
274
+ if termination_condition (cache. fu_tmp,
275
+ cache. u,
276
+ u_prev,
277
+ cache. abstol,
278
+ cache. reltol)
250
279
cache. force_stop = true
251
280
return nothing
252
281
end
@@ -258,13 +287,18 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
258
287
cache. make_new_J = true
259
288
end
260
289
end
290
+ @. u_prev = u
261
291
cache. λ *= cache. λ_factor
262
292
cache. λ_factor = cache. damping_increase_factor
263
293
return nothing
264
294
end
265
295
266
296
function perform_step! (cache:: LevenbergMarquardtCache{false} )
267
297
@unpack fu1, f, make_new_J = cache
298
+
299
+ tc_storage = cache. tc_storage
300
+ termination_condition = cache. alg. termination_condition (tc_storage)
301
+
268
302
if iszero (fu1)
269
303
cache. force_stop = true
270
304
return nothing
@@ -281,7 +315,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
281
315
cache. make_new_J = false
282
316
cache. stats. njacs += 1
283
317
end
284
- @unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
318
+
319
+ @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
285
320
286
321
cache. mat_tmp = JᵀJ + λ * DᵀD
287
322
# Usual Levenberg-Marquardt step ("velocity").
@@ -322,7 +357,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
322
357
if (1 - β)^ b_uphill * loss ≤ loss_old
323
358
# Accept step.
324
359
cache. u += δ
325
- if loss < cache. abstol
360
+ if termination_condition (fu_new, cache . u, u_prev, cache. abstol, cache . reltol)
326
361
cache. force_stop = true
327
362
return nothing
328
363
end
@@ -334,6 +369,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
334
369
cache. make_new_J = true
335
370
end
336
371
end
372
+ cache. u_prev = @. cache. u
337
373
cache. λ *= cache. λ_factor
338
374
cache. λ_factor = cache. damping_increase_factor
339
375
return nothing
0 commit comments