Skip to content

Commit 5f67cb5

Browse files
authored
error in broadcasting with fewer bands (#305)
1 parent e74f057 commit 5f67cb5

File tree

3 files changed

+60
-20
lines changed

3 files changed

+60
-20
lines changed

src/generic/Band.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ BandError(A::AbstractMatrix) = BandError(A, max(size(A)...)-1)
7373

7474
function showerror(io::IO, e::BandError)
7575
A, i = e.A, e.i
76-
print(io, "attempt to access $(typeof(A)) with bandwidths " *
76+
print(io, "BandError: attempt to access $(typeof(A)) with bandwidths " *
7777
"($(bandwidth(A, 1)), $(bandwidth(A, 2))) at band $i")
7878
end
7979

src/generic/broadcast.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function checkzerobands(dest, f, A::AbstractMatrix)
7878
l, u = bandwidths(A)
7979

8080
if (l,u) (d_l,d_u)
81-
for j = 1:n
81+
for j = rowsupport(A)
8282
for k = max(1,j-u) : min(j-d_u-1,m)
8383
iszero(f(A[k,j])) || throw(BandError(dest,j-k))
8484
end
@@ -93,21 +93,12 @@ function _banded_broadcast!(dest::AbstractMatrix, f, src::AbstractMatrix{T}, _1,
9393
z = f(zero(T))
9494
iszero(z) || checkbroadcastband(dest, size(src), bandwidths(broadcasted(f, src)))
9595
m,n = size(dest)
96+
m == n == 0 && return dest
9697

9798
d_l, d_u = bandwidths(dest)
9899
s_l, s_u = bandwidths(src)
99-
if d_l < min(s_l,m-1)
100-
# check zeros
101-
for j = 1:n, k = max(1,j+d_l+1):min(j+s_l,j+d_l,m)
102-
iszero(f(inbands_getindex(src, k, j))) || throw(BandError(dest))
103-
end
104-
end
105-
if d_u < min(s_u,n-1)
106-
# check zeros
107-
for j = 1:n, k = max(1,j-d_u,j-s_u):min(j-d_u-1,m)
108-
iszero(f(inbands_getindex(src, k, j))) || throw(BandError(dest))
109-
end
110-
end
100+
101+
checkzerobands(dest, f, src)
111102

112103
_banded_broadcast_anylayout!(dest, src, f, z, (s_l, s_u), (d_l, d_u), m)
113104

@@ -252,13 +243,15 @@ end
252243
min_su_du = min(s_u, d_u)
253244
min_sl_dl = min(s_l, d_l)
254245
for j = rowsupport(dest)
255-
for k = max(1,j-d_u):min(j-s_u-1,m)
246+
# if s_u < d_u, set extra bands in dest above the diagonal to zero
247+
for k = max(1,j-d_u):min(j+min(d_l, -(s_u+1)),m)
256248
inbands_setindex!(dest, z, k, j)
257249
end
258250

259251
_banded_broadcast_loop_overlap!(dest, src, f, (min_sl_dl, min_su_du), m, j)
260252

261-
for k = max(1,j+s_l+1):min(j+d_l,m)
253+
# if s_l < d_l, set extra bands in dest below the diagonal to zero
254+
for k = max(1,j+max(s_l+1, -d_u)):min(j+d_l,m)
262255
inbands_setindex!(dest, z, k, j)
263256
end
264257
end
@@ -276,12 +269,17 @@ function _banded_broadcast!(dest::AbstractMatrix, f, (src,x)::Tuple{AbstractMatr
276269
z = f(zero(T), x)
277270
iszero(z) || checkbroadcastband(dest, size(src), bandwidths(broadcasted(f, src,x)))
278271
m,n = size(dest)
272+
m == n == 0 && return dest
273+
274+
f_x = Base.Fix2(f, x)
275+
276+
# if dest has fewer bands (either above or below the diagonal) than dest,
277+
# then f.(x, dest) must be zero in these bands
278+
checkzerobands(dest, f_x, src)
279279

280280
d_l, d_u = bandwidths(dest)
281281
s_l, s_u = bandwidths(src)
282-
(d_l min(s_l,m-1) && d_u min(s_u,n-1)) || throw(BandError(dest))
283282

284-
f_x = Base.Fix2(f, x)
285283
_banded_broadcast_anylayout!(dest, src, f_x, z, (s_l, s_u), (d_l, d_u), m)
286284

287285
dest
@@ -291,12 +289,19 @@ function _banded_broadcast!(dest::AbstractMatrix, f, (x,src)::Tuple{Number,Abstr
291289
z = f(x, zero(T))
292290
iszero(z) || checkbroadcastband(dest, size(src), bandwidths(broadcasted(f, x,src)))
293291
m,n = size(dest)
292+
m == n == 0 && return dest
293+
294+
f_x = Base.Fix1(f, x)
295+
296+
# if dest has fewer bands (either above or below the diagonal) than dest,
297+
# then f.(x, src) must be zero in these bands
298+
checkzerobands(dest, f_x, src)
294299

295300
d_l, d_u = bandwidths(dest)
296301
s_l, s_u = bandwidths(src)
297-
(d_l min(s_l,m-1) && d_u min(s_u,n-1)) || throw(BandError(dest))
298302

299-
f_x = Base.Fix1(f, x)
303+
checkzerobands(dest, f, src)
304+
300305
_banded_broadcast_anylayout!(dest, src, f_x, z, (s_l, s_u), (d_l, d_u), m)
301306

302307
dest

test/test_broadcasting.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ import BandedMatrices: BandedStyle, BandedRows
4949
end
5050
end
5151

52+
function test_empty(f!)
53+
# the function f! must not set the RHS to zero for the error check to work
54+
@testset "empty dest" begin
55+
D = brand(4, 4, -2, 1) # empty
56+
B = brand(size(D)...,1,2) # non-empty and non-zero bands
57+
@test_throws BandError f!(D, B)
58+
f!(D, zero(B))
59+
@test all(iszero, D)
60+
end
61+
end
62+
5263
@testset "identity" begin
5364
n = 100
5465
A = brand(n,n,2,2)
@@ -87,6 +98,16 @@ import BandedMatrices: BandedStyle, BandedRows
8798
B = brand(n,n,1,-1)
8899
A .= B
89100
@test A == B
101+
102+
@testset "empty dest" begin
103+
test_empty((D,B) -> D .= B)
104+
end
105+
106+
@testset "adjtrans" begin
107+
@testset "empty dest" begin
108+
test_empty((D,B) -> D' .= B')
109+
end
110+
end
90111
end
91112

92113
@testset "lmul!/rmul!" begin
@@ -187,6 +208,11 @@ import BandedMatrices: BandedStyle, BandedRows
187208
B .= 2.0 .\ A
188209
@test B == 2.0 \ A == 2.0 \ Matrix(A)
189210

211+
@testset "empty dest" begin
212+
test_empty((D,B) -> D .= 2 .* B)
213+
test_empty((D,B) -> D .= B .* 2)
214+
end
215+
190216
@testset "trans-adj" begin
191217
A = brand(5,5,1,1)
192218
= copy(A)
@@ -207,6 +233,15 @@ import BandedMatrices: BandedStyle, BandedRows
207233
@test A == 2
208234
rmul!(A', 1/2)
209235
@test A ==
236+
237+
@testset "empty dest" begin
238+
test_empty((D,B) -> D' .= 2 .* B')
239+
test_empty((D,B) -> D' .= B' .* 2)
240+
test_empty((D,B) -> D' .= 2 .* B' .* 2)
241+
test_empty((D,B) -> D' .= 2 .* B)
242+
test_empty((D,B) -> D' .= B .* 2)
243+
test_empty((D,B) -> D' .= 2 .* B .* 2)
244+
end
210245
end
211246
end
212247

0 commit comments

Comments
 (0)