Skip to content

Commit 8a5f7b4

Browse files
committed
add findstructralnz for (bi/tri-)diagonal matrices
1 parent 2d00a0c commit 8a5f7b4

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

src/ArrayInterface.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
module ArrayInterface
22

33
using Requires
4+
using LinearAlgebra
5+
using SparseArrays
6+
7+
export findstructralnz
48

59
function ismutable end
610

@@ -16,6 +20,69 @@ ismutable(x) = ismutable(typeof(x))
1620
ismutable(::Type{<:Array}) = true
1721
ismutable(::Type{<:Number}) = false
1822

23+
"""
24+
findstructralnz(x::AbstractArray)
25+
26+
Return: (I,J) #indexable objects
27+
Find sparsity pattern of special matrices, similar to first two elements of findnz(::SparseCSCMatrix)
28+
"""
29+
function findstructralnz(x::Diagonal)
30+
n=size(x,1)
31+
(1:n,1:n)
32+
end
33+
34+
abstract type MatrixIndex end
35+
36+
struct BidiagonalIndex <: MatrixIndex
37+
count::Int
38+
isup::Bool
39+
end
40+
41+
struct TridiagonalIndex <: MatrixIndex
42+
count::Int
43+
nsize::Int
44+
isrow::Bool
45+
end
46+
47+
Base.firstindex(ind::MatrixIndex)=1
48+
Base.lastindex(ind::MatrixIndex)=ind.count
49+
function Base.getindex(ind::BidiagonalIndex,i::Int)
50+
1 <= i <= ind.count || throw(BoundsError(ind, i))
51+
if ind.isup
52+
ii=i+1
53+
else
54+
ii=i+1+1
55+
end
56+
convert(Int,floor(ii/2))
57+
end
58+
59+
function Base.getindex(ind::TridiagonalIndex,i::Int)
60+
1 <= i <= ind.count || throw(BoundsError(ind, i))
61+
offsetu= ind.isrow ? 0 : 1
62+
offsetl= ind.isrow ? 1 : 0
63+
if 1 <= i <= ind.nsize
64+
return i
65+
elseif ind.nsize < i <= ind.nsize+ind.nsize-1
66+
return i-ind.nsize+offsetu
67+
else
68+
return i-(ind.nsize+ind.nsize-1)+offsetl
69+
end
70+
end
71+
72+
function findstructralnz(x::Bidiagonal)
73+
n=size(x,1)
74+
isup= x.uplo=='U' ? true : false
75+
rowind=BidiagonalIndex(n+n-1,isup)
76+
colind=BidiagonalIndex(n+n-1,!isup)
77+
(rowind,colind)
78+
end
79+
80+
function findstructralnz(x::Union{Tridiagonal,SymTridiagonal})
81+
n=size(x,1)
82+
rowind=TridiagonalIndex(n+n-1+n-1,n,true)
83+
colind=TridiagonalIndex(n+n-1+n-1,n,false)
84+
(rowind,colind)
85+
end
1986

2087
function __init__()
2188

test/runtests.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,24 @@ using ArrayInterface, Test
33
@test ArrayInterface.ismutable(rand(3))
44

55
using StaticArrays
6-
ArrayInterface.ismutable(@SVector [1,2,3]) == false
7-
ArrayInterface.ismutable(@MVector [1,2,3]) == true
6+
@test ArrayInterface.ismutable(@SVector [1,2,3]) == false
7+
@test ArrayInterface.ismutable(@MVector [1,2,3]) == true
8+
9+
using LinearAlgebra
10+
D=Diagonal([1,2,3,4])
11+
rowind,colind=findstructralnz(D)
12+
@test [D[rowind[i],colind[i]] for i in 1:4]==[1,2,3,4]
13+
14+
Bu = Bidiagonal([1,2,3,4], [7,8,9], :U)
15+
rowind,colind=findstructralnz(Bu)
16+
@test [Bu[rowind[i],colind[i]] for i in 1:7]==[1,7,2,8,3,9,4]
17+
Bl = Bidiagonal([1,2,3,4], [7,8,9], :L)
18+
rowind,colind=findstructralnz(Bl)
19+
@test [Bl[rowind[i],colind[i]] for i in 1:7]==[1,7,2,8,3,9,4]
20+
21+
Tri=Tridiagonal([1,2,3],[1,2,3,4],[4,5,6])
22+
rowind,colind=findstructralnz(Tri)
23+
@test [Tri[rowind[i],colind[i]] for i in 1:10]==[1,2,3,4,4,5,6,1,2,3]
24+
STri=SymTridiagonal([1,2,3,4],[5,6,7])
25+
rowind,colind=findstructralnz(STri)
26+
@test [STri[rowind[i],colind[i]] for i in 1:10]==[1,2,3,4,5,6,7,5,6,7]

0 commit comments

Comments
 (0)