Skip to content

Commit 6bff8ba

Browse files
Finish and test well-conditioned form
1 parent 10f769f commit 6bff8ba

File tree

6 files changed

+97
-14
lines changed

6 files changed

+97
-14
lines changed

src/LinearSolve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ end
106106

107107
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
108108
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
109+
NormalCholeskyFactorization, NormalBunchKaufmanFactorization,
109110
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
110111
SparspakFactorization, DiagonalFactorization
111112

@@ -119,4 +120,6 @@ export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
119120

120121
export HYPREAlgorithm
121122

123+
export OperatorAssumptions, OperatorCondition
124+
122125
end

src/common.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ __issquare(::OperatorAssumptions{issq,condition}) where {issq,condition} = issq
6363
__conditioning(::OperatorAssumptions{issq,condition}) where {issq,condition} = condition
6464

6565

66-
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
66+
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, condition}
6767
A::TA
6868
b::Tb
6969
u::Tu
@@ -77,7 +77,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
7777
reltol::Ttol
7878
maxiters::Int
7979
verbose::Bool
80-
assumptions::OperatorAssumptions{issq}
80+
assumptions::OperatorAssumptions{issq,condition}
8181
end
8282

8383
"""
@@ -143,9 +143,12 @@ default_tol(::Type{<:Rational}) = 0
143143
default_tol(::Type{<:Integer}) = 0
144144
default_tol(::Type{Any}) = 0
145145

146-
function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm, Nothing},
146+
default_alias_A(::Any,::Any,::Any) = false
147+
default_alias_b(::Any,::Any,::Any) = false
148+
149+
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
147150
args...;
148-
alias_A = false, alias_b = false,
151+
alias_A = default_alias_A(alg, prob.A, prob.b), alias_b = default_alias_b(alg, prob.A, prob.b),
149152
abstol = default_tol(eltype(prob.A)),
150153
reltol = default_tol(eltype(prob.A)),
151154
maxiters::Int = length(prob.b),
@@ -187,7 +190,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
187190
typeof(Pl),
188191
typeof(Pr),
189192
typeof(reltol),
190-
__issquare(assumptions)
193+
__issquare(assumptions),
194+
__conditioning(assumptions)
191195
}(A,
192196
b,
193197
u0,

src/default.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ function defaultalg(A, b, assump::OperatorAssumptions{true})
220220
alg
221221
end
222222

223+
function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.WellConditioned})
224+
NormalCholeskyFactorization()
225+
end
226+
223227
function defaultalg(A, b, ::OperatorAssumptions{false,OperatorCondition.IllConditioned})
224228
QRFactorization()
225229
end
@@ -234,6 +238,14 @@ end
234238

235239
## Catch high level interface
236240

241+
function SciMLBase.init(prob::LinearProblem, alg::Nothing,
242+
args...;
243+
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
244+
kwargs...)
245+
alg = defaultalg(prob.A, prob.b, assumptions)
246+
SciMLBase.init(prob, alg, args...; assumptions, kwargs...)
247+
end
248+
237249
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
238250
args...; assumptions::OperatorAssumptions = OperatorAssumptions(),
239251
kwargs...)

src/factorization.jl

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ function init_cacheval(alg::GenericFactorization,
323323
end
324324

325325
# Cholesky needs the posdef matrix, for GenericFactorization assume structure is needed
326-
function init_cacheval(alg::Union{GenericFactorization,
327-
GenericFactorization{typeof(cholesky)},
326+
function init_cacheval(alg::Union{GenericFactorization{typeof(cholesky)},
328327
GenericFactorization{typeof(cholesky!)}}, A, b, u, Pl, Pr,
329328
maxiters::Int, abstol, reltol, verbose::Bool,
330329
assumptions::OperatorAssumptions)
@@ -476,20 +475,22 @@ end
476475

477476
struct NormalCholeskyFactorization{P} <: AbstractFactorization
478477
pivot::P
479-
perm::Bool
480478
end
481479

482-
function NormalCholeskyFactorization(; pivot = nothing, perm = nothing)
480+
function NormalCholeskyFactorization(; pivot = nothing)
483481
if pivot === nothing
484-
@static if VERSION < v"1.7beta"
482+
pivot = @static if VERSION < v"1.7beta"
485483
Val(true)
486484
else
487485
RowMaximum()
488486
end
489487
end
490-
NormalCholeskyFactorization(pivot, perm)
488+
NormalCholeskyFactorization(pivot)
491489
end
492490

491+
default_alias_A(::NormalCholeskyFactorization,::Any,::Any) = true
492+
default_alias_b(::NormalCholeskyFactorization,::Any,::Any) = true
493+
493494
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
494495
maxiters::Int, abstol, reltol, verbose::Bool,
495496
assumptions::OperatorAssumptions)
@@ -500,9 +501,48 @@ function SciMLBase.solve(cache::LinearCache, alg::NormalCholeskyFactorization;
500501
kwargs...)
501502
A = cache.A
502503
A = convert(AbstractMatrix, A)
503-
fact, ipiv = cache.cacheval
504504
if cache.isfresh
505-
fact = cholesky(Symmetric((A)' * A), alg.pivot)
505+
if A isa SparseMatrixCSC
506+
fact = cholesky(Symmetric((A)' * A))
507+
else
508+
fact = cholesky(Symmetric((A)' * A), alg.pivot)
509+
end
510+
cache = set_cacheval(cache, fact)
511+
end
512+
if A isa SparseMatrixCSC
513+
cache.u .= cache.cacheval \ (A' * cache.b)
514+
y = cache.u
515+
else
516+
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)
517+
end
518+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
519+
end
520+
521+
## NormalBunchKaufmanFactorization
522+
523+
struct NormalBunchKaufmanFactorization <: AbstractFactorization
524+
rook::Bool
525+
end
526+
527+
function NormalBunchKaufmanFactorization(; rook = false)
528+
NormalBunchKaufmanFactorization(rook)
529+
end
530+
531+
default_alias_A(::NormalBunchKaufmanFactorization,::Any,::Any) = true
532+
default_alias_b(::NormalBunchKaufmanFactorization,::Any,::Any) = true
533+
534+
function init_cacheval(alg::NormalBunchKaufmanFactorization, A, b, u, Pl, Pr,
535+
maxiters::Int, abstol, reltol, verbose::Bool,
536+
assumptions::OperatorAssumptions)
537+
ArrayInterface.bunchkaufman_instance(convert(AbstractMatrix, A))
538+
end
539+
540+
function SciMLBase.solve(cache::LinearCache, alg::NormalBunchKaufmanFactorization;
541+
kwargs...)
542+
A = cache.A
543+
A = convert(AbstractMatrix, A)
544+
if cache.isfresh
545+
fact = bunchkaufman(Symmetric((A)' * A), alg.rook)
506546
cache = set_cacheval(cache, fact)
507547
end
508548
y = ldiv!(cache.u, cache.cacheval, A' * cache.b)

src/iterative_wrappers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
1515
args, kwargs)
1616
end
1717

18+
default_alias_A(::KrylovJL,::Any,::Any) = true
19+
default_alias_b(::KrylovJL,::Any,::Any) = true
20+
1821
function KrylovJL_CG(args...; kwargs...)
1922
KrylovJL(args...; KrylovAlg = Krylov.cg!, kwargs...)
2023
end
@@ -205,6 +208,9 @@ function IterativeSolversJL(args...;
205208
args, kwargs)
206209
end
207210

211+
default_alias_A(::IterativeSolversJL,::Any,::Any) = true
212+
default_alias_b(::IterativeSolversJL,::Any,::Any) = true
213+
208214
function IterativeSolversJL_CG(args...; kwargs...)
209215
IterativeSolversJL(args...;
210216
generate_iterator = IterativeSolvers.cg_iterator!,
@@ -312,6 +318,9 @@ function KrylovKitJL(args...;
312318
return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs)
313319
end
314320

321+
default_alias_A(::KrylovKitJL,::Any,::Any) = true
322+
default_alias_b(::KrylovKitJL,::Any,::Any) = true
323+
315324
function KrylovKitJL_CG(args...; kwargs...)
316325
KrylovKitJL(args...; KrylovAlg = KrylovKit.CG, kwargs..., isposdef = true)
317326
end

test/nonsquare.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,19 @@ A = sprandn(1000, 100, 0.1)
2929
b = randn(1001)
3030
prob = LinearProblem(A, view(b, 1:1000))
3131
linsolve = init(prob, QRFactorization())
32-
solve(linsolve)
32+
solve(linsolve)
33+
34+
A = randn(1000, 100)
35+
b = randn(1000)
36+
@test isapprox(solve(LinearProblem(A, b)).u, Symmetric(A' * A) \ (A' * b))
37+
solve(LinearProblem(A, b)).u;
38+
solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
39+
solve(LinearProblem(A, b), (LinearSolve.NormalBunchKaufmanFactorization())).u;
40+
solve(LinearProblem(A, b), assumptions = (OperatorAssumptions(false; condition = OperatorCondition.WellConditioned))).u;
41+
42+
A = sprandn(5000, 100, 0.1)
43+
b = randn(5000)
44+
@test isapprox(solve(LinearProblem(A, b)).u, ldlt(A' * A) \ (A' * b))
45+
solve(LinearProblem(A, b)).u;
46+
solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
47+
solve(LinearProblem(A, b), assumptions = (OperatorAssumptions(false; condition = OperatorCondition.WellConditioned))).u;

0 commit comments

Comments
 (0)