Skip to content

Commit 631883e

Browse files
authored
Avoid recursion in _banded_muladd! for vectors (#293)
* Avoid recursion in _banded_muladd for vector * Handle zero-sized arrays in _banded_gbmv * Add comma in comment * Add tests * condense _banded_gbmv! branches
1 parent 663514f commit 631883e

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/generic/matmul.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,18 @@ banded_gbmv!(tA, α, A, x, β, y) =
2626

2727

2828
@inline function _banded_gbmv!(tA, α, A, x, β, y)
29-
if x y
30-
banded_gbmv!(tA, α, A, copy(x), β, y)
29+
#= Some BLAS implementations throw warnings
30+
with zero-sized arrays, so we handle
31+
these cases separately.
32+
=#
33+
length(y) == 0 && return y
34+
if length(x) == 0
35+
_fill_lmul!(β, y)
3136
else
32-
banded_gbmv!(tA, α, A, x, β, y)
37+
xc = x y ? copy(x) : x
38+
banded_gbmv!(tA, α, A, xc, β, y)
3339
end
40+
return y
3441
end
3542

3643

@@ -43,11 +50,15 @@ function _banded_muladd!(α::T, A, x::AbstractVector, β, y) where T
4350
l, u = bandwidths(A)
4451
if -l > u # no bands
4552
_fill_lmul!(β, y)
46-
elseif l < 0
47-
_banded_muladd!(α, view(A, :, 1-l:n), view(x, 1-l:n), β, y)
48-
elseif u < 0
53+
elseif l < 0 # with u >= -l > 0, that is, all bands lie above the diagonal
54+
# E.g. (l,u) = (-1,2)
55+
# set lview = 0 and uview = u + l >= 0
56+
_banded_gbmv!('N', α, view(A, :, 1-l:n), view(x, 1-l:n), β, y)
57+
elseif u < 0 # with -l <= u < 0, that is, all bands lie below the diagnoal.
58+
# E.g. (l,u) = (2,-1)
59+
# set lview = l + u >= 0 and uview = 0
4960
y[1:-u] .= zero(T)
50-
_banded_muladd!(α, view(A, 1-u:m, :), x, β, view(y, 1-u:m))
61+
_banded_gbmv!('N', α, view(A, 1-u:m, :), x, β, view(y, 1-u:m))
5162
y
5263
else
5364
_banded_gbmv!('N', α, A, x, β, y)

test/test_banded.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ Base.similar(::MyMatrix, ::Type{T}, m::Int, n::Int) where T = MyMatrix{T}(undef,
111111
@test A*v Matrix(A)*v
112112
@test A'*w Matrix(A)'*w
113113
end
114+
115+
@testset "empty" begin
116+
let B=BandedMatrix((0=>ones(0),), (10,0)), v = ones(size(B,2))
117+
@test B * v == zeros(size(B,1))
118+
end
119+
let B=BandedMatrix((0=>ones(0),), (0,10)), v = ones(size(B,2))
120+
@test B * v == zeros(size(B,1))
121+
end
122+
end
114123
end
115124

116125
@testset "Banded * Dense" begin

0 commit comments

Comments
 (0)