Skip to content

Commit ae71cf2

Browse files
Merge pull request #422 from avik-pal/ap/banded
Add dispatches for Transpose and Adjoint for Banded Matrices
2 parents e63b793 + 9e5a85a commit ae71cf2

File tree

3 files changed

+60
-19
lines changed

3 files changed

+60
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "7.4.11"
3+
version = "7.5.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ArrayInterfaceBandedMatricesExt.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
module ArrayInterfaceBandedMatricesExt
22

3-
43
if isdefined(Base, :get_extension)
54
using ArrayInterface
65
using ArrayInterface: BandedMatrixIndex
76
using BandedMatrices
7+
using LinearAlgebra
88
else
99
using ..ArrayInterface
1010
using ..ArrayInterface: BandedMatrixIndex
1111
using ..BandedMatrices
12+
using ..LinearAlgebra
1213
end
1314

15+
const TransOrAdjBandedMatrix = Union{
16+
Adjoint{T, <:BandedMatrix{T}},
17+
Transpose{T, <:BandedMatrix{T}},
18+
} where {T}
19+
20+
const AllBandedMatrix = Union{
21+
BandedMatrix{T},
22+
TransOrAdjBandedMatrix{T},
23+
} where {T}
1424

1525
Base.firstindex(i::BandedMatrixIndex) = 1
1626
Base.lastindex(i::BandedMatrixIndex) = i.count
@@ -45,24 +55,24 @@ end
4555

4656
function BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow)
4757
upperbandwidth > -lowerbandwidth || throw(ErrorException("Invalid Bandwidths"))
48-
bandinds = upperbandwidth:-1:-lowerbandwidth
58+
bandinds = upperbandwidth:-1:(-lowerbandwidth)
4959
bandsizes = [_bandsize(band, rowsize, colsize) for band in bandinds]
5060
BandedMatrixIndex(sum(bandsizes), rowsize, colsize, bandinds, bandsizes, isrow)
5161
end
5262

53-
function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix)
63+
function ArrayInterface.findstructralnz(x::AllBandedMatrix)
5464
l, u = BandedMatrices.bandwidths(x)
5565
rowsize, colsize = Base.size(x)
5666
rowind = BandedMatrixIndex(rowsize, colsize, l, u, true)
5767
colind = BandedMatrixIndex(rowsize, colsize, l, u, false)
5868
return (rowind, colind)
5969
end
6070

61-
ArrayInterface.has_sparsestruct(::Type{<:BandedMatrices.BandedMatrix}) = true
62-
ArrayInterface.isstructured(::Type{<:BandedMatrices.BandedMatrix}) = true
63-
ArrayInterface.fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true
71+
ArrayInterface.has_sparsestruct(::Type{<:AllBandedMatrix}) = true
72+
ArrayInterface.isstructured(::Type{<:AllBandedMatrix}) = true
73+
ArrayInterface.fast_matrix_colors(::Type{<:AllBandedMatrix}) = true
6474

65-
function ArrayInterface.matrix_colors(A::BandedMatrices.BandedMatrix)
75+
function ArrayInterface.matrix_colors(A::AllBandedMatrix)
6676
l, u = BandedMatrices.bandwidths(A)
6777
width = u + l + 1
6878
return ArrayInterface._cycle(1:width, Base.size(A, 2))

test/bandedmatrices.jl

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,50 @@
1-
21
using ArrayInterface
32
using BandedMatrices
43
using Test
54

6-
B=BandedMatrix(Ones(5,5), (-1,2))
7-
B[band(1)].=[1,2,3,4]
8-
B[band(2)].=[5,6,7]
5+
function checkequal(idx1::ArrayInterface.BandedMatrixIndex,
6+
idx2::ArrayInterface.BandedMatrixIndex)
7+
return idx1.rowsize == idx2.rowsize && idx1.colsize == idx2.colsize &&
8+
idx1.bandinds == idx2.bandinds && idx1.bandsizes == idx2.bandsizes &&
9+
idx1.isrow == idx2.isrow && idx1.count == idx2.count
10+
end
11+
12+
B = BandedMatrix(Ones(5, 5), (-1, 2))
13+
B[band(1)] .= [1, 2, 3, 4]
14+
B[band(2)] .= [5, 6, 7]
915
@test ArrayInterface.has_sparsestruct(B)
10-
rowind,colind=ArrayInterface.findstructralnz(B)
11-
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4]
12-
B=BandedMatrix(Ones(4,6), (-1,2))
13-
B[band(1)].=[1,2,3,4]
14-
B[band(2)].=[5,6,7,8]
15-
rowind,colind=ArrayInterface.findstructralnz(B)
16-
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4]
16+
rowind, colind = ArrayInterface.findstructralnz(B)
17+
@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 1, 2, 3, 4]
18+
B = BandedMatrix(Ones(4, 6), (-1, 2))
19+
B[band(1)] .= [1, 2, 3, 4]
20+
B[band(2)] .= [5, 6, 7, 8]
21+
rowind, colind = ArrayInterface.findstructralnz(B)
22+
@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 8, 1, 2, 3, 4]
1723
@test ArrayInterface.isstructured(typeof(B))
1824
@test ArrayInterface.fast_matrix_colors(typeof(B))
1925

26+
for op in (adjoint, transpose)
27+
B = BandedMatrix(Ones(5, 5), (-1, 2))
28+
B[band(1)] .= [1, 2, 3, 4]
29+
B[band(2)] .= [5, 6, 7]
30+
B′ = op(B)
31+
@test ArrayInterface.has_sparsestruct(B′)
32+
rowind′, colind′ = ArrayInterface.findstructralnz(B′)
33+
rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′))
34+
@test checkequal(rowind′, rowind′′)
35+
@test checkequal(colind′, colind′′)
36+
37+
B = BandedMatrix(Ones(4, 6), (-1, 2))
38+
B[band(1)] .= [1, 2, 3, 4]
39+
B[band(2)] .= [5, 6, 7, 8]
40+
B′ = op(B)
41+
rowind′, colind′ = ArrayInterface.findstructralnz(B′)
42+
rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′))
43+
@test checkequal(rowind′, rowind′′)
44+
@test checkequal(colind′, colind′′)
45+
46+
@test ArrayInterface.isstructured(typeof(B′))
47+
@test ArrayInterface.fast_matrix_colors(typeof(B′))
48+
49+
@test ArrayInterface.matrix_colors(B′) == ArrayInterface.matrix_colors(BandedMatrix(B′))
50+
end

0 commit comments

Comments
 (0)