@@ -179,7 +179,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
179
179
else
180
180
uf, linsolve, J, fu_cache, jac_cache, du = jacobian_caches (alg, f, u, p,
181
181
Val (iip); linsolve_kwargs, linsolve_with_JᵀJ = Val (false ))
182
- @bb JᵀJ = similar (u)
182
+ u_ = _vec (u)
183
+ @bb JᵀJ = similar (u_)
183
184
@bb v = similar (du)
184
185
end
185
186
@@ -241,91 +242,103 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
241
242
end
242
243
243
244
function perform_step! (cache:: LevenbergMarquardtCache{iip, fastls} ) where {iip, fastls}
245
+ @unpack alg, linsolve = cache
246
+
244
247
if cache. make_new_J
245
248
cache. J = jacobian!! (cache. J, cache)
246
249
if fastls
247
250
cache. JᵀJ = __sum_JᵀJ!! (cache. JᵀJ, cache. J)
248
- # cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
249
251
else
250
252
@bb cache. JᵀJ = transpose (cache. J) × cache. J
251
- # cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
252
253
end
254
+ cache. DᵀD = __update_LM_diagonal!! (cache. DᵀD, cache. JᵀJ)
253
255
cache. make_new_J = false
254
256
end
255
257
256
- # @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
257
-
258
258
# Usual Levenberg-Marquardt step ("velocity").
259
259
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
260
- # if fastls
261
- # copyto!(@view(cache.mat_tmp[1:length(fu1), :]), cache.J)
262
- # cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD
263
- # cache.rhs_tmp[1:length(fu1)] .= _vec(fu1)
264
- # linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
265
- # b = cache.rhs_tmp, linu = _vec(cache.du), p = p, reltol = cache.abstol)
266
- # _vec(cache.v) .= -_vec(cache.du)
267
- # else
268
- # mul!(_vec(cache.u_tmp), J', _vec(fu1))
269
- # @. cache.mat_tmp = JᵀJ + λ * DᵀD
270
- # linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
271
- # b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
272
- # cache.linsolve = linres.cache
273
- # _vec(cache.v) .= -_vec(cache.du)
274
- # end
275
-
276
- # update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
277
- # cache.v)
278
-
279
- # # Geodesic acceleration (step_size = v + a / 2).
280
- # @unpack v, α_geodesic, h = cache
281
- # cache.u_tmp .= _restructure(cache.u_tmp, _vec(u) .+ h .* _vec(v))
282
- # f(cache.fu_tmp, cache.u_tmp, p)
283
-
284
- # # The following lines do: cache.a = -J \ cache.fu_tmp
285
- # # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
286
- # mul!(_vec(cache.Jv), J, _vec(v))
287
- # @. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
288
- # if fastls
289
- # cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp)
290
- # linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du),
291
- # p = p, reltol = cache.abstol)
292
- # else
293
- # mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
294
- # linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
295
- # linu = _vec(cache.du), p = p, reltol = cache.abstol)
296
- # cache.linsolve = linres.cache
297
- # @. cache.a = -cache.du
298
- # end
260
+ if fastls
261
+ if setindex_trait (cache. mat_tmp) === CanSetindex ()
262
+ copyto! (@view (cache. mat_tmp[1 : length (cache. fu), :]), cache. J)
263
+ cache. mat_tmp[(length (cache. fu) + 1 ): end , :] .= cache. λ .* cache. DᵀD
264
+ else
265
+ cache. mat_tmp = _vcat (cache. J, cache. λ .* cache. DᵀD)
266
+ end
267
+ if setindex_trait (cache. rhs_tmp) === CanSetindex ()
268
+ cache. rhs_tmp[1 : length (cache. fu)] .= _vec (cache. fu)
269
+ else
270
+ cache. rhs_tmp = _vcat (_vec (cache. fu), zero (_vec (cache. u)))
271
+ end
272
+ linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
273
+ b = cache. rhs_tmp, linu = _vec (cache. v), cache. p, reltol = cache. abstol)
274
+ @bb @. cache. v = - linres. u
275
+ else
276
+ @bb cache. u_cache_2 = transpose (J) × cache. fu
277
+ @bb @. cache. mat_tmp = cache. JᵀJ + cache. λ * cache. DᵀD
278
+ linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
279
+ b = _vec (cache. u_cache_2), linu = _vec (cache. v), cache. p, reltol = cache. abstol)
280
+ cache. linsolve = linres. cache
281
+ @bb @. cache. v = - linres. u
282
+ end
283
+
284
+ update_trace! (cache. trace, cache. stats. nsteps + 1 , get_u (cache), get_fu (cache), cache. J,
285
+ cache. v)
286
+
287
+ # Geodesic acceleration (step_size = v + a / 2).
288
+ @bb @. cache. u_cache_2 = cache. u + cache. h * cache. v
289
+ evaluate_f (cache, cache. u_cache_2, cache. p, Val (:fu_cache_2 ))
290
+
291
+ # The following lines do: cache.a = -J \ cache.fu_tmp
292
+ # NOTE: Don't pass `A` in again, since we want to reuse the previous solve
293
+ @bb cache. Jv = cache. J × cache. v
294
+ @bb @. cache. fu_cache_2 = (2 / cache. h) *
295
+ ((cache. fu_cache_2 - cache. fu) / cache. h - cache. Jv)
296
+ if fastls
297
+ if setindex_trait (cache. rhs_tmp) === CanSetindex ()
298
+ cache. rhs_tmp[1 : length (cache. fu)] .= _vec (cache. fu_cache_2)
299
+ else
300
+ cache. rhs_tmp = _vcat (_vec (cache. fu_cache_2), zero (_vec (cache. u)))
301
+ end
302
+ linres = dolinsolve (alg. precs, linsolve; b = cache. rhs_tmp, linu = _vec (cache. a),
303
+ cache. p, reltol = cache. abstol)
304
+ @bb @. cache. a = - linres. u
305
+ else
306
+ @bb cache. u_cache_2 = transpose (J) × cache. fu_cache_2
307
+ linres = dolinsolve (alg. precs, linsolve; b = _vec (cache. u_cache_2),
308
+ linu = _vec (cache. a), cache. p, reltol = cache. abstol)
309
+ cache. linsolve = linres. cache
310
+ @bb @. cache. a = - linres. du
311
+ end
299
312
300
313
cache. stats. nsolve += 2
301
314
cache. stats. nfactors += 2
302
315
303
316
# Require acceptable steps to satisfy the following condition.
304
- norm_v = cache. internalnorm (v)
305
- if 2 * cache. internalnorm (cache. a) ≤ α_geodesic * norm_v
306
- # _vec( cache.δ) .= _vec(v) .+ _vec( cache.a) . / 2
307
- # @unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
308
- # f (cache.fu_tmp, u .+ δ, p )
309
- # loss = cache.internalnorm(cache.fu_tmp )
310
-
311
- # # Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
312
- # β = dot(v, v_old ) / (norm_v * norm_v_old)
313
- # if (1 - β)^b_uphill * loss ≤ loss_old
314
- # # Accept step.
315
- # cache.u .+= δ
316
- # check_and_update!(cache.tc_cache_1, cache, cache.fu_tmp , cache.u, cache.u_prev)
317
- # if ! cache.force_stop && cache.tc_cache_2 !== nothing
318
- # # For NLLS Problems
319
- # cache.fu1 . = cache.fu_tmp . - cache.fu1
320
- # check_and_update!(cache.tc_cache_2, cache, cache.fu1 , cache.u, cache.u_prev )
321
- # end
322
- # cache.fu1 .= cache.fu_tmp
323
- # _vec (cache.v_old) .= _vec( v)
324
- # cache.norm_v_old = norm_v
325
- # cache.loss_old = loss
326
- # cache.λ_factor = 1 / cache.damping_decrease_factor
327
- # cache.make_new_J = true
328
- # end
317
+ norm_v = cache. internalnorm (cache . v)
318
+ if 2 * cache. internalnorm (cache. a) ≤ cache . α_geodesic * norm_v
319
+ @bb @. cache. du_cache = cache . v + cache. a / 2
320
+ @bb @. cache . u_cache_2 = cache . u + cache. du_cache
321
+ evaluate_f (cache, cache . u_cache_2, cache . p, Val ( :fu_cache_2 ) )
322
+ loss = cache. internalnorm (cache. fu_cache_2 )
323
+
324
+ # Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
325
+ β = dot (cache . v, cache . v_cache ) / (norm_v * cache . norm_v_old)
326
+ if (1 - β)^ cache . b_uphill * loss ≤ cache . loss_old
327
+ # Accept step.
328
+ @bb copyto! ( cache. u, cache . u_cache_2)
329
+ check_and_update! (cache. tc_cache_1, cache, cache. fu_cache , cache. u,
330
+ cache. u_cache)
331
+ if ! cache . force_stop && cache . tc_cache_2 != = nothing # For NLLS Problems
332
+ @bb @. cache. fu = cache. fu_cache_2 - cache. fu
333
+ check_and_update! (cache. tc_cache_2, cache, cache. fu , cache. u, cache. u_cache )
334
+ end
335
+ @bb copyto! ( cache. fu_cache, cache. fu_cache_2)
336
+ @bb copyto! (cache. v_cache, cache . v)
337
+ cache. norm_v_old = norm_v
338
+ cache. loss_old = loss
339
+ cache. λ_factor = 1 / cache. damping_decrease_factor
340
+ cache. make_new_J = true
341
+ end
329
342
end
330
343
331
344
@bb copyto! (cache. u_cache, cache. u)
0 commit comments