Skip to content

Commit a7981c8

Browse files
committed
add findstructralnz support for BlockBandedMatrix
1 parent 64fab10 commit a7981c8

File tree

2 files changed

+123
-14
lines changed

2 files changed

+123
-14
lines changed

src/ArrayInterface.jl

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,85 @@ end
111111

112112
struct BandedMatrixIndex <: MatrixIndex
113113
count::Int
114-
nsize::Int
115-
bandinds::Array{Int}
116-
bandsizes::Array{Int}
114+
rowsize::Int
115+
colsize::Int
116+
bandinds::Array{Int,1}
117+
bandsizes::Array{Int,1}
117118
isrow::Bool
118119
end
119120

120-
function BandedMatrixIndex(nsize,lowerbandwidth,upperbandwidth,isrow)
121+
function _bandsize(bandind,rowsize,colsize)
122+
-(rowsize-1) <= bandind <= colsize-1 || throw(ErrorException("Invalid Bandind"))
123+
if (bandind*(colsize-rowsize)>0) & (abs(bandind)<=abs(colsize-rowsize))
124+
return min(rowsize,colsize)
125+
elseif bandind*(colsize-rowsize)<=0
126+
return min(rowsize,colsize)-abs(bandind)
127+
else
128+
return min(rowsize,colsize)-abs(bandind)+abs(colsize-rowsize)
129+
end
130+
end
131+
132+
function BandedMatrixIndex(rowsize,colsize,lowerbandwidth,upperbandwidth,isrow)
121133
upperbandwidth>-lowerbandwidth || throw(ErrorException("Invalid Bandwidths"))
122134
bandinds=upperbandwidth:-1:-lowerbandwidth
123-
bandsizes=[nsize-abs(band) for band in bandinds]
124-
BandedMatrixIndex(sum(bandsizes),nsize,bandinds,bandsizes,isrow)
135+
bandsizes=[_bandsize(band,rowsize,colsize) for band in bandinds]
136+
BandedMatrixIndex(sum(bandsizes),rowsize,colsize,bandinds,bandsizes,isrow)
137+
end
138+
139+
struct BlockBandedMatrixIndex <: MatrixIndex
140+
count::Int
141+
refinds::Array{Int,1}
142+
refcoords::Array{Int,1}#storing col or row inds at ref points
143+
isrow::Bool
144+
end
145+
146+
function BlockBandedMatrixIndex(nrowblock,ncolblock,rowsizes,colsizes,l,u)
147+
blockrowind=BandedMatrixIndex(nrowblock,ncolblock,l,u,true)
148+
blockcolind=BandedMatrixIndex(nrowblock,ncolblock,l,u,false)
149+
sortedinds=sort([(blockrowind[i],blockcolind[i]) for i in 1:length(blockrowind)],by=x->x[2])
150+
sort!(sortedinds,by=x->x[1],alg=InsertionSort)#stable sort keeps the second index in order
151+
refinds=Array{Int,1}()
152+
refrowcoords=Array{Int,1}()
153+
refcolcoords=Array{Int,1}()
154+
rowheights=cumsum(pushfirst!(copy(rowsizes),1))
155+
blockheight=0
156+
blockrow=1
157+
blockcol=1
158+
currenti=1
159+
lastrowind=sortedinds[1][1]-1
160+
lastcolind=sortedinds[1][2]
161+
for ind in sortedinds
162+
rowind,colind=ind
163+
if colind==lastcolind
164+
if rowind>lastrowind
165+
blockheight+=rowsizes[rowind]
166+
end
167+
else
168+
for j in blockcol:blockcol+colsizes[lastcolind]-1
169+
push!(refinds,currenti)
170+
push!(refrowcoords,blockrow)
171+
push!(refcolcoords,j)
172+
currenti+=blockheight
173+
end
174+
blockcol+=colsizes[lastcolind]
175+
blockrow=rowheights[rowind]
176+
blockheight=rowsizes[rowind]
177+
end
178+
lastcolind=colind
179+
lastrowind=rowind
180+
end
181+
for j in blockcol:blockcol+colsizes[lastcolind]-1
182+
push!(refinds,currenti)
183+
push!(refrowcoords,blockrow)
184+
push!(refcolcoords,j)
185+
currenti+=blockheight
186+
end
187+
push!(refinds,currenti)#guard
188+
push!(refrowcoords,-1)
189+
push!(refcolcoords,-1)
190+
rowindobj=BlockBandedMatrixIndex(currenti-1,refinds,refrowcoords,true)
191+
colindobj=BlockBandedMatrixIndex(currenti-1,refinds,refcolcoords,false)
192+
rowindobj,colindobj
125193
end
126194

127195
Base.firstindex(ind::MatrixIndex)=1
@@ -159,14 +227,29 @@ function Base.getindex(ind::BandedMatrixIndex,i::Int)
159227
p+=1
160228
end
161229
bandind=ind.bandinds[p]
162-
startfromone=ind.isrow & (bandind>0)
230+
startfromone=!xor(ind.isrow,(bandind>0))
163231
if startfromone
164232
return _i
165233
else
166234
return _i+abs(bandind)
167235
end
168236
end
169237

238+
function Base.getindex(ind::BlockBandedMatrixIndex,i::Int)
239+
1 <= i <= ind.count || throw(BoundsError(ind, i))
240+
p=1
241+
while i-ind.refinds[p]>=0
242+
p+=1
243+
end
244+
p-=1
245+
_i=i-ind.refinds[p]
246+
if ind.isrow
247+
return ind.refcoords[p]+_i
248+
else
249+
return ind.refcoords[p]
250+
end
251+
end
252+
170253
function findstructralnz(x::Bidiagonal)
171254
n=size(x,1)
172255
isup= x.uplo=='U' ? true : false
@@ -259,9 +342,9 @@ function __init__()
259342
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
260343
function findstructralnz(x::BandedMatrices.BandedMatrix)
261344
l,u=BandedMatrices.bandwidths(x)
262-
nsize=size(x,1)
263-
rowind=BandedMatrixIndex(nsize,l,u,true)
264-
colind=BandedMatrixIndex(nsize,l,u,false)
345+
rowsize,colsize=size(x)
346+
rowind=BandedMatrixIndex(rowsize,colsize,l,u,true)
347+
colind=BandedMatrixIndex(rowsize,colsize,l,u,false)
265348
(rowind,colind)
266349
end
267350

@@ -277,9 +360,19 @@ function __init__()
277360

278361
end
279362

280-
@require BlockBandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
281-
is_structured(::Type{<:BandedMatrices.BlockBandedMatrix}) = true
282-
is_structured(::Type{<:BandedMatrices.BandedBlockBandedMatrix}) = true
363+
@require BlockBandedMatrices="ffab5731-97b5-5995-9138-79e8c1846df0" begin
364+
function findstructralnz(x::BlockBandedMatrices.BlockBandedMatrix)
365+
l,u=BlockBandedMatrices.blockbandwidths(x)
366+
nrowblock=BlockBandedMatrices.nblocks(x,1)
367+
ncolblock=BlockBandedMatrices.nblocks(x,2)
368+
rowsizes=[BlockBandedMatrices.blocksize(x,(i,1))[1] for i in 1:nrowblock]
369+
colsizes=[BlockBandedMatrices.blocksize(x,(1,i))[2] for i in 1:ncolblock]
370+
BlockBandedMatrixIndex(nrowblock,ncolblock,rowsizes,colsizes,l,u)
371+
end
372+
373+
has_sparsestruct(::Type{<:BlockBandedMatrices.BlockBandedMatrix}) = true
374+
is_structured(::Type{<:BlockBandedMatrices.BlockBandedMatrix}) = true
375+
is_structured(::Type{<:BlockBandedMatrices.BandedBlockBandedMatrix}) = true
283376
fast_matrix_colors(::Type{<:BlockBandedMatrices.BlockBandedMatrix}) = true
284377
fast_matrix_colors(::Type{<:BlockBandedMatrices.BandedBlockBandedMatrix}) = true
285378

test/runtests.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,20 @@ B[band(1)].=[1,2,3,4]
4545
B[band(2)].=[5,6,7]
4646
@test has_sparsestruct(B)
4747
rowind,colind=findstructralnz(B)
48-
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4]
48+
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4]
49+
B=BandedMatrix(Ones(4,6), (-1,2))
50+
B[band(1)].=[1,2,3,4]
51+
B[band(2)].=[5,6,7,8]
52+
rowind,colind=findstructralnz(B)
53+
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4]
54+
55+
using BlockBandedMatrices
56+
BB=BlockBandedMatrix(Ones(10,10),([1,2,3,4],[4,3,2,1]),(1,0))
57+
BB[Block(1,1)].=[1 2 3 4]
58+
BB[Block(2,1)].=[5 6 7 8;9 10 11 12]
59+
rowind,colind=findstructralnz(BB)
60+
@test [BB[rowind[i],colind[i]] for i in 1:length(rowind)]==
61+
[1,5,9,2,6,10,3,7,11,4,8,12,
62+
1,1,1,1,1,1,1,1,1,1,1,1,1,1,
63+
1,1,1,1,1,1,1,1,1,1,1,1,1,1,
64+
1,1,1,1,1]

0 commit comments

Comments
 (0)