Skip to content

Commit ee5e5b8

Browse files
committed
Reuse LU Factorization to check for singular matrix
1 parent 00852f0 commit ee5e5b8

File tree

8 files changed

+83
-29
lines changed

8 files changed

+83
-29
lines changed

docs/src/api/nonlinearsolve.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ These are the native solvers of NonlinearSolve.jl.
88
NewtonRaphson
99
TrustRegion
1010
PseudoTransient
11+
DFSane
12+
GeneralBroyden
13+
GeneralKlement
1114
```
1215

1316
## Polyalgorithms

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import ArrayInterface: restructure
99
import ForwardDiff
1010

1111
import ADTypes: AbstractFiniteDifferencesMode
12-
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
12+
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable, issingular
1313
import ConcreteStructs: @concrete
1414
import EnumX: @enumx
1515
import ForwardDiff: Dual

src/dfsane.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ See also the implementation in [SimpleNonlinearSolve.jl](https://github.com/SciM
4343
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
4444
algorithm. Defaults to `1000`.
4545
"""
46-
4746
struct DFSane{T, F} <: AbstractNonlinearSolveAlgorithm
4847
σ_min::T
4948
σ_max::T

src/klement.jl

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,36 @@
1+
"""
2+
GeneralKlement(; max_resets = 5, linsolve = nothing,
3+
linesearch = LineSearch(), precs = DEFAULT_PRECS)
4+
5+
An implementation of `Klement` with line search, preconditioning and customizable linear
6+
solves.
7+
8+
## Keyword Arguments
9+
10+
- `max_resets`: the maximum number of resets to perform. Defaults to `5`.
11+
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
12+
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
13+
LinearSolve.jl default algorithm choice. For more information on available algorithm
14+
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
15+
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
16+
preconditioners. For more information on specifying preconditioners for LinearSolve
17+
algorithms, consult the
18+
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
19+
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
20+
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
21+
used here directly, and they will be converted to the correct `LineSearch`.
22+
"""
123
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
224
max_resets::Int
325
linsolve
426
precs
527
linesearch
6-
singular_tolerance
728
end
829

930
function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,
10-
linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing)
31+
linesearch = LineSearch(), precs = DEFAULT_PRECS)
1132
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
12-
return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance)
33+
return GeneralKlement(max_resets, linsolve, precs, linesearch)
1334
end
1435

1536
@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -27,7 +48,6 @@ end
2748
Jᵀ²du
2849
Jdu
2950
resets
30-
singular_tolerance
3151
force_stop
3252
maxiters::Int
3353
internalnorm
@@ -51,20 +71,20 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlemen
5171
if u isa Number
5272
linsolve = nothing
5373
else
74+
# For General Julia Arrays default to LU Factorization
75+
linsolve_alg = alg.linsolve === nothing && u isa Array ? LUFactorization() :
76+
nothing
5477
weight = similar(u)
5578
recursivefill!(weight, true)
5679
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
5780
nothing)..., weight)
5881
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
59-
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
82+
linsolve = init(linprob, linsolve_alg; alias_A = true, alias_b = true, Pl, Pr,
6083
linsolve_kwargs...)
6184
end
6285

63-
singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) :
64-
eltype(u)(alg.singular_tolerance)
65-
6686
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,
67-
J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false,
87+
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
6888
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
6989
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
7090
end
@@ -73,21 +93,23 @@ function perform_step!(cache::GeneralKlementCache{true})
7393
@unpack u, fu, f, p, alg, J, linsolve, du = cache
7494
T = eltype(J)
7595

76-
# FIXME: How can we do this faster?
77-
if cond(J) > cache.singular_tolerance
96+
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
97+
98+
if singular
7899
if cache.resets == alg.max_resets
79100
cache.force_stop = true
80101
cache.retcode = ReturnCode.Unstable
81102
return nothing
82103
end
104+
fact_done = false
83105
fill!(J, zero(T))
84106
J[diagind(J)] .= T(1)
85107
cache.resets += 1
86108
end
87109

88110
# u = u - J \ fu
89-
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du),
90-
p, reltol = cache.abstol)
111+
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
112+
b = -_vec(fu), linu = _vec(du), p, reltol = cache.abstol)
91113
cache.linsolve = linres.cache
92114

93115
# Line Search
@@ -108,7 +130,8 @@ function perform_step!(cache::GeneralKlementCache{true})
108130
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
109131
mul!(cache.Jdu, J, _vec(du))
110132
cache.fu .= cache.fu2 .- cache.fu
111-
cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
133+
cache.fu .= _restructure(cache.fu,
134+
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(T)))
112135
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
113136
cache.J_cache .*= J
114137
mul!(cache.J_cache2, cache.J_cache, J)
@@ -123,23 +146,25 @@ function perform_step!(cache::GeneralKlementCache{false})
123146
@unpack fu, f, p, alg, J, linsolve = cache
124147
T = eltype(J)
125148

126-
# FIXME: How can we do this faster?
127-
if cond(J) > cache.singular_tolerance
149+
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
150+
151+
if singular
128152
if cache.resets == alg.max_resets
129153
cache.force_stop = true
130154
cache.retcode = ReturnCode.Unstable
131155
return nothing
132156
end
133-
cache.J = __init_identity_jacobian(u, fu)
157+
fact_done = false
158+
cache.J = __init_identity_jacobian(cache.u, fu)
134159
cache.resets += 1
135160
end
136161

137162
# u = u - J \ fu
138163
if linsolve === nothing
139164
cache.du = -fu / cache.J
140165
else
141-
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu),
142-
linu = _vec(cache.du), p, reltol = cache.abstol)
166+
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
167+
b = -_vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
143168
cache.linsolve = linres.cache
144169
end
145170

@@ -161,7 +186,8 @@ function perform_step!(cache::GeneralKlementCache{false})
161186
cache.Jᵀ²du = cache.J_cache * cache.Jdu
162187
cache.Jdu = J * _vec(cache.du)
163188
cache.fu = cache.fu2 .- cache.fu
164-
cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
189+
cache.fu = _restructure(cache.fu,
190+
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(T)))
165191
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
166192
cache.J = J .+ cache.J_cache
167193

src/trustRegion.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
8181
end
8282

8383
"""
84-
```julia
8584
TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
8685
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple,
8786
max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1,

src/utils.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,28 @@ function __init_identity_jacobian(u::StaticArray, fu)
221221
return convert(MArray{Tuple{length(fu), length(u)}},
222222
Matrix{eltype(u)}(I, length(fu), length(u)))
223223
end
224+
225+
# Check Singular Matrix
226+
_issingular(x::Number) = iszero(x)
227+
@generated function _issingular(x::T) where {T}
228+
hasmethod(issingular, Tuple{T}) && return :(issingular(x))
229+
return :(__issingular(x))
230+
end
231+
__issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(T)))
232+
__issingular(x) = false ## If SciMLOperator and such
233+
234+
# If factorization is LU then perform that and update the linsolve cache
235+
# else check if the matrix is singular
236+
function _try_factorize_and_check_singular!(linsolve, X)
237+
if linsolve.cacheval isa LU
238+
# LU Factorization was used
239+
linsolve.A = X
240+
linsolve.cacheval = LinearSolve.do_factorization(linsolve.alg, X, linsolve.b,
241+
linsolve.u)
242+
linsolve.isfresh = false
243+
244+
return !issuccess(linsolve.cacheval), true
245+
end
246+
return _issingular(X), false
247+
end
248+
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

test/23_test_problems.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,21 @@ end
8787

8888
broken_tests = Dict(alg => Int[] for alg in alg_ops)
8989
broken_tests[alg_ops[1]] = [1, 3, 4, 5, 6, 8, 11, 12, 13, 14, 21]
90-
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 22]
90+
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 16, 21, 22]
9191
broken_tests[alg_ops[3]] = [1, 2, 4, 5, 6, 8, 11, 12, 13, 14, 21]
9292

9393
test_on_library(problems, dicts, alg_ops, broken_tests)
9494
end
9595

9696
@testset "GeneralKlement 23 Test Problems" begin
9797
alg_ops = (GeneralKlement(),
98-
GeneralKlement(; linesearch = BackTracking()))
98+
GeneralKlement(; linesearch = BackTracking()),
99+
GeneralKlement(; linesearch = HagerZhang()))
99100

100101
broken_tests = Dict(alg => Int[] for alg in alg_ops)
101-
broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 13, 22]
102-
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 22]
102+
broken_tests[alg_ops[1]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
103+
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
104+
broken_tests[alg_ops[3]] = [1, 2, 5, 6, 11, 12, 13, 22]
103105

104106
test_on_library(problems, dicts, alg_ops, broken_tests)
105107
end

test/matrix_resizing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ vecprob = NonlinearProblem(ff, vec(u0), p)
77
prob = NonlinearProblem(ff, u0, p)
88

99
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
10-
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden())
10+
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden(), GeneralKlement())
1111
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
1212
end
1313

@@ -18,6 +18,6 @@ vecprob = NonlinearProblem(fiip, vec(u0), p)
1818
prob = NonlinearProblem(fiip, u0, p)
1919

2020
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
21-
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden())
21+
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden(), GeneralKlement())
2222
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
2323
end

0 commit comments

Comments
 (0)