Skip to content

Commit 5b13e5d

Browse files
Add NormalCholeskyFactorization and WellConditioned defaults
Continuing on #289 to close #283
1 parent 8277c12 commit 5b13e5d

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

src/default.jl

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,52 @@ function defaultalg(A, b, assump::OperatorAssumptions{true})
155155
# whether MKL or OpenBLAS is being used
156156
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
157157
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
158-
ArrayInterface.can_setindex(b) && __conditioning(assump) === OperatorCondition.IllConditioned
158+
ArrayInterface.can_setindex(b) && (__conditioning(assump) === OperatorCondition.IllConditioned ||
159+
__conditioning(assump) === OperatorCondition.WellConditioned)
160+
159161
if length(b) <= 10
160-
alg = GenericLUFactorization()
162+
pivot = @static if VERSION < v"1.7beta"
163+
if __conditioning(assump) === OperatorCondition.IllConditioned
164+
Val(true)
165+
else
166+
Val(false)
167+
end
168+
else
169+
if __conditioning(assump) === OperatorCondition.IllConditioned
170+
RowMaximum()
171+
else
172+
RowNonZero()
173+
end
174+
end
175+
end
176+
alg = GenericLUFactorization(pivot)
161177
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
162178
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
163179
eltype(A) <: Union{Float32, Float64})
164-
alg = RFLUFactorization()
180+
pivot = if __conditioning(assump) === OperatorCondition.IllConditioned
181+
Val(true)
182+
else
183+
Val(false)
184+
end
185+
alg = RFLUFactorization(;pivot = pivot)
165186
#elseif A === nothing || A isa Matrix
166187
# alg = FastLUFactorization()
167188
else
168-
alg = LUFactorization()
189+
pivot = @static if VERSION < v"1.7beta"
190+
if __conditioning(assump) === OperatorCondition.IllConditioned
191+
Val(true)
192+
else
193+
Val(false)
194+
end
195+
else
196+
if __conditioning(assump) === OperatorCondition.IllConditioned
197+
RowMaximum()
198+
else
199+
RowNonZero()
200+
end
201+
end
202+
end
203+
alg = LUFactorization(pivot)
169204
end
170205
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
171206
alg = QRFactorization()

src/factorization.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,43 @@ function SciMLBase.solve(cache::LinearCache, alg::RFLUFactorization{P, T};
472472
SciMLBase.build_linear_solution(alg, y, nothing, cache)
473473
end
474474

475+
## NormalCholeskyFactorization
476+
477+
struct NormalCholeskyFactorization <: AbstractFactorization
478+
pivot::P
479+
perm::Bool
480+
end
481+
482+
function NormalCholeskyFactorization(; pivot = nothing, perm = nothing)
483+
if pivot === nothing
484+
@static if VERSION < v"1.7beta"
485+
Val(true)
486+
else
487+
RowMaximum()
488+
end
489+
end
490+
NormalCholeskyFactorization(pivot, perm)
491+
end
492+
493+
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
494+
maxiters::Int, abstol, reltol, verbose::Bool,
495+
assumptions::OperatorAssumptions)
496+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
497+
end
498+
499+
function SciMLBase.solve(cache::LinearCache, alg::NormalCholeskyFactorization;
500+
kwargs...)
501+
A = cache.A
502+
A = convert(AbstractMatrix, A)
503+
fact, ipiv = cache.cacheval
504+
if cache.isfresh
505+
fact = cholesky(Symmetric((A)' * A), alg.pivot)
506+
cache = set_cacheval(cache, fact)
507+
end
508+
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)
509+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
510+
end
511+
475512
## DiagonalFactorization
476513

477514
struct DiagonalFactorization <: AbstractFactorization end

0 commit comments

Comments
 (0)