Skip to content

Commit af79d26

Browse files
Merge pull request #290 from SciML/normalcholesky
Add NormalCholeskyFactorization and WellConditioned defaults
2 parents 8277c12 + f322e55 commit af79d26

12 files changed

+197
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "1.40.0"
4+
version = "1.41.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

docs/src/basics/FAQ.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ as otherwise that will need to be determined at runtime.
1717
## I found a faster algorithm that can be used than what LinearSolve.jl chose?
1818

1919
What assumptions are made as part of your method? If your method only works on well-conditioned operators, then
20-
make sure you set the `WellConditioned` assumption in the `assumptions`. See the
20+
make sure you set the `WellConditioned` assumption in the `assumptions`. See the
2121
[OperatorAssumptions page for more details](@ref assumptions). If using the right assumptions does not improve
2222
the performance to the expected state, please open an issue and we will improve the default algorithm.
2323

docs/src/basics/OperatorAssumptions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ IllConditioned
1212
VeryIllConditioned
1313
SuperIllConditioned
1414
WellConditioned
15-
```
15+
```

docs/src/basics/common_solver_opts.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The following are the options these algorithms take, along with their defaults.
1616
- `verbose`: Whether to print extra information. Defaults to `false`.
1717
- `assumptions`: Sets the assumptions of the operator in order to effect the default
1818
choice algorithm. See the [Operator Assumptions page for more details](@ref assumptions).
19+
1920
## Iterative Solver Controls
2021

2122
Error controls are not used by all algorithms. Specifically, direct solves always

ext/LinearSolveHYPRE.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using HYPRE.LibHYPRE: HYPRE_Complex
44
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
55
using IterativeSolvers: Identity
66
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
7-
OperatorAssumptions, default_tol, init_cacheval, __issquare, set_cacheval
7+
OperatorAssumptions, default_tol, init_cacheval, __issquare,
8+
__conditioning, set_cacheval
89
using SciMLBase: LinearProblem, SciMLBase
910
using UnPack: @unpack
1011
using Setfield: @set!
@@ -82,7 +83,8 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
8283

8384
cache = LinearCache{
8485
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
85-
typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions)
86+
typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions),
87+
__conditioning(assumptions)
8688
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
8789
maxiters,
8890
verbose, assumptions)

src/LinearSolve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ end
106106

107107
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
108108
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
109+
NormalCholeskyFactorization, NormalBunchKaufmanFactorization,
109110
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
110111
SparspakFactorization, DiagonalFactorization
111112

@@ -119,4 +120,6 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
119120

120121
export HYPREAlgorithm
121122

123+
export OperatorAssumptions, OperatorCondition
124+
122125
end

src/common.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,17 @@ end
5353
5454
Sets the operator `A` assumptions used as part of the default algorithm
5555
"""
56-
struct OperatorAssumptions{issq,condition} end
57-
function OperatorAssumptions(issquare = nothing; condition::OperatorCondition.T = OperatorCondition.IllConditioned)
56+
struct OperatorAssumptions{issq, condition} end
57+
function OperatorAssumptions(issquare = nothing;
58+
condition::OperatorCondition.T = OperatorCondition.IllConditioned)
5859
issq = something(_unwrap_val(issquare), Nothing)
5960
condition = _unwrap_val(condition)
60-
OperatorAssumptions{issq,condition}()
61+
OperatorAssumptions{issq, condition}()
6162
end
62-
__issquare(::OperatorAssumptions{issq,condition}) where {issq,condition} = issq
63-
__conditioning(::OperatorAssumptions{issq,condition}) where {issq,condition} = condition
63+
__issquare(::OperatorAssumptions{issq, condition}) where {issq, condition} = issq
64+
__conditioning(::OperatorAssumptions{issq, condition}) where {issq, condition} = condition
6465

65-
66-
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
66+
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, condition}
6767
A::TA
6868
b::Tb
6969
u::Tu
@@ -77,7 +77,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
7777
reltol::Ttol
7878
maxiters::Int
7979
verbose::Bool
80-
assumptions::OperatorAssumptions{issq}
80+
assumptions::OperatorAssumptions{issq, condition}
8181
end
8282

8383
"""
@@ -143,9 +143,13 @@ default_tol(::Type{<:Rational}) = 0
143143
default_tol(::Type{<:Integer}) = 0
144144
default_tol(::Type{Any}) = 0
145145

146-
function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm, Nothing},
146+
default_alias_A(::Any, ::Any, ::Any) = false
147+
default_alias_b(::Any, ::Any, ::Any) = false
148+
149+
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
147150
args...;
148-
alias_A = false, alias_b = false,
151+
alias_A = default_alias_A(alg, prob.A, prob.b),
152+
alias_b = default_alias_b(alg, prob.A, prob.b),
149153
abstol = default_tol(eltype(prob.A)),
150154
reltol = default_tol(eltype(prob.A)),
151155
maxiters::Int = length(prob.b),
@@ -187,7 +191,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
187191
typeof(Pl),
188192
typeof(Pr),
189193
typeof(reltol),
190-
__issquare(assumptions)
194+
__issquare(assumptions),
195+
__conditioning(assumptions)
191196
}(A,
192197
b,
193198
u0,

src/default.jl

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
7474
end
7575
end
7676

77-
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{true,OperatorCondition.IllConditioned})
77+
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b,
78+
assump::OperatorAssumptions{true, OperatorCondition.IllConditioned})
7879
QRFactorization()
7980
end
8081

@@ -86,7 +87,8 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssump
8687
end
8788
end
8889

89-
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{true,OperatorCondition.IllConditioned})
90+
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray,
91+
assump::OperatorAssumptions{true, OperatorCondition.IllConditioned})
9092
QRFactorization()
9193
end
9294

@@ -130,7 +132,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.Abstract
130132
end
131133

132134
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
133-
::OperatorAssumptions{true,OperatorCondition.IllConditioned})
135+
::OperatorAssumptions{true, OperatorCondition.IllConditioned})
134136
QRFactorization()
135137
end
136138

@@ -155,17 +157,50 @@ function defaultalg(A, b, assump::OperatorAssumptions{true})
155157
# whether MKL or OpenBLAS is being used
156158
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
157159
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
158-
ArrayInterface.can_setindex(b) && __conditioning(assump) === OperatorCondition.IllConditioned
160+
ArrayInterface.can_setindex(b) &&
161+
(__conditioning(assump) === OperatorCondition.IllConditioned ||
162+
__conditioning(assump) === OperatorCondition.WellConditioned)
159163
if length(b) <= 10
160-
alg = GenericLUFactorization()
164+
pivot = @static if VERSION < v"1.7beta"
165+
if __conditioning(assump) === OperatorCondition.IllConditioned
166+
Val(true)
167+
else
168+
Val(false)
169+
end
170+
else
171+
if __conditioning(assump) === OperatorCondition.IllConditioned
172+
RowMaximum()
173+
else
174+
RowNonZero()
175+
end
176+
end
177+
alg = GenericLUFactorization(pivot)
161178
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
162179
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
163180
eltype(A) <: Union{Float32, Float64})
164-
alg = RFLUFactorization()
181+
pivot = if __conditioning(assump) === OperatorCondition.IllConditioned
182+
Val(true)
183+
else
184+
Val(false)
185+
end
186+
alg = RFLUFactorization(; pivot = pivot)
165187
#elseif A === nothing || A isa Matrix
166188
# alg = FastLUFactorization()
167189
else
168-
alg = LUFactorization()
190+
pivot = @static if VERSION < v"1.7beta"
191+
if __conditioning(assump) === OperatorCondition.IllConditioned
192+
Val(true)
193+
else
194+
Val(false)
195+
end
196+
else
197+
if __conditioning(assump) === OperatorCondition.IllConditioned
198+
RowMaximum()
199+
else
200+
RowNonZero()
201+
end
202+
end
203+
alg = LUFactorization(pivot)
169204
end
170205
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
171206
alg = QRFactorization()
@@ -187,20 +222,34 @@ function defaultalg(A, b, assump::OperatorAssumptions{true})
187222
alg
188223
end
189224

190-
function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.IllConditioned})
225+
function defaultalg(A, b, ::OperatorAssumptions{false, OperatorCondition.WellConditioned})
226+
NormalCholeskyFactorization()
227+
end
228+
229+
function defaultalg(A, b, ::OperatorAssumptions{false, OperatorCondition.IllConditioned})
191230
QRFactorization()
192231
end
193232

194-
function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.VeryIllConditioned})
233+
function defaultalg(A, b,
234+
::OperatorAssumptions{false, OperatorCondition.VeryIllConditioned})
195235
QRFactorization()
196236
end
197237

198-
function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.SuperIllConditioned})
238+
function defaultalg(A, b,
239+
::OperatorAssumptions{false, OperatorCondition.SuperIllConditioned})
199240
SVDFactorization(false, LinearAlgebra.QRIteration())
200241
end
201242

202243
## Catch high level interface
203244

245+
function SciMLBase.init(prob::LinearProblem, alg::Nothing,
246+
args...;
247+
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
248+
kwargs...)
249+
alg = defaultalg(prob.A, prob.b, assumptions)
250+
SciMLBase.init(prob, alg, args...; assumptions, kwargs...)
251+
end
252+
204253
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
205254
args...; assumptions::OperatorAssumptions = OperatorAssumptions(),
206255
kwargs...)

src/factorization.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,84 @@ function SciMLBase.solve(cache::LinearCache, alg::RFLUFactorization{P, T};
472472
SciMLBase.build_linear_solution(alg, y, nothing, cache)
473473
end
474474

475+
## NormalCholeskyFactorization
476+
477+
struct NormalCholeskyFactorization{P} <: AbstractFactorization
478+
pivot::P
479+
end
480+
481+
function NormalCholeskyFactorization(; pivot = nothing)
482+
if pivot === nothing
483+
pivot = @static if VERSION < v"1.7beta"
484+
Val(true)
485+
else
486+
RowMaximum()
487+
end
488+
end
489+
NormalCholeskyFactorization(pivot)
490+
end
491+
492+
default_alias_A(::NormalCholeskyFactorization, ::Any, ::Any) = true
493+
default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true
494+
495+
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
496+
maxiters::Int, abstol, reltol, verbose::Bool,
497+
assumptions::OperatorAssumptions)
498+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
499+
end
500+
501+
function SciMLBase.solve(cache::LinearCache, alg::NormalCholeskyFactorization;
502+
kwargs...)
503+
A = cache.A
504+
A = convert(AbstractMatrix, A)
505+
if cache.isfresh
506+
if A isa SparseMatrixCSC
507+
fact = cholesky(Symmetric((A)' * A))
508+
else
509+
fact = cholesky(Symmetric((A)' * A), alg.pivot)
510+
end
511+
cache = set_cacheval(cache, fact)
512+
end
513+
if A isa SparseMatrixCSC
514+
cache.u .= cache.cacheval \ (A' * cache.b)
515+
y = cache.u
516+
else
517+
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)
518+
end
519+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
520+
end
521+
522+
## NormalBunchKaufmanFactorization
523+
524+
struct NormalBunchKaufmanFactorization <: AbstractFactorization
525+
rook::Bool
526+
end
527+
528+
function NormalBunchKaufmanFactorization(; rook = false)
529+
NormalBunchKaufmanFactorization(rook)
530+
end
531+
532+
default_alias_A(::NormalBunchKaufmanFactorization, ::Any, ::Any) = true
533+
default_alias_b(::NormalBunchKaufmanFactorization, ::Any, ::Any) = true
534+
535+
function init_cacheval(alg::NormalBunchKaufmanFactorization, A, b, u, Pl, Pr,
536+
maxiters::Int, abstol, reltol, verbose::Bool,
537+
assumptions::OperatorAssumptions)
538+
ArrayInterface.bunchkaufman_instance(convert(AbstractMatrix, A))
539+
end
540+
541+
function SciMLBase.solve(cache::LinearCache, alg::NormalBunchKaufmanFactorization;
542+
kwargs...)
543+
A = cache.A
544+
A = convert(AbstractMatrix, A)
545+
if cache.isfresh
546+
fact = bunchkaufman(Symmetric((A)' * A), alg.rook)
547+
cache = set_cacheval(cache, fact)
548+
end
549+
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)
550+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
551+
end
552+
475553
## DiagonalFactorization
476554

477555
struct DiagonalFactorization <: AbstractFactorization end

src/factorization_sparse.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ function _ldiv!(x::Vector,
88
end
99

1010
function _ldiv!(x::AbstractVector,
11-
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
12-
SuiteSparse.SPQR.QRSparse,
13-
SuiteSparse.CHOLMOD.Factor}, b::AbstractVector)
11+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
12+
SuiteSparse.SPQR.QRSparse,
13+
SuiteSparse.CHOLMOD.Factor}, b::AbstractVector)
1414
x .= A \ b
15-
end
15+
end

0 commit comments

Comments
 (0)