Skip to content

Commit 3099e72

Browse files
committed
start fixing ambiguities
1 parent f87583e commit 3099e72

File tree

4 files changed

+122
-38
lines changed

4 files changed

+122
-38
lines changed

ext/LinearSolveBandedMatricesExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import LinearSolve: defaultalg,
55
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
66

77
# Defaults for BandedMatrices
8-
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions)
8+
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions{Bool})
99
if oa.issq
1010
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
1111
elseif LinearSolve.is_underdetermined(A)
@@ -15,7 +15,7 @@ function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions)
1515
end
1616
end
1717

18-
function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions)
18+
function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions{Bool})
1919
return DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization)
2020
end
2121

ext/LinearSolveFastAlmostBandedMatricesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using FastAlmostBandedMatrices, LinearAlgebra, LinearSolve
44
import LinearSolve: defaultalg,
55
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
66

7-
function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions)
7+
function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions{Bool})
88
if oa.issq
99
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
1010
else

src/default.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))
2828

2929
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
30-
assump::OperatorAssumptions)
30+
assump::OperatorAssumptions{Bool})
3131
defaultalg(A.A, b, assump)
3232
end
3333

@@ -36,41 +36,41 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing})
3636
defaultalg(A, b, OperatorAssumptions(issq, assump.condition))
3737
end
3838

39-
function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions)
39+
function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
4040
if assump.issq
4141
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
4242
else
4343
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
4444
end
4545
end
4646

47-
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions)
47+
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{Bool})
4848
DefaultLinearSolver(DefaultAlgorithmChoice.LDLtFactorization)
4949
end
50-
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions)
50+
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{Bool})
5151
DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
5252
end
53-
function defaultalg(A::Factorization, b, ::OperatorAssumptions)
53+
function defaultalg(A::Factorization, b, ::OperatorAssumptions{Bool})
5454
DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
5555
end
56-
function defaultalg(A::Diagonal, b, ::OperatorAssumptions)
56+
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Bool})
5757
DefaultLinearSolver(DefaultAlgorithmChoice.DiagonalFactorization)
5858
end
5959

60-
function defaultalg(A::Hermitian, b, ::OperatorAssumptions)
60+
function defaultalg(A::Hermitian, b, ::OperatorAssumptions{Bool})
6161
DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization)
6262
end
6363

64-
function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions)
64+
function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions{Bool})
6565
DefaultLinearSolver(DefaultAlgorithmChoice.BunchKaufmanFactorization)
6666
end
6767

68-
function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions)
68+
function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
6969
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
7070
end
7171

7272
function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
73-
assump::OperatorAssumptions) where {Tv, Ti}
73+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
7474
if assump.issq
7575
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
7676
else
@@ -80,7 +80,7 @@ end
8080

8181
@static if INCLUDE_SPARSE
8282
function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
83-
assump::OperatorAssumptions) where {Ti}
83+
assump::OperatorAssumptions{Bool}) where {Ti}
8484
if assump.issq
8585
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
8686
DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
@@ -93,7 +93,7 @@ end
9393
end
9494
end
9595

96-
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions)
96+
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool})
9797
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
9898
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
9999
else
@@ -102,7 +102,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
102102
end
103103

104104
# A === nothing case
105-
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions)
105+
function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool})
106106
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
107107
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
108108
else
@@ -112,7 +112,7 @@ end
112112

113113
# Ambiguity handling
114114
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
115-
assump::OperatorAssumptions)
115+
assump::OperatorAssumptions{Bool})
116116
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
117117
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
118118
else
@@ -121,7 +121,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.Abstract
121121
end
122122

123123
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
124-
assump::OperatorAssumptions)
124+
assump::OperatorAssumptions{Bool})
125125
if has_ldiv!(A)
126126
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
127127
elseif !assump.issq
@@ -137,7 +137,7 @@ function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
137137
end
138138

139139
# Allows A === nothing as a stand-in for dense matrix
140-
function defaultalg(A, b, assump::OperatorAssumptions)
140+
function defaultalg(A, b, assump::OperatorAssumptions{Bool})
141141
alg = if assump.issq
142142
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
143143
# it makes sense according to the benchmarks, which is dependent on

src/factorization.jl

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -402,15 +402,15 @@ function do_factorization(alg::GenericFactorization, A, b, u)
402402
return fact
403403
end
404404

405-
function init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr,
405+
function init_cacheval(alg::GenericFactorization{typeof(lu)}, A::AbstractMatrix, b, u, Pl, Pr,
406406
maxiters::Int,
407407
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
408-
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
408+
ArrayInterface.lu_instance(A)
409409
end
410-
function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr,
410+
function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::AbstractMatrix, b, u, Pl, Pr,
411411
maxiters::Int,
412412
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
413-
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
413+
ArrayInterface.lu_instance(A)
414414
end
415415

416416
function init_cacheval(alg::GenericFactorization{typeof(lu)},
@@ -445,16 +445,36 @@ function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b
445445
assumptions::OperatorAssumptions)
446446
ArrayInterface.lu_instance(A)
447447
end
448+
function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
449+
maxiters::Int, abstol, reltol, verbose::Bool,
450+
assumptions::OperatorAssumptions) where {T, V}
451+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
452+
end
453+
function init_cacheval(alg::GenericFactorization{typeof(lu)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
454+
maxiters::Int, abstol, reltol, verbose::Bool,
455+
assumptions::OperatorAssumptions) where {T, V}
456+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
457+
end
448458

449-
function init_cacheval(alg::GenericFactorization{typeof(qr)}, A, b, u, Pl, Pr,
459+
function init_cacheval(alg::GenericFactorization{typeof(qr)}, A::AbstractMatrix, b, u, Pl, Pr,
450460
maxiters::Int,
451461
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
452-
ArrayInterface.qr_instance(convert(AbstractMatrix, A))
462+
ArrayInterface.qr_instance(A)
453463
end
454-
function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A, b, u, Pl, Pr,
464+
function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::AbstractMatrix, b, u, Pl, Pr,
455465
maxiters::Int,
456466
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
457-
ArrayInterface.qr_instance(convert(AbstractMatrix, A))
467+
ArrayInterface.qr_instance(A)
468+
end
469+
function init_cacheval(alg::GenericFactorization{typeof(qr)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
470+
maxiters::Int, abstol, reltol, verbose::Bool,
471+
assumptions::OperatorAssumptions) where {T, V}
472+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
473+
end
474+
function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
475+
maxiters::Int, abstol, reltol, verbose::Bool,
476+
assumptions::OperatorAssumptions) where {T, V}
477+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
458478
end
459479

460480
function init_cacheval(alg::GenericFactorization{typeof(qr)},
@@ -490,15 +510,15 @@ function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::Tridiagonal, b
490510
ArrayInterface.qr_instance(A)
491511
end
492512

493-
function init_cacheval(alg::GenericFactorization{typeof(svd)}, A, b, u, Pl, Pr,
513+
function init_cacheval(alg::GenericFactorization{typeof(svd)}, A::AbstractMatrix, b, u, Pl, Pr,
494514
maxiters::Int,
495515
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
496-
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
516+
ArrayInterface.svd_instance(A)
497517
end
498-
function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A, b, u, Pl, Pr,
518+
function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::AbstractMatrix, b, u, Pl, Pr,
499519
maxiters::Int,
500520
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
501-
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
521+
ArrayInterface.svd_instance(A)
502522
end
503523

504524
function init_cacheval(alg::GenericFactorization{typeof(svd)},
@@ -534,6 +554,16 @@ function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::Tridiagonal,
534554
assumptions::OperatorAssumptions)
535555
ArrayInterface.svd_instance(A)
536556
end
557+
function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
558+
maxiters::Int, abstol, reltol, verbose::Bool,
559+
assumptions::OperatorAssumptions) where {T, V}
560+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
561+
end
562+
function init_cacheval(alg::GenericFactorization{typeof(svd)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
563+
maxiters::Int, abstol, reltol, verbose::Bool,
564+
assumptions::OperatorAssumptions) where {T, V}
565+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
566+
end
537567

538568
function init_cacheval(alg::GenericFactorization, A::Diagonal, b, u, Pl, Pr, maxiters::Int,
539569
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
@@ -549,6 +579,18 @@ function init_cacheval(alg::GenericFactorization, A::SymTridiagonal{T, V}, b, u,
549579
assumptions::OperatorAssumptions) where {T, V}
550580
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
551581
end
582+
function init_cacheval(alg::GenericFactorization, A, b, u, Pl, Pr,
583+
maxiters::Int,
584+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
585+
init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
586+
maxiters::Int, abstol, reltol, verbose::Bool,
587+
assumptions::OperatorAssumptions)
588+
end
589+
function init_cacheval(alg::GenericFactorization, A::AbstractMatrix, b, u, Pl, Pr,
590+
maxiters::Int,
591+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
592+
do_factorization(alg, A, b, u)
593+
end
552594

553595
function init_cacheval(alg::Union{GenericFactorization{typeof(bunchkaufman!)},
554596
GenericFactorization{typeof(bunchkaufman)}},
@@ -573,15 +615,49 @@ end
573615
# Try to never use it.
574616

575617
# Cholesky needs the posdef matrix, for GenericFactorization assume structure is needed
576-
function init_cacheval(alg::Union{GenericFactorization{typeof(cholesky)},
577-
GenericFactorization{typeof(cholesky!)}}, A, b, u, Pl, Pr,
578-
maxiters::Int, abstol, reltol, verbose::Bool,
618+
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::AbstractMatrix, b, u, Pl, Pr,
619+
maxiters::Int, abstol, reltol, verbose::Bool,
620+
assumptions::OperatorAssumptions)
621+
newA = copy(convert(AbstractMatrix, A))
622+
do_factorization(alg, newA, b, u)
623+
end
624+
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::AbstractMatrix, b, u, Pl, Pr,
625+
maxiters::Int, abstol, reltol, verbose::Bool,
579626
assumptions::OperatorAssumptions)
580627
newA = copy(convert(AbstractMatrix, A))
581628
do_factorization(alg, newA, b, u)
582629
end
630+
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::Diagonal, b, u, Pl, Pr, maxiters::Int,
631+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
632+
Diagonal(inv.(A.diag))
633+
end
634+
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::Tridiagonal, b, u, Pl, Pr,
635+
maxiters::Int,
636+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
637+
ArrayInterface.lu_instance(A)
638+
end
639+
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
640+
maxiters::Int, abstol, reltol, verbose::Bool,
641+
assumptions::OperatorAssumptions) where {T, V}
642+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
643+
end
644+
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::Diagonal, b, u, Pl, Pr, maxiters::Int,
645+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
646+
Diagonal(inv.(A.diag))
647+
end
648+
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::Tridiagonal, b, u, Pl, Pr,
649+
maxiters::Int,
650+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
651+
ArrayInterface.lu_instance(A)
652+
end
653+
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
654+
maxiters::Int, abstol, reltol, verbose::Bool,
655+
assumptions::OperatorAssumptions) where {T, V}
656+
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
657+
end
658+
583659

584-
function init_cacheval(alg::Union{GenericFactorization},
660+
function init_cacheval(alg::GenericFactorization,
585661
A::Union{Hermitian{T, <:SparseMatrixCSC},
586662
Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr,
587663
maxiters::Int, abstol, reltol, verbose::Bool,
@@ -1063,21 +1139,29 @@ end
10631139
# but QRFactorization uses 16.
10641140
FastQRFactorization() = FastQRFactorization(NoPivot(), 36)
10651141

1066-
function init_cacheval(alg::FastQRFactorization{NoPivot}, A, b, u, Pl, Pr,
1142+
function init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr,
10671143
maxiters::Int, abstol, reltol, verbose::Bool,
10681144
assumptions::OperatorAssumptions)
10691145
ws = QRWYWs(A; blocksize = alg.blocksize)
10701146
return WorkspaceAndFactors(ws,
10711147
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
10721148
end
1073-
function init_cacheval(::FastQRFactorization{ColumnNorm}, A, b, u, Pl, Pr,
1149+
function init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr,
10741150
maxiters::Int, abstol, reltol, verbose::Bool,
10751151
assumptions::OperatorAssumptions)
10761152
ws = QRpWs(A)
10771153
return WorkspaceAndFactors(ws,
10781154
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
10791155
end
10801156

1157+
function init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr,
1158+
maxiters::Int, abstol, reltol, verbose::Bool,
1159+
assumptions::OperatorAssumptions)
1160+
return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
1161+
maxiters::Int, abstol, reltol, verbose::Bool,
1162+
assumptions::OperatorAssumptions)
1163+
end
1164+
10811165
function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P};
10821166
kwargs...) where {P}
10831167
A = cache.A
@@ -1184,4 +1268,4 @@ for alg in InteractiveUtils.subtypes(AbstractFactorization)
11841268
maxiters::Int, abstol, reltol, verbose::Bool,
11851269
assumptions::OperatorAssumptions)
11861270
end
1187-
end
1271+
end

0 commit comments

Comments
 (0)