Skip to content

Commit faea35f

Browse files
authored
Conditionally skip loops in two-term checkzerobands (#345)
* conditionally eliminate loops in two-term checkzerobands * tests for checkzerobands * Add tests * bump version to v0.17.20
1 parent c367853 commit faea35f

File tree

3 files changed

+56
-15
lines changed

3 files changed

+56
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BandedMatrices"
22
uuid = "aae01518-5342-5314-be14-df237901396f"
3-
version = "0.17.19"
3+
version = "0.17.20"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/generic/broadcast.jl

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,26 @@ function checkzerobands(dest, f, A::AbstractMatrix)
7878
l, u = bandwidths(A)
7979

8080
if !(d_l >= l && d_u >= u)
81-
for j = rowsupport(A)
82-
for k = max(1,j-u) : min(j-d_u-1,m)
83-
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
81+
if d_l >= l
82+
for j = rowsupport(A)
83+
for k = max(1,j-u) : min(j-d_u-1,m)
84+
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
85+
end
8486
end
85-
for k = max(1,j+d_l+1) : min(j+l,m)
86-
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
87+
elseif d_u >= u
88+
for j = rowsupport(A)
89+
for k = max(1,j+d_l+1) : min(j+l,m)
90+
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
91+
end
92+
end
93+
else
94+
for j = rowsupport(A)
95+
for k = max(1,j-u) : min(j-d_u-1,m)
96+
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
97+
end
98+
for k = max(1,j+d_l+1) : min(j+l,m)
99+
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
100+
end
87101
end
88102
end
89103
end
@@ -693,12 +707,29 @@ function checkzerobands(dest, f, (A,B)::NTuple{2,AbstractMatrix})
693707

694708
rspA = rowsupport(A)
695709
rspB = rowsupport(B)
696-
for j = min(first(rspA), first(rspB)):max(last(rspA), last(rspB))
697-
for k = max(1,j-u) : min(j-d_u-1,m)
698-
iszero(f(A[k,j], B[k,j])) || throw(BandError(dest,b))
699-
end
700-
for k = max(1,j+d_l+1) : min(j+l,m)
701-
iszero(f(A[k,j], B[k,j])) || throw(BandError(dest,b))
710+
# if the dest has more bands both above and below the diagonal, no checks are necessary
711+
if !(u <= d_u && l <= d_l)
712+
if u <= d_u # the dest has more bands above the diagonal, so we only check below
713+
for j = min(first(rspA), first(rspB)):max(last(rspA), last(rspB))
714+
for k = max(1,j+d_l+1) : min(j+l,m)
715+
iszero(f(A[k,j], B[k,j])) || throw(BandError(dest,j-k))
716+
end
717+
end
718+
elseif l <= d_l # the dest has more bands below the diagonal, so we only check above
719+
for j = min(first(rspA), first(rspB)):max(last(rspA), last(rspB))
720+
for k = max(1,j-u) : min(j-d_u-1,m)
721+
iszero(f(A[k,j], B[k,j])) || throw(BandError(dest,j-k))
722+
end
723+
end
724+
else # check both above and below the diagonal
725+
for j = min(first(rspA), first(rspB)):max(last(rspA), last(rspB))
726+
for k = max(1,j-u) : min(j-d_u-1,m)
727+
iszero(f(A[k,j], B[k,j])) || throw(BandError(dest,j-k))
728+
end
729+
for k = max(1,j+d_l+1) : min(j+l,m)
730+
iszero(f(A[k,j], B[k,j])) || throw(BandError(dest,j-k))
731+
end
732+
end
702733
end
703734
end
704735
end
@@ -727,9 +758,8 @@ function _banded_broadcast!(dest::AbstractMatrix, f, (A,B)::NTuple{2,AbstractMat
727758
if (d_l,d_u) == (A_l,A_u) == (B_l,B_u)
728759
data_d .= f.(data_A,data_B)
729760
else
730-
max_l,max_u = max(A_l,B_l,d_l),max(A_u,B_u,d_u)
731-
min_l,min_u = min(A_l,B_l,d_l),min(A_u,B_u,d_u)
732761
checkzerobands(dest, f, (A,B))
762+
min_l,min_u = min(A_l,B_l,d_l),min(A_u,B_u,d_u)
733763

734764
# fill extra bands in dest
735765
fill!(view(data_d,1:d_u-max(A_u,B_u),:), z)

test/test_broadcasting.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test
22
import Base: BroadcastStyle
33
import Base.Broadcast: broadcasted
4-
import BandedMatrices: BandedStyle, BandedRows
4+
import BandedMatrices: BandedStyle, BandedRows, BandError
55

66
@testset "broadcasting" begin
77
@testset "general" begin
@@ -47,6 +47,17 @@ import BandedMatrices: BandedStyle, BandedRows
4747
@test norm(A .- A[:,1]) == 0
4848
@test A A[:,1]
4949
end
50+
51+
@testset "checkzerobands" begin
52+
A = brand(10,10, 2,2)
53+
B = brand(10,10, 1,1)
54+
for (l,u) in ((0,0), (0,1), (1,0))
55+
dest = brand(size(A)..., l,u)
56+
@test_throws BandError dest .= A .+ A
57+
@test_throws BandError dest .= A .+ B
58+
@test_throws BandError dest .= A
59+
end
60+
end
5061
end
5162

5263
function test_empty(f!)

0 commit comments

Comments
 (0)