Skip to content

Commit bcedda1

Browse files
authored
Rowsupport in banded axpy methods (#316)
* short-circuit in axpy * rowsupport in banded_dense_axpy
1 parent acd0ab5 commit bcedda1

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

src/generic/broadcast.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -951,27 +951,36 @@ _banded_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix, notbandedX, notba
951951
Xl, Xu = bandwidths(X)
952952
Yl, Yu = bandwidths(Y)
953953

954+
if -Xl > Xu #= no bands in X =#
955+
return Y
956+
end
957+
954958
@boundscheck if Xl > Yl
955-
# test that all entries are zero in extra bands
956-
for j=1:size(X,2),k=max(1,j+Yl+1):min(j+Xl,n)
959+
# test that all entries are zero in extra bands below the diagonal
960+
for j=rowsupport(X),k=max(1,j+Yl+1):min(j+Xl,n)
957961
if inbands_getindex(X, k, j) 0
958-
throw(BandError(X, (k,j)))
962+
throw(BandError(Y, (k,j)))
959963
end
960964
end
961965
end
962966
@boundscheck if Xu > Yu
963-
# test that all entries are zero in extra bands
964-
for j=1:size(X,2),k=max(1,j-Xu):min(j-Yu-1,n)
967+
# test that all entries are zero in extra bands above the diagonal
968+
for j=rowsupport(X),k=max(1,j-Xu):min(j-Yu-1,n)
965969
if inbands_getindex(X, k, j) 0
966-
throw(BandError(X, (k,j)))
970+
throw(BandError(Y, (k,j)))
967971
end
968972
end
969973
end
970974

975+
if -Yl > Yu #= no bands in Y =#
976+
return Y
977+
end
978+
979+
# only fill overlapping bands
971980
l = min(Xl,Yl)
972981
u = min(Xu,Yu)
973982

974-
@inbounds for j=1:m,k=max(1,j-u):min(n,j+l)
983+
@inbounds for j=rowsupport(X), k=max(1,j-u):min(n,j+l)
975984
inbands_setindex!(Y, a*inbands_getindex(X,k,j) + inbands_getindex(Y,k,j) ,k, j)
976985
end
977986
Y
@@ -981,7 +990,7 @@ function banded_dense_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix)
981990
if size(X) != size(Y)
982991
throw(DimensionMismatch("+"))
983992
end
984-
@inbounds for j=1:size(X,2),k=colrange(X,j)
993+
@inbounds for j=rowsupport(X), k=colrange(X,j)
985994
Y[k,j] += a*inbands_getindex(X,k,j)
986995
end
987996
Y

test/test_broadcasting.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,21 @@ import BandedMatrices: BandedStyle, BandedRows
306306
@test bandwidths(2A+B) == bandwidths(2.0.*A .+ B) == (2,2)
307307
B .= 2.0 .* A .+ B
308308
@test B == C
309+
310+
@testset "trivial cases" begin
311+
B = brand(2,4,-1,0) # no bands in B
312+
B2 = brand(2,4,0,-1) # no bands in B2
313+
C = brand(size(B)...,1,1)
314+
D = copy(C)
315+
axpy!(0.1, B, C) # no bands in src
316+
@test C == D
317+
@test_throws BandError axpy!(0.1, C, B)
318+
@test_throws BandError axpy!(0.1, C, B2)
319+
D = copy(B)
320+
C .= 0
321+
axpy!(0.1, C, B) # no bands in dest, but src is zero
322+
@test B == D
323+
end
309324
end
310325

311326
@testset "gbmv!" begin

0 commit comments

Comments
 (0)