Skip to content

Commit 2700fce

Browse files
committed
Use needs_square_A from LinearSolve
1 parent 06186c0 commit 2700fce

File tree

6 files changed

+21
-52
lines changed

6 files changed

+21
-52
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ FiniteDiff = "2"
4444
ForwardDiff = "0.10.3"
4545
LeastSquaresOptim = "0.8"
4646
LineSearches = "7"
47-
LinearSolve = "2"
47+
LinearSolve = "2.12"
4848
NonlinearProblemLibrary = "0.1"
4949
PrecompileTools = "1"
5050
RecursiveArrayTools = "2"

src/gaussnewton.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
8282
alg = get_concrete_algorithm(alg_, prob)
8383
@unpack f, u0, p = prob
8484

85-
if !needs_square_A(alg.linsolve) && !(u0 isa Number) && !(u0 isa StaticArray)
86-
linsolve_with_JᵀJ = Val(false)
87-
else
88-
linsolve_with_JᵀJ = Val(true)
89-
end
85+
linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0))
9086

9187
u = alias_u0 ? u0 : deepcopy(u0)
9288
if iip

src/levenberg.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
109109
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
110110
end
111111

112-
@concrete mutable struct LevenbergMarquardtCache{iip, fastqr} <:
112+
@concrete mutable struct LevenbergMarquardtCache{iip, fastls} <:
113113
AbstractNonlinearSolveCache{iip}
114114
f
115115
alg
@@ -164,11 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
164164
u = alias_u0 ? u0 : deepcopy(u0)
165165
fu1 = evaluate_f(prob, u)
166166

167-
if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
168-
linsolve_with_JᵀJ = Val(false)
169-
else
170-
linsolve_with_JᵀJ = Val(true)
171-
end
167+
linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0))
172168

173169
if _unwrap_val(linsolve_with_JᵀJ)
174170
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
@@ -227,7 +223,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
227223
zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0))
228224
end
229225

230-
function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fastqr}
226+
function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fastls}
231227
@unpack fu1, f, make_new_J = cache
232228
if iszero(fu1)
233229
cache.force_stop = true
@@ -236,7 +232,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
236232

237233
if make_new_J
238234
jacobian!!(cache.J, cache)
239-
if fastqr
235+
if fastls
240236
cache.J² .= cache.J .^ 2
241237
sum!(cache.JᵀJ', cache.J²)
242238
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
@@ -251,7 +247,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
251247

252248
# Usual Levenberg-Marquardt step ("velocity").
253249
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
254-
if fastqr
250+
if fastls
255251
cache.mat_tmp[1:length(fu1), :] .= cache.J
256252
cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD
257253
cache.rhs_tmp[1:length(fu1)] .= _vec(fu1)
@@ -276,7 +272,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
276272
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
277273
mul!(_vec(cache.Jv), J, _vec(v))
278274
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
279-
if fastqr
275+
if fastls
280276
cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp)
281277
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du),
282278
p = p, reltol = cache.abstol)
@@ -321,7 +317,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
321317
return nothing
322318
end
323319

324-
function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fastqr}
320+
function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fastls}
325321
@unpack fu1, f, make_new_J = cache
326322
if iszero(fu1)
327323
cache.force_stop = true
@@ -330,7 +326,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
330326

331327
if make_new_J
332328
cache.J = jacobian!!(cache.J, cache)
333-
if fastqr
329+
if fastls
334330
cache.JᵀJ = _vec(sum(cache.J .^ 2; dims = 1))
335331
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
336332
else
@@ -347,7 +343,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
347343
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache
348344

349345
# Usual Levenberg-Marquardt step ("velocity").
350-
if fastqr
346+
if fastls
351347
cache.mat_tmp = vcat(J, λ .* cache.DᵀD)
352348
cache.rhs_tmp[1:length(fu1)] .= -_vec(fu1)
353349
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
@@ -367,7 +363,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
367363
# Geodesic acceleration (step_size = v + a / 2).
368364
rhs_term = _vec(((2 / h) .* ((_vec(f(u .+ h .* _restructure(u, v), p)) .-
369365
_vec(fu1)) ./ h .- J * _vec(v))))
370-
if fastqr
366+
if fastls
371367
cache.rhs_tmp[1:length(fu1)] .= -_vec(rhs_term)
372368
linres = dolinsolve(alg.precs, linsolve;
373369
b = cache.rhs_tmp, linu = _vec(cache.a), p = p, reltol = cache.abstol)

src/utils.jl

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -265,30 +265,6 @@ _reshape(x::Number, args...) = x
265265
return :(@. y += α * x)
266266
end
267267

268-
# Needs Square Matrix
269-
# FIXME: Remove once https://github.com/SciML/LinearSolve.jl/pull/400 is merged and tagged
270-
"""
271-
needs_square_A(alg)
272-
273-
Returns `true` if the algorithm requires a square matrix.
274-
"""
275-
needs_square_A(::Nothing) = false
276-
function needs_square_A(alg)
277-
try
278-
A = [1.0 2.0;
279-
3.0 4.0;
280-
5.0 6.0]
281-
b = ones(Float64, 3)
282-
solve(LinearProblem(A, b), alg)
283-
return false
284-
catch err
285-
return true
286-
end
287-
end
288-
for alg in (:QRFactorization, :FastQRFactorization, NormalCholeskyFactorization,
289-
NormalBunchKaufmanFactorization)
290-
@eval needs_square_A(::$(alg)) = false
291-
end
292-
for kralg in (LinearSolve.Krylov.lsmr!, LinearSolve.Krylov.craigmr!)
293-
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
294-
end
268+
_needs_square_A(_, ::Number) = true
269+
_needs_square_A(_, ::StaticArray) = true
270+
_needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)

test/nonlinear_least_squares.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@ nlls_problems = [prob_oop, prob_iip]
3030
solvers = [
3131
GaussNewton(),
3232
GaussNewton(; linsolve = LUFactorization()),
33+
LevenbergMarquardt(),
34+
LevenbergMarquardt(; linsolve = LUFactorization()),
3335
LeastSquaresOptimJL(:lm),
3436
LeastSquaresOptimJL(:dogleg),
3537
]
3638

37-
# Compile time on v"1.9" is too high!
38-
VERSION v"1.10-" && append!(solvers,
39-
[LevenbergMarquardt(), LevenbergMarquardt(; linsolve = LUFactorization())])
40-
4139
for prob in nlls_problems, solver in solvers
4240
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
4341
@test SciMLBase.successful_retcode(sol)

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ end
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
1818
@time @safetestset "Polyalgs" include("polyalgs.jl")
1919
@time @safetestset "Matrix Resizing" include("matrix_resizing.jl")
20-
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
20+
if VERSION v"1.10-"
21+
# Takes too long to compile on older versions
22+
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
23+
end
2124
end
2225

2326
if GROUP == "All" || GROUP == "23TestProblems"

0 commit comments

Comments
 (0)