Skip to content

Commit f4f6940

Browse files
Merge pull request #393 from avik-pal/ap/banded
Proper handling for BandedMatrices
2 parents 0843b53 + 7d766d8 commit f4f6940

File tree

6 files changed

+84
-5
lines changed

6 files changed

+84
-5
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
- {user: SciML, repo: OrdinaryDiffEq.jl, group: InterfaceII}
2020
- {user: SciML, repo: ModelingToolkit.jl, group: All}
2121
- {user: SciML, repo: SciMLSensitivity.jl, group: Core1}
22+
- {user: SciML, repo: BoundaryValueDiffEq.jl, group: All}
2223

2324
steps:
2425
- uses: actions/checkout@v4

Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.10.0"
4+
version = "2.11.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -31,17 +31,20 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
3131
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3232

3333
[weakdeps]
34+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3435
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
35-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3636
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
37+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3738
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3839
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
3940
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4041
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
4142
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4243
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
44+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4345

4446
[extensions]
47+
LinearSolveBandedMatricesExt = "BandedMatrices"
4548
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4649
LinearSolveCUDAExt = "CUDA"
4750
LinearSolveEnzymeExt = "Enzyme"
@@ -51,9 +54,11 @@ LinearSolveKernelAbstractionsExt = "KernelAbstractions"
5154
LinearSolveKrylovKitExt = "KrylovKit"
5255
LinearSolveMetalExt = "Metal"
5356
LinearSolvePardisoExt = "Pardiso"
57+
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
5458

5559
[compat]
5660
ArrayInterface = "7.4.11"
61+
BandedMatrices = "1"
5762
BlockDiagonals = "0.1"
5863
ConcreteStructs = "0.2"
5964
DocStringExtensions = "0.8, 0.9"
@@ -69,6 +74,7 @@ Krylov = "0.9"
6974
KrylovKit = "0.5, 0.6"
7075
PrecompileTools = "1"
7176
Preferences = "1"
77+
RecursiveArrayTools = "2"
7278
RecursiveFactorization = "0.2.8"
7379
Reexport = "1"
7480
Requires = "1"
@@ -80,6 +86,7 @@ UnPack = "1"
8086
julia = "1.6"
8187

8288
[extras]
89+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
8390
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
8491
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
8592
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
@@ -95,6 +102,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
95102
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
96103
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
97104
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
105+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
98106
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
99107
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
100108

ext/LinearSolveBandedMatricesExt.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
module LinearSolveBandedMatricesExt
2+
3+
using BandedMatrices, LinearAlgebra, LinearSolve
4+
import LinearSolve: defaultalg,
5+
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
6+
7+
# Defaults for BandedMatrices
8+
function defaultalg(A::BandedMatrix, b, ::OperatorAssumptions)
9+
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
10+
end
11+
12+
function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions)
13+
return DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization)
14+
end
15+
16+
# BandedMatrices `qr` doesn't allow other args without causing an ambiguity
17+
do_factorization(alg::QRFactorization, A::BandedMatrix, b, u) = alg.inplace ? qr!(A) : qr(A)
18+
19+
function do_factorization(alg::LUFactorization, A::BandedMatrix, b, u)
20+
_pivot = alg.pivot isa NoPivot ? Val(false) : Val(true)
21+
return lu!(A, _pivot; check = false)
22+
end
23+
24+
# For BandedMatrix
25+
for alg in (:SVDFactorization, :MKLLUFactorization, :DiagonalFactorization,
26+
:SparspakFactorization, :KLUFactorization, :UMFPACKFactorization,
27+
:GenericLUFactorization, :RFLUFactorization, :BunchKaufmanFactorization,
28+
:CHOLMODFactorization, :NormalCholeskyFactorization, :LDLtFactorization,
29+
:AppleAccelerateLUFactorization, :CholeskyFactorization)
30+
@eval begin
31+
function init_cacheval(::$(alg), ::BandedMatrix, b, u, Pl, Pr, maxiters::Int,
32+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
33+
return nothing
34+
end
35+
end
36+
end
37+
38+
function init_cacheval(::LUFactorization, A::BandedMatrix, b, u, Pl, Pr, maxiters::Int,
39+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
40+
return lu(similar(A, 0, 0))
41+
end
42+
43+
# For Symmetric BandedMatrix
44+
for alg in (:SVDFactorization, :MKLLUFactorization, :DiagonalFactorization,
45+
:SparspakFactorization, :KLUFactorization, :UMFPACKFactorization,
46+
:GenericLUFactorization, :RFLUFactorization, :BunchKaufmanFactorization,
47+
:CHOLMODFactorization, :NormalCholeskyFactorization,
48+
:AppleAccelerateLUFactorization, :QRFactorization, :LUFactorization)
49+
@eval begin
50+
function init_cacheval(::$(alg), ::Symmetric{<:Number, <:BandedMatrix}, b, u, Pl,
51+
Pr, maxiters::Int, abstol, reltol, verbose::Bool,
52+
assumptions::OperatorAssumptions)
53+
return nothing
54+
end
55+
end
56+
end
57+
58+
end
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module LinearSolveRecursiveArrayToolsExt
2+
3+
using LinearSolve, RecursiveArrayTools
4+
import LinearSolve: init_cacheval
5+
6+
# Krylov.jl tries to init with `ArrayPartition(undef, ...)`. Avoid hitting that!
7+
function init_cacheval(alg::LinearSolve.KrylovJL, A, b::ArrayPartition, u, Pl, Pr,
8+
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
9+
return nothing
10+
end
11+
12+
end

src/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ However, in practice this computation is very expensive and thus not possible fo
1414
Therefore, OperatorCondition lets one share to LinearSolve the expected conditioning. The higher the
1515
expected condition number, the safer the algorithm needs to be and thus there is a trade-off between
1616
numerical performance and stability. By default the method assumes the operator may be ill-conditioned
17-
for the standard linear solvers to converge (such as LU-factorization), though more extreme
17+
for the standard linear solvers to converge (such as LU-factorization), though more extreme
1818
ill-conditioning or well-conditioning could be the case and specified through this assumption.
1919
"""
2020
EnumX.@enumx OperatorCondition begin

src/default.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ end
9393
end
9494

9595
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions)
96-
if assump.condition === OperatorConodition.IllConditioned || !assump.issq
96+
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
9797
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
9898
else
9999
@static if VERSION >= v"1.8-"
@@ -163,7 +163,7 @@ function defaultalg(A, b, assump::OperatorAssumptions)
163163
DefaultAlgorithmChoice.GenericLUFactorization
164164
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
165165
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
166-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
166+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
167167
(usemkl && length(b) <= 200)) &&
168168
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
169169
eltype(A) <: Union{Float32, Float64})

0 commit comments

Comments
 (0)