Skip to content

Commit 31b84bd

Browse files
committed
add findstructralnz support for BandedMatrix
1 parent c4cee12 commit 31b84bd

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

src/ArrayInterface.jl

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,26 @@ struct BidiagonalIndex <: MatrixIndex
104104
end
105105

106106
struct TridiagonalIndex <: MatrixIndex
107+
count::Int#count==nsize+nsize-1+nsize-1
108+
nsize::Int
109+
isrow::Bool
110+
end
111+
112+
struct BandedMatrixIndex <: MatrixIndex
107113
count::Int
108114
nsize::Int
115+
bandinds::Array{Int}
116+
bandsizes::Array{Int}
109117
isrow::Bool
110118
end
111119

120+
function BandedMatrixIndex(nsize,lowerbandwidth,upperbandwidth,isrow)
121+
upperbandwidth>-lowerbandwidth || throw(ErrorException("Invalid Bandwidths"))
122+
bandinds=upperbandwidth:-1:-lowerbandwidth
123+
bandsizes=[nsize-abs(band) for band in bandinds]
124+
BandedMatrixIndex(sum(bandsizes),nsize,bandinds,bandsizes,isrow)
125+
end
126+
112127
Base.firstindex(ind::MatrixIndex)=1
113128
Base.lastindex(ind::MatrixIndex)=ind.count
114129
Base.length(ind::MatrixIndex)=ind.count
@@ -135,6 +150,23 @@ function Base.getindex(ind::TridiagonalIndex,i::Int)
135150
end
136151
end
137152

153+
function Base.getindex(ind::BandedMatrixIndex,i::Int)
154+
1 <= i <= ind.count || throw(BoundsError(ind, i))
155+
_i=i
156+
p=1
157+
while _i-ind.bandsizes[p]>0
158+
_i-=ind.bandsizes[p]
159+
p+=1
160+
end
161+
bandind=ind.bandinds[p]
162+
startfromone=ind.isrow & (bandind>0)
163+
if startfromone
164+
return _i
165+
else
166+
return _i+abs(bandind)
167+
end
168+
end
169+
138170
function findstructralnz(x::Bidiagonal)
139171
n=size(x,1)
140172
isup= x.uplo=='U' ? true : false
@@ -225,11 +257,20 @@ function __init__()
225257
end
226258

227259
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
260+
function findstructralnz(x::BandedMatrices.BandedMatrix)
261+
l,u=BandedMatrices.bandwidths(x)
262+
nsize=size(x,1)
263+
rowind=BandedMatrixIndex(nsize,l,u,true)
264+
colind=BandedMatrixIndex(nsize,l,u,false)
265+
(rowind,colind)
266+
end
267+
268+
has_sparsestruct(::Type{<:BandedMatrices.BandedMatrix}) = true
228269
is_structured(::Type{<:BandedMatrices.BandedMatrix}) = true
229270
fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true
230271

231272
function matrix_colors(A::BandedMatrices.BandedMatrix)
232-
u,l=bandwidths(A)
273+
l,u=BandedMatrices.bandwidths(A)
233274
width=u+l+1
234275
_cycle(1:width,size(A,2))
235276
end
@@ -243,10 +284,10 @@ function __init__()
243284
fast_matrix_colors(::Type{<:BlockBandedMatrices.BandedBlockBandedMatrix}) = true
244285

245286
function matrix_colors(A::BlockBandedMatrices.BlockBandedMatrix)
246-
l,u=blockbandwidths(A)
287+
l,u=BlockBandedMatrices.blockbandwidths(A)
247288
blockwidth=l+u+1
248-
nblock=nblocks(A,2)
249-
cols=[blocksize(A,(1,i))[2] for i in 1:nblock]
289+
nblock=BlockBandedMatrices.nblocks(A,2)
290+
cols=[BlockBandedMatrices.blocksize(A,(1,i))[2] for i in 1:nblock]
250291
blockcolors=_cycle(1:blockwidth,nblock)
251292
#the reserved number of colors of a block is the maximum length of columns of blocks with the same block color
252293
ncolors=[maximum(cols[i:blockwidth:nblock]) for i in 1:blockwidth]
@@ -257,12 +298,12 @@ function __init__()
257298
end
258299

259300
function matrix_colors(A::BlockBandedMatrices.BandedBlockBandedMatrix)
260-
l,u=blockbandwidths(A)
261-
lambda,mu=subblockbandwidths(A)
301+
l,u=BlockBandedMatrices.blockbandwidths(A)
302+
lambda,mu=BlockBandedMatrices.subblockbandwidths(A)
262303
blockwidth=l+u+1
263304
subblockwidth=lambda+mu+1
264-
nblock=nblocks(A,2)
265-
cols=[blocksize(A,(1,i))[2] for i in 1:nblock]
305+
nblock=BlockBandedMatrices.nblocks(A,2)
306+
cols=[BlockBandedMatrices.blocksize(A,(1,i))[2] for i in 1:nblock]
266307
blockcolors=_cycle(1:blockwidth,nblock)
267308
#the reserved number of colors of a block is the min of subblockwidth and the largest length of columns of blocks with the same block color
268309
ncolors=[min(subblockwidth,maximum(cols[i:blockwidth:nblock])) for i in 1:min(blockwidth,nblock)]

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,12 @@ Sp=sparse([1,2,3],[1,2,3],[1,2,3])
3737
@test has_sparsestruct(Sp)
3838
rowind,colind=findstructralnz(Sp)
3939
@test [Tri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3]
40+
41+
using BandedMatrices
42+
43+
B=BandedMatrix(Ones(5,5), (-1,2))
44+
B[band(1)].=[1,2,3,4]
45+
B[band(2)].=[5,6,7]
46+
@test has_sparsestruct(B)
47+
rowind,colind=findstructralnz(B)
48+
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4]

0 commit comments

Comments
 (0)