120
120
fu
121
121
fu_cache
122
122
fu_cache_2
123
- du
124
- du_cache
125
123
J
126
124
JᵀJ
127
125
Jv
@@ -197,9 +195,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
197
195
198
196
loss = internalnorm (fu)
199
197
200
- @bb a = similar (du)
201
- @bb v_old = copy (v)
202
- @bb δ = similar (du)
198
+ a = du # `du` is not used anywhere, use it to store `a`
203
199
204
200
make_new_J = true
205
201
@@ -215,8 +211,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
215
211
trace = init_nonlinearsolve_trace (alg, u, fu, ApplyArray (__zero, J), du; kwargs... )
216
212
217
213
if ! fastls
218
- @bb mat_tmp = similar (JᵀJ)
219
- @bb mat_tmp .*= T (0 )
214
+ @bb mat_tmp = zero (JᵀJ)
220
215
rhs_tmp = nothing
221
216
else
222
217
mat_tmp = _vcat (J, DᵀD)
@@ -229,15 +224,14 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
229
224
@bb u_cache = copy (u)
230
225
@bb u_cache_2 = similar (u)
231
226
@bb fu_cache_2 = similar (fu)
232
- @bb du_cache = similar (du)
233
227
Jv = J * v
234
- @bb v_cache = similar (v)
228
+ @bb v_cache = zero (v)
235
229
236
230
return LevenbergMarquardtCache {iip, fastls} (f, alg, u, u_cache, u_cache_2, fu, fu_cache,
237
- fu_cache_2, du, du_cache, J, JᵀJ, Jv, DᵀD, v, v_cache, a, mat_tmp, rhs_tmp, p, uf,
231
+ fu_cache_2, J, JᵀJ, Jv, DᵀD, v, v_cache, a, mat_tmp, rhs_tmp, p, uf,
238
232
linsolve, jac_cache, false , maxiters, internalnorm, ReturnCode. Default, abstol,
239
233
reltol, prob, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h,
240
- α_geodesic, b_uphill, min_damping_D, internalnorm (v_cache) , loss, make_new_J,
234
+ α_geodesic, b_uphill, min_damping_D, loss , loss, make_new_J,
241
235
NLStats (1 , 0 , 0 , 0 , 0 ), tc_cache_1, tc_cache_2, trace)
242
236
end
243
237
@@ -271,11 +265,12 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
271
265
end
272
266
linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
273
267
b = cache. rhs_tmp, linu = _vec (cache. v), cache. p, reltol = cache. abstol)
268
+ cache. linsolve = linres. cache
274
269
@bb @. cache. v = - linres. u
275
270
else
276
271
@bb cache. u_cache_2 = transpose (cache. J) × cache. fu
277
272
@bb @. cache. mat_tmp = cache. JᵀJ + cache. λ * cache. DᵀD
278
- linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
273
+ linres = dolinsolve (alg. precs, linsolve; A = __maybe_symmetric ( cache. mat_tmp) ,
279
274
b = _vec (cache. u_cache_2), linu = _vec (cache. v), cache. p, reltol = cache. abstol)
280
275
cache. linsolve = linres. cache
281
276
@bb @. cache. v = - linres. u
@@ -289,7 +284,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
289
284
evaluate_f (cache, cache. u_cache_2, cache. p, Val (:fu_cache_2 ))
290
285
291
286
# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
292
- # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
287
+ # NOTE: Don't pass `A`` in again, since we want to reuse the previous solve
293
288
@bb cache. Jv = cache. J × cache. v
294
289
@bb @. cache. fu_cache_2 = (2 / cache. h) *
295
290
((cache. fu_cache_2 - cache. fu) / cache. h - cache. Jv)
@@ -301,13 +296,14 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
301
296
end
302
297
linres = dolinsolve (alg. precs, linsolve; b = cache. rhs_tmp, linu = _vec (cache. a),
303
298
cache. p, reltol = cache. abstol)
299
+ cache. linsolve = linres. cache
304
300
@bb @. cache. a = - linres. u
305
301
else
306
- @bb cache. u_cache_2 = transpose (J) × cache. fu_cache_2
302
+ @bb cache. u_cache_2 = transpose (cache . J) × cache. fu_cache_2
307
303
linres = dolinsolve (alg. precs, linsolve; b = _vec (cache. u_cache_2),
308
304
linu = _vec (cache. a), cache. p, reltol = cache. abstol)
309
305
cache. linsolve = linres. cache
310
- @bb @. cache. a = - linres. du
306
+ @bb @. cache. a = - linres. u
311
307
end
312
308
313
309
cache. stats. nsolve += 2
@@ -316,8 +312,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
316
312
# Require acceptable steps to satisfy the following condition.
317
313
norm_v = cache. internalnorm (cache. v)
318
314
if 2 * cache. internalnorm (cache. a) ≤ cache. α_geodesic * norm_v
319
- @bb @. cache. du_cache = cache. v + cache. a / 2
320
- @bb @. cache. u_cache_2 = cache. u + cache. du_cache
315
+ @bb @. cache. u_cache_2 = cache. u + cache. v + cache. a / 2
321
316
evaluate_f (cache, cache. u_cache_2, cache. p, Val (:fu_cache_2 ))
322
317
loss = cache. internalnorm (cache. fu_cache_2)
323
318
@@ -326,7 +321,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
326
321
if (1 - β)^ cache. b_uphill * loss ≤ cache. loss_old
327
322
# Accept step.
328
323
@bb copyto! (cache. u, cache. u_cache_2)
329
- check_and_update! (cache. tc_cache_1, cache, cache. fu_cache , cache. u,
324
+ check_and_update! (cache. tc_cache_1, cache, cache. fu_cache_2 , cache. u,
330
325
cache. u_cache)
331
326
if ! cache. force_stop && cache. tc_cache_2 != = nothing # For NLLS Problems
332
327
@bb @. cache. fu = cache. fu_cache_2 - cache. fu
0 commit comments