Skip to content

Commit aac8d7e

Browse files
authored
axpy! for BandedMatrix (#336)
* axpy for BandedMatrix * test for matching bandwidth * Add dimension mismatch tests
1 parent 331fe53 commit aac8d7e

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

src/banded/BandedMatrix.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,14 @@ end
830830

831831
bandedbroadcaststyle(_) = BandedStyle()
832832
BroadcastStyle(::Type{<:BandedMatrix{<:Any,Dat}}) where Dat = bandedbroadcaststyle(BroadcastStyle(Dat))
833+
834+
function banded_axpy!(a::Number, X::BandedMatrix, Y::BandedMatrix)
835+
bx = bandwidths(X)
836+
by = bandwidths(Y)
837+
if bx == by
838+
axpy!(a, X.data, Y.data)
839+
else
840+
banded_generic_axpy!(a, X, Y)
841+
end
842+
return Y
843+
end

src/generic/broadcast.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -932,10 +932,11 @@ _banded_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix, notbandedX, notba
932932
banded_dense_axpy!(a, X, Y)
933933

934934
# additions and subtractions
935-
@propagate_inbounds function banded_generic_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix)
935+
function banded_generic_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix)
936936
n,m = size(X)
937-
if (n,m) size(Y)
938-
throw(BoundsError())
937+
ny,my = size(Y)
938+
if (n,m) (ny,my)
939+
throw(DimensionMismatch("X has size $((n,m)) but $Y has size $((ny,my))"))
939940
end
940941
Xl, Xu = bandwidths(X)
941942
Yl, Yu = bandwidths(Y)
@@ -944,17 +945,17 @@ _banded_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix, notbandedX, notba
944945
return Y
945946
end
946947

947-
@boundscheck if Xl > Yl
948+
if Xl > Yl
948949
# test that all entries are zero in extra bands below the diagonal
949-
for j=rowsupport(X),k=max(1,j+Yl+1):min(j+Xl,n)
950+
@inbounds for j=rowsupport(X),k=max(1,j+Yl+1):min(j+Xl,n)
950951
if inbands_getindex(X, k, j) 0
951952
throw(BandError(Y, (k,j)))
952953
end
953954
end
954955
end
955-
@boundscheck if Xu > Yu
956+
if Xu > Yu
956957
# test that all entries are zero in extra bands above the diagonal
957-
for j=rowsupport(X),k=max(1,j-Xu):min(j-Yu-1,n)
958+
@inbounds for j=rowsupport(X),k=max(1,j-Xu):min(j-Yu-1,n)
958959
if inbands_getindex(X, k, j) 0
959960
throw(BandError(Y, (k,j)))
960961
end
@@ -976,8 +977,10 @@ _banded_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix, notbandedX, notba
976977
end
977978

978979
function banded_dense_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix)
979-
if size(X) != size(Y)
980-
throw(DimensionMismatch("+"))
980+
n,m = size(X)
981+
ny,my = size(Y)
982+
if (n,m) (ny,my)
983+
throw(DimensionMismatch("X has size $((n,m)) but $Y has size $((ny,my))"))
981984
end
982985
@inbounds for j=rowsupport(X), k=colrange(X,j)
983986
Y[k,j] += a*inbands_getindex(X,k,j)

test/test_broadcasting.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ import BandedMatrices: BandedStyle, BandedRows
312312
B .= 2.0 .* A .+ B
313313
@test B == C
314314

315+
# test with identical bandwidth
316+
@test axpy!(3, A, copy(A)) 4A
317+
315318
@testset "trivial cases" begin
316319
B = brand(2,4,-1,0) # no bands in B
317320
B2 = brand(2,4,0,-1) # no bands in B2
@@ -326,6 +329,10 @@ import BandedMatrices: BandedStyle, BandedRows
326329
axpy!(0.1, C, B) # no bands in dest, but src is zero
327330
@test B == D
328331
end
332+
333+
@test_throws DimensionMismatch axpy!(2, brand(2,2,1,1), brand(3,3,1,1))
334+
@test_throws DimensionMismatch axpy!(2, brand(2,2,1,1), brand(3,3,2,2))
335+
@test_throws DimensionMismatch axpy!(2, brand(2,2,1,1), zeros(3,3))
329336
end
330337

331338
@testset "gbmv!" begin

0 commit comments

Comments
 (0)