Skip to content

Commit a4c228d

Browse files
committed
Add a function to check if square A is needed
1 parent 62ed82d commit a4c228d

File tree

8 files changed

+45
-21
lines changed

8 files changed

+45
-21
lines changed

src/NonlinearSolve.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,8 @@ import PrecompileTools
8888
for T in (Float32, Float64)
8989
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
9090

91-
# precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
92-
# PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
93-
# DON'T MERGE
94-
precompile_algs = ()
91+
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
92+
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
9593

9694
for alg in precompile_algs
9795
solve(prob, alg, abstol = T(1e-2))

src/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ end
159159
]
160160
else
161161
[
162-
:(GeneralBroyden()),
163162
:(GeneralKlement()),
163+
:(GeneralBroyden()),
164164
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
165165
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
166166
:(TrustRegion(; linsolve, precs, adkwargs...)),

src/gaussnewton.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
8282
alg = get_concrete_algorithm(alg_, prob)
8383
@unpack f, u0, p = prob
8484

85-
# Use QR if the user did not specify a linear solver
86-
if alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
87-
alg.linsolve isa FastQRFactorization
85+
if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
8886
linsolve_with_JᵀJ = Val(false)
8987
else
9088
linsolve_with_JᵀJ = Val(true)

src/klement.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
7171
u = alias_u0 ? u0 : deepcopy(u0)
7272
fu = evaluate_f(prob, u)
7373
J = __init_identity_jacobian(u, fu)
74+
du = _mutable_zero(u)
7475

7576
if u isa Number
7677
linsolve = nothing
@@ -80,10 +81,10 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
8081
linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() :
8182
nothing
8283
alg = set_linsolve(alg_, linsolve_alg)
83-
linsolve = __setup_linsolve(J, _vec(fu), _vec(u), p, alg)
84+
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
8485
end
8586

86-
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,
87+
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), du, p, linsolve,
8788
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
8889
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
8990
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))

src/levenberg.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
164164
u = alias_u0 ? u0 : deepcopy(u0)
165165
fu1 = evaluate_f(prob, u)
166166

167-
# Use QR if the user did not specify a linear solver
168-
if (alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
169-
alg.linsolve isa FastQRFactorization) && !(u isa Number)
167+
if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
170168
linsolve_with_JᵀJ = Val(false)
171169
else
172170
linsolve_with_JᵀJ = Val(true)

src/utils.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,30 @@ function _try_factorize_and_check_singular!(linsolve, X)
256256
return _issingular(X), false
257257
end
258258
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false
259+
260+
# Needs Square Matrix
261+
"""
262+
needs_square_A(alg)
263+
264+
Returns `true` if the algorithm requires a square matrix.
265+
"""
266+
needs_square_A(::Nothing) = false
267+
function needs_square_A(alg)
268+
try
269+
A = [1.0 2.0;
270+
3.0 4.0;
271+
5.0 6.0]
272+
b = ones(Float64, 3)
273+
solve(LinearProblem(A, b), alg)
274+
return false
275+
catch err
276+
return true
277+
end
278+
end
279+
for alg in (:QRFactorization, :FastQRFactorization, NormalCholeskyFactorization,
280+
NormalBunchKaufmanFactorization)
281+
@eval needs_square_A(::$(alg)) = false
282+
end
283+
for kralg in (LinearSolve.Krylov.lsmr!, LinearSolve.Krylov.craigmr!)
284+
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
285+
end

test/23_test_problems.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,14 @@ end
5959
end
6060

6161
@testset "LevenbergMarquardt 23 Test Problems" begin
62-
alg_ops = (LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()),
63-
LevenbergMarquardt(; α_geodesic = 0.1, linsolve = NormalCholeskyFactorization()))
62+
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.1),
63+
LevenbergMarquardt(; linsolve = CholeskyFactorization()))
6464

6565
# dictionary with indices of test problems where method does not converge to small residual
6666
broken_tests = Dict(alg => Int[] for alg in alg_ops)
67-
broken_tests[alg_ops[1]] = [3, 6, 11, 21]
68-
broken_tests[alg_ops[2]] = [3, 6, 11, 21]
67+
broken_tests[alg_ops[1]] = [3, 6, 17, 21]
68+
broken_tests[alg_ops[2]] = [3, 6, 17, 21]
69+
broken_tests[alg_ops[3]] = [6, 11, 21]
6970

7071
test_on_library(problems, dicts, alg_ops, broken_tests)
7172
end

test/basictests.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ end
352352
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
353353
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
354354
probN = NonlinearProblem(quadratic_f, u0, 2.0)
355-
@test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0))
355+
@test all(solve(probN, LevenbergMarquardt(; autodiff); abstol = 1e-9,
356+
reltol = 1e-9).u .≈ sqrt(2.0))
356357
end
357358

358359
# Test that `LevenbergMarquardt` passes a test that `NewtonRaphson` fails on.
@@ -368,7 +369,7 @@ end
368369
@testset "Keyword Arguments" begin
369370
damping_initial = [0.5, 2.0, 5.0]
370371
damping_increase_factor = [1.5, 3.0, 10.0]
371-
damping_decrease_factor = Float64[2, 5, 10]
372+
damping_decrease_factor = Float64[2, 5, 12]
372373
finite_diff_step_geodesic = [0.02, 0.2, 0.3]
373374
α_geodesic = [0.6, 0.8, 0.9]
374375
b_uphill = Float64[0, 1, 2]
@@ -379,14 +380,14 @@ end
379380
min_damping_D)
380381
for options in list_of_options
381382
local probN, sol, alg
382-
alg = LevenbergMarquardt(damping_initial = options[1],
383+
alg = LevenbergMarquardt(; damping_initial = options[1],
383384
damping_increase_factor = options[2],
384385
damping_decrease_factor = options[3],
385386
finite_diff_step_geodesic = options[4], α_geodesic = options[5],
386387
b_uphill = options[6], min_damping_D = options[7])
387388

388389
probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
389-
sol = solve(probN, alg, abstol = 1e-10)
390+
sol = solve(probN, alg, abstol = 1e-12)
390391
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
391392
end
392393
end

0 commit comments

Comments
 (0)