Skip to content

Commit 4ff8ff0

Browse files
Implement vcat(::AbstractBandedMatrix...) (#448)
* implement bandwidths for OneElement * make improvements * fix sparse(::SparseMatrixCSC) * fix bandwidths for SparseMatrixCSC, add for SparseVector * add bandwidths(::Zeros) behaviour for empty sparse structures * add unit tests * overload vcat(::AbstractBandedMatrix...) * style * include tests in runtests.jl * fix issue involving LazyBandedMatrices * fixed mistake * make improvements * add vcat between BandedMatrices and OneElements * fix issue involving calculation of bandwidths. Add unit tests for OneElement * fix issue involving bandwidths larger than dimensions * restore vcat * v1.7.4 --------- Co-authored-by: Sheehan Olver <[email protected]>
1 parent 57a70a5 commit 4ff8ff0

File tree

5 files changed

+96
-1
lines changed

5 files changed

+96
-1
lines changed

src/BandedMatrices.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import ArrayLayouts: AbstractTridiagonalLayout, BidiagonalLayout, BlasMatLdivVec
3434
symmetricuplo, transposelayout, triangulardata, triangularlayout, zero!,
3535
QRPackedQLayout, AdjQRPackedQLayout
3636

37-
import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector
37+
import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector, ZerosMatrix, ZerosVector
3838

3939
const libblas = LinearAlgebra.BLAS.libblas
4040
const liblapack = LinearAlgebra.BLAS.liblapack

src/generic/AbstractBandedMatrix.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,57 @@ function sum(A::AbstractBandedMatrix; dims=:)
401401
throw(ArgumentError("dimension must be ≥ 1, got $dims"))
402402
end
403403
end
404+
405+
###
406+
# vcat
407+
###
408+
409+
function LinearAlgebra.vcat(x::AbstractBandedMatrix...)
410+
#avoid unnecessary steps for singleton
411+
if length(x) == 1
412+
return x[1]
413+
end
414+
415+
#instantiate the returned banded matrix with zeros and required bandwidths/dimensions
416+
m = size(x[1], 2)
417+
l,u = -m, typemin(Int64)
418+
n = 0
419+
isempty = true
420+
421+
#Check for dimension error and calculate bandwidths
422+
for A in x
423+
if size(A, 2) != m
424+
sizes = Tuple(size(b, 2) for b in x)
425+
throw(DimensionMismatch("number of columns of each matrix must match (got $sizes)"))
426+
end
427+
428+
l_A, u_A = bandwidths(A)
429+
if l_A + u_A >= 0
430+
isempty = false
431+
u = max(u, min(m - 1, u_A) - n)
432+
l = max(l, min(size(A, 1) - 1, l_A) + n)
433+
end
434+
435+
n += size(A, 1)
436+
end
437+
438+
type = promote_type(eltype.(x)...)
439+
if isempty
440+
return BandedMatrix{type}(undef, (n, m), bandwidths(Zeros(1)))
441+
end
442+
ret = BandedMatrix(Zeros{type}(n, m), (l, u))
443+
444+
#Populate the banded matrix
445+
row_offset = 0
446+
for A in x
447+
n_A = size(A, 1)
448+
449+
for i = 1:n_A, j = rowrange(A, i)
450+
ret[row_offset + i, j] = A[i, j]
451+
end
452+
453+
row_offset += n_A
454+
end
455+
456+
ret
457+
end

src/interfaceimpl.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,6 @@ function getindex(D::Bidiagonal{T,V}, b::Band) where {T,V}
116116
D.uplo == 'U' && b.i == 1 && return copy(D.ev)
117117
convert(V, Zeros{T}(size(D,1)-abs(b.i)))
118118
end
119+
120+
121+
Base.vcat(x::Union{OneElement, ZerosMatrix, AdjOrTrans{<:Any,<:ZerosVector}, AbstractBandedMatrix}...) = vcat(BandedMatrix.(x)...)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ include("test_tribanded.jl")
2727
include("test_interface.jl")
2828
include("test_miscs.jl")
2929
include("test_sum.jl")
30+
include("test_cat.jl")

test/test_cat.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
module TestCat
2+
3+
using BandedMatrices, LinearAlgebra, Test, Random, FillArrays, SparseArrays
4+
5+
@testset "vcat" begin
6+
@testset "banded matrices" begin
7+
a = BandedMatrix(0 => 1:2)
8+
@test vcat(a) == a
9+
10+
b = BandedMatrix(0 => 1:3,-1 => 1:2, -2 => 1:1)
11+
@test_throws DimensionMismatch vcat(a,b)
12+
13+
c = BandedMatrix(0 => [1.0, 2.0, 3.0], 1 => [1.0, 2.0], 2 => [1.0])
14+
@test eltype(vcat(b, c)) == Float64
15+
@test vcat(b, c) == vcat(Matrix(b), Matrix(c))
16+
17+
for i in ((1,2), (-3,4), (0,-1))
18+
a = BandedMatrix(ones(Float64, rand(1:10), 5), i)
19+
b = BandedMatrix(ones(Int64, rand(1:10), 5), i)
20+
c = BandedMatrix(ones(Int32, rand(1:10), 5), i)
21+
d = vcat(a, b, c)
22+
sd = vcat(sparse(a), sparse(b), sparse(c))
23+
@test eltype(d) == Float64
24+
@test d == sd
25+
@test bandwidths(d) == bandwidths(sd)
26+
end
27+
end
28+
29+
@testset "one element" begin
30+
n = rand(3:20)
31+
x,y = OneElement(1, (1,1), (1,n)), OneElement(1, (1,n), (1,n))
32+
b = BandedMatrix((0 => ones(n-2), 1 => -2ones(n - 2), 2 => ones(n - 2)), (n-2, n))
33+
@test vcat(x,b,y) == Tridiagonal([ones(n - 2); 0], [1 ; -2ones(n - 2); 1], [0; ones(n - 2)])
34+
end
35+
end
36+
37+
end

0 commit comments

Comments
 (0)