Skip to content

Commit 4fd3aad

Browse files
authored
Generalize istriu/istril to accept a band index (#590)
Currently, only `istriu(S)` and `istril(S)` are specialized for sparse matrices, and this PR generalizes these to accept the band index `k`. This improves performance.
1 parent 780c4de commit 4fd3aad

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/sparsematrix.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4141,7 +4141,7 @@ function is_hermsym(A::AbstractSparseMatrixCSC, check::Function)
41414141
return true
41424142
end
41434143

4144-
function istriu(A::AbstractSparseMatrixCSC)
4144+
function istriu(A::AbstractSparseMatrixCSC, k::Integer=0)
41454145
m, n = size(A)
41464146
colptr = getcolptr(A)
41474147
rowval = rowvals(A)
@@ -4150,7 +4150,8 @@ function istriu(A::AbstractSparseMatrixCSC)
41504150
for col = 1:min(n, m-1)
41514151
l1 = colptr[col+1]-1
41524152
for i = 0 : (l1 - colptr[col])
4153-
if rowval[l1-i] <= col
4153+
if rowval[l1-i] <= col - k
4154+
# rows preceeding the index would also lie above the band
41544155
break
41554156
end
41564157
if _isnotzero(nzval[l1-i])
@@ -4161,15 +4162,16 @@ function istriu(A::AbstractSparseMatrixCSC)
41614162
return true
41624163
end
41634164

4164-
function istril(A::AbstractSparseMatrixCSC)
4165+
function istril(A::AbstractSparseMatrixCSC, k::Integer=0)
41654166
m, n = size(A)
41664167
colptr = getcolptr(A)
41674168
rowval = rowvals(A)
41684169
nzval = nonzeros(A)
41694170

41704171
for col = 2:n
41714172
for i = colptr[col] : (colptr[col+1]-1)
4172-
if rowval[i] >= col
4173+
if rowval[i] >= col - k
4174+
# subsequent rows would also lie below the band
41734175
break
41744176
end
41754177
if _isnotzero(nzval[i])

test/sparsematrix_ops.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,4 +626,17 @@ end
626626
@test_throws ArgumentError copytrito!(M, S, 'M')
627627
end
628628

629+
@testset "istriu/istril" begin
630+
for T in Any[Tridiagonal(1:3, 1:4, 1:3),
631+
Bidiagonal(1:4, 1:3, :U), Bidiagonal(1:4, 1:3, :L),
632+
Diagonal(1:4),
633+
diagm(-2=>1:2, 2=>1:2)]
634+
S = sparse(T)
635+
for k in -5:5
636+
@test istriu(S, k) == istriu(T, k)
637+
@test istril(S, k) == istril(T, k)
638+
end
639+
end
640+
end
641+
629642
end # module

0 commit comments

Comments
 (0)