@@ -72,11 +72,6 @@ numerically-difficult nonlinear systems.
72
72
where `J` is the Jacobian. It is suggested by
73
73
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
74
74
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
75
-
76
- !!! warning
77
-
78
- `linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
79
- Support for the OOP version is planned!
80
75
"""
81
76
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
82
77
ad:: AD
@@ -102,18 +97,17 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
102
97
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
103
98
end
104
99
105
- @concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} < :
106
- AbstractNonlinearSolveCache{iip}
100
+ @concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
107
101
f
108
102
alg
109
- u:: uType
103
+ u
110
104
fu1
111
105
fu2
112
106
du
113
107
p
114
108
uf
115
109
linsolve
116
- J:: jType
110
+ J
117
111
jac_cache
118
112
force_stop:: Bool
119
113
maxiters:: Int
@@ -122,27 +116,27 @@ end
122
116
abstol
123
117
prob
124
118
DᵀD
125
- JᵀJ:: jType
126
- λ:: λType
127
- λ_factor: :λType
128
- damping_increase_factor: :λType
129
- damping_decrease_factor: :λType
130
- h: :λType
131
- α_geodesic: :λType
132
- b_uphill: :λType
133
- min_damping_D: :λType
134
- v:: uType
135
- a:: uType
136
- tmp_vec:: uType
137
- v_old:: uType
138
- norm_v_old:: lossType
139
- δ:: uType
140
- loss_old:: lossType
119
+ JᵀJ
120
+ λ
121
+ λ_factor
122
+ damping_increase_factor
123
+ damping_decrease_factor
124
+ h
125
+ α_geodesic
126
+ b_uphill
127
+ min_damping_D
128
+ v
129
+ a
130
+ tmp_vec
131
+ v_old
132
+ norm_v_old
133
+ δ
134
+ loss_old
141
135
make_new_J:: Bool
142
136
fu_tmp
143
137
u_tmp
144
138
Jv
145
- mat_tmp:: jType
139
+ mat_tmp
146
140
stats:: NLStats
147
141
end
148
142
@@ -153,8 +147,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
153
147
@unpack f, u0, p = prob
154
148
u = alias_u0 ? u0 : deepcopy (u0)
155
149
fu1 = evaluate_f (prob, u)
156
- uf, linsolve, J, fu2, jac_cache, du = jacobian_caches (alg, f, u, p, Val (iip);
157
- linsolve_kwargs)
150
+ uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches (alg, f, u, p, Val (iip);
151
+ linsolve_kwargs, linsolve_with_JᵀJ = Val ( true ) )
158
152
159
153
λ = convert (eltype (u), alg. damping_initial)
160
154
λ_factor = convert (eltype (u), alg. damping_increase_factor)
@@ -174,12 +168,10 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
174
168
end
175
169
176
170
loss = internalnorm (fu1)
177
- JᵀJ = J isa Number ? zero (J) : similar (J, size (J, 2 ), size (J, 2 ))
178
- v = zero (u)
179
- a = zero (u)
180
- tmp_vec = zero (u)
181
- v_old = zero (u)
182
- δ = zero (u)
171
+ a = _mutable_zero (u)
172
+ tmp_vec = _mutable_zero (u)
173
+ v_old = _mutable_zero (u)
174
+ δ = _mutable_zero (u)
183
175
make_new_J = true
184
176
fu_tmp = zero (fu1)
185
177
mat_tmp = zero (JᵀJ)
@@ -223,7 +215,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
223
215
# The following lines do: cache.a = -J \ cache.fu_tmp
224
216
mul! (cache. Jv, J, v)
225
217
@. cache. fu_tmp = (2 / h) * ((cache. fu_tmp - fu1) / h - cache. Jv)
226
- linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp, b = _vec (cache. fu_tmp),
218
+ mul! (cache. u_tmp, J' , cache. fu_tmp)
219
+ linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp, b = _vec (cache. u_tmp),
227
220
linu = _vec (cache. du), p = p, reltol = cache. abstol)
228
221
cache. linsolve = linres. cache
229
222
@. cache. a = - cache. du
@@ -279,15 +272,30 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
279
272
cache. make_new_J = false
280
273
cache. stats. njacs += 1
281
274
end
282
- @unpack u, p, λ, JᵀJ, DᵀD, J = cache
275
+ @unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
283
276
284
277
cache. mat_tmp = JᵀJ + λ * DᵀD
285
278
# Usual Levenberg-Marquardt step ("velocity").
286
- cache. v = - cache. mat_tmp \ (J' * fu1)
279
+ if linsolve === nothing
280
+ cache. v = - cache. mat_tmp \ (J' * fu1)
281
+ else
282
+ linres = dolinsolve (alg. precs, linsolve; A = - cache. mat_tmp, b = _vec (J' * fu1),
283
+ linu = _vec (cache. v), p, reltol = cache. abstol)
284
+ cache. linsolve = linres. cache
285
+ end
287
286
288
287
@unpack v, h, α_geodesic = cache
289
288
# Geodesic acceleration (step_size = v + a / 2).
290
- cache. a = - cache. mat_tmp \ ((2 / h) .* ((f (u .+ h .* v, p) .- fu1) ./ h .- J * v))
289
+ if linsolve === nothing
290
+ cache. a = - cache. mat_tmp \
291
+ _vec (J' * ((2 / h) .* ((f (u .+ h .* v, p) .- fu1) ./ h .- J * v)))
292
+ else
293
+ linres = dolinsolve (alg. precs, linsolve; A = - cache. mat_tmp,
294
+ b = _mutable (_vec (J' *
295
+ ((2 / h) .* ((f (u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
296
+ linu = _vec (cache. a), p, reltol = cache. abstol)
297
+ cache. linsolve = linres. cache
298
+ end
291
299
cache. stats. nsolve += 1
292
300
cache. stats. nfactors += 1
293
301
0 commit comments