@@ -109,7 +109,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
109
109
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
110
110
end
111
111
112
- @concrete mutable struct LevenbergMarquardtCache{iip, fastqr } < :
112
+ @concrete mutable struct LevenbergMarquardtCache{iip, fastls } < :
113
113
AbstractNonlinearSolveCache{iip}
114
114
f
115
115
alg
@@ -164,11 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
164
164
u = alias_u0 ? u0 : deepcopy (u0)
165
165
fu1 = evaluate_f (prob, u)
166
166
167
- if ! needs_square_A (alg. linsolve) && ! (u isa Number) && ! (u isa StaticArray)
168
- linsolve_with_JᵀJ = Val (false )
169
- else
170
- linsolve_with_JᵀJ = Val (true )
171
- end
167
+ linsolve_with_JᵀJ = Val (_needs_square_A (alg, u0))
172
168
173
169
if _unwrap_val (linsolve_with_JᵀJ)
174
170
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches (alg, f, u, p,
@@ -227,7 +223,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
227
223
zero (u), zero (fu1), mat_tmp, rhs_tmp, J², NLStats (1 , 0 , 0 , 0 , 0 ))
228
224
end
229
225
230
- function perform_step! (cache:: LevenbergMarquardtCache{true, fastqr } ) where {fastqr }
226
+ function perform_step! (cache:: LevenbergMarquardtCache{true, fastls } ) where {fastls }
231
227
@unpack fu1, f, make_new_J = cache
232
228
if iszero (fu1)
233
229
cache. force_stop = true
@@ -236,7 +232,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
236
232
237
233
if make_new_J
238
234
jacobian!! (cache. J, cache)
239
- if fastqr
235
+ if fastls
240
236
cache. J² .= cache. J .^ 2
241
237
sum! (cache. JᵀJ' , cache. J²)
242
238
cache. DᵀD. diag .= max .(cache. DᵀD. diag, cache. JᵀJ)
@@ -251,7 +247,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
251
247
252
248
# Usual Levenberg-Marquardt step ("velocity").
253
249
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
254
- if fastqr
250
+ if fastls
255
251
cache. mat_tmp[1 : length (fu1), :] .= cache. J
256
252
cache. mat_tmp[(length (fu1) + 1 ): end , :] .= λ .* cache. DᵀD
257
253
cache. rhs_tmp[1 : length (fu1)] .= _vec (fu1)
@@ -276,7 +272,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
276
272
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
277
273
mul! (_vec (cache. Jv), J, _vec (v))
278
274
@. cache. fu_tmp = (2 / h) * ((cache. fu_tmp - fu1) / h - cache. Jv)
279
- if fastqr
275
+ if fastls
280
276
cache. rhs_tmp[1 : length (fu1)] .= _vec (cache. fu_tmp)
281
277
linres = dolinsolve (alg. precs, linsolve; b = cache. rhs_tmp, linu = _vec (cache. du),
282
278
p = p, reltol = cache. abstol)
@@ -321,7 +317,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
321
317
return nothing
322
318
end
323
319
324
- function perform_step! (cache:: LevenbergMarquardtCache{false, fastqr } ) where {fastqr }
320
+ function perform_step! (cache:: LevenbergMarquardtCache{false, fastls } ) where {fastls }
325
321
@unpack fu1, f, make_new_J = cache
326
322
if iszero (fu1)
327
323
cache. force_stop = true
@@ -330,7 +326,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
330
326
331
327
if make_new_J
332
328
cache. J = jacobian!! (cache. J, cache)
333
- if fastqr
329
+ if fastls
334
330
cache. JᵀJ = _vec (sum (cache. J .^ 2 ; dims = 1 ))
335
331
cache. DᵀD. diag .= max .(cache. DᵀD. diag, cache. JᵀJ)
336
332
else
@@ -347,7 +343,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
347
343
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
348
344
349
345
# Usual Levenberg-Marquardt step ("velocity").
350
- if fastqr
346
+ if fastls
351
347
cache. mat_tmp = vcat (J, λ .* cache. DᵀD)
352
348
cache. rhs_tmp[1 : length (fu1)] .= - _vec (fu1)
353
349
linres = dolinsolve (alg. precs, linsolve; A = cache. mat_tmp,
@@ -367,7 +363,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
367
363
# Geodesic acceleration (step_size = v + a / 2).
368
364
rhs_term = _vec (((2 / h) .* ((_vec (f (u .+ h .* _restructure (u, v), p)) .-
369
365
_vec (fu1)) ./ h .- J * _vec (v))))
370
- if fastqr
366
+ if fastls
371
367
cache. rhs_tmp[1 : length (fu1)] .= - _vec (rhs_term)
372
368
linres = dolinsolve (alg. precs, linsolve;
373
369
b = cache. rhs_tmp, linu = _vec (cache. a), p = p, reltol = cache. abstol)
0 commit comments