Skip to content

Commit 29079ab

Browse files
Merge pull request #9 from huanglangwen/master
add findstructralnz for (bi/tri-)diagonal matrices
2 parents 2d00a0c + ff564e8 commit 29079ab

File tree

2 files changed

+119
-2
lines changed

2 files changed

+119
-2
lines changed

src/ArrayInterface.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
module ArrayInterface
22

33
using Requires
4+
using LinearAlgebra
5+
using SparseArrays
6+
7+
export findstructralnz,has_sparsestruct
48

59
function ismutable end
610

@@ -16,6 +20,88 @@ ismutable(x) = ismutable(typeof(x))
1620
ismutable(::Type{<:Array}) = true
1721
ismutable(::Type{<:Number}) = false
1822

23+
"""
24+
has_sparsestruct(x::AbstractArray)
25+
26+
determine whether `findstructralnz` accepts the parameter `x`
27+
"""
28+
has_sparsestruct(x)=false
29+
has_sparsestruct(x::AbstractArray)=false
30+
has_sparsestruct(x::SparseMatrixCSC)=true
31+
has_sparsestruct(x::Diagonal)=true
32+
has_sparsestruct(x::Bidiagonal)=true
33+
has_sparsestruct(x::Tridiagonal)=true
34+
has_sparsestruct(x::SymTridiagonal)=true
35+
36+
"""
37+
findstructralnz(x::AbstractArray)
38+
39+
Return: (I,J) #indexable objects
40+
Find sparsity pattern of special matrices, similar to first two elements of findnz(::SparseMatrixCSC)
41+
"""
42+
function findstructralnz(x::Diagonal)
43+
n=size(x,1)
44+
(1:n,1:n)
45+
end
46+
47+
abstract type MatrixIndex end
48+
49+
struct BidiagonalIndex <: MatrixIndex
50+
count::Int
51+
isup::Bool
52+
end
53+
54+
struct TridiagonalIndex <: MatrixIndex
55+
count::Int
56+
nsize::Int
57+
isrow::Bool
58+
end
59+
60+
Base.firstindex(ind::MatrixIndex)=1
61+
Base.lastindex(ind::MatrixIndex)=ind.count
62+
Base.length(ind::MatrixIndex)=ind.count
63+
function Base.getindex(ind::BidiagonalIndex,i::Int)
64+
1 <= i <= ind.count || throw(BoundsError(ind, i))
65+
if ind.isup
66+
ii=i+1
67+
else
68+
ii=i+1+1
69+
end
70+
convert(Int,floor(ii/2))
71+
end
72+
73+
function Base.getindex(ind::TridiagonalIndex,i::Int)
74+
1 <= i <= ind.count || throw(BoundsError(ind, i))
75+
offsetu= ind.isrow ? 0 : 1
76+
offsetl= ind.isrow ? 1 : 0
77+
if 1 <= i <= ind.nsize
78+
return i
79+
elseif ind.nsize < i <= ind.nsize+ind.nsize-1
80+
return i-ind.nsize+offsetu
81+
else
82+
return i-(ind.nsize+ind.nsize-1)+offsetl
83+
end
84+
end
85+
86+
function findstructralnz(x::Bidiagonal)
87+
n=size(x,1)
88+
isup= x.uplo=='U' ? true : false
89+
rowind=BidiagonalIndex(n+n-1,isup)
90+
colind=BidiagonalIndex(n+n-1,!isup)
91+
(rowind,colind)
92+
end
93+
94+
function findstructralnz(x::Union{Tridiagonal,SymTridiagonal})
95+
n=size(x,1)
96+
rowind=TridiagonalIndex(n+n-1+n-1,n,true)
97+
colind=TridiagonalIndex(n+n-1+n-1,n,false)
98+
(rowind,colind)
99+
end
100+
101+
function findstructralnz(x::SparseMatrixCSC)
102+
rowind,colind,_=findnz(x)
103+
(rowind,colind)
104+
end
19105

20106
function __init__()
21107

test/runtests.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,36 @@ using ArrayInterface, Test
33
@test ArrayInterface.ismutable(rand(3))
44

55
using StaticArrays
6-
ArrayInterface.ismutable(@SVector [1,2,3]) == false
7-
ArrayInterface.ismutable(@MVector [1,2,3]) == true
6+
@test ArrayInterface.ismutable(@SVector [1,2,3]) == false
7+
@test ArrayInterface.ismutable(@MVector [1,2,3]) == true
8+
9+
using LinearAlgebra, SparseArrays
10+
D=Diagonal([1,2,3,4])
11+
@test has_sparsestruct(D)
12+
rowind,colind=findstructralnz(D)
13+
@test [D[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4]
14+
@test length(rowind)==4
15+
@test length(rowind)==length(colind)
16+
17+
Bu = Bidiagonal([1,2,3,4], [7,8,9], :U)
18+
@test has_sparsestruct(Bu)
19+
rowind,colind=findstructralnz(Bu)
20+
@test [Bu[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,7,2,8,3,9,4]
21+
Bl = Bidiagonal([1,2,3,4], [7,8,9], :L)
22+
@test has_sparsestruct(Bl)
23+
rowind,colind=findstructralnz(Bl)
24+
@test [Bl[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,7,2,8,3,9,4]
25+
26+
Tri=Tridiagonal([1,2,3],[1,2,3,4],[4,5,6])
27+
@test has_sparsestruct(Tri)
28+
rowind,colind=findstructralnz(Tri)
29+
@test [Tri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4,4,5,6,1,2,3]
30+
STri=SymTridiagonal([1,2,3,4],[5,6,7])
31+
@test has_sparsestruct(STri)
32+
rowind,colind=findstructralnz(STri)
33+
@test [STri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4,5,6,7,5,6,7]
34+
35+
Sp=sparse([1,2,3],[1,2,3],[1,2,3])
36+
@test has_sparsestruct(Sp)
37+
rowind,colind=findstructralnz(Sp)
38+
@test [Tri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3]

0 commit comments

Comments
 (0)