From 37090e052fa2a705b25bd459930953e80619517c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 1 Sep 2025 17:51:40 -0400 Subject: [PATCH] Add `isstoredband` to check if a band is stored as a vector --- src/bidiag.jl | 2 ++ src/dense.jl | 12 ++++++++++++ src/diagonal.jl | 3 +++ src/hessenberg.jl | 2 ++ src/triangular.jl | 5 +++++ src/tridiag.jl | 2 ++ test/hessenberg.jl | 27 +++++++++++++++++++++++++++ test/triangular.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 96 insertions(+) diff --git a/src/bidiag.jl b/src/bidiag.jl index cc5e6de7..63600f08 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -435,6 +435,8 @@ function diag(M::Bidiagonal, n::Integer=0) return v end +isstoredband(A::Bidiagonal, k::Integer) = k == 0 || k == _offdiagind(A.uplo) + function +(A::Bidiagonal, B::Bidiagonal) if A.uplo == B.uplo || length(A.dv) == 0 Bidiagonal(A.dv+B.dv, A.ev+B.ev, A.uplo) diff --git a/src/dense.jl b/src/dense.jl index 5fa28f85..865d3e0c 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -237,6 +237,18 @@ end fillstored!(A::AbstractMatrix, v) = fill!(A, v) +""" + isstoredband(A::AbstractMatrix, k::Integer) + +Return whether the `k`-th band of `A` is stored as a vector, as opposed to +being generated during indexing. +For example, only the principal diagonal would be stored for a `Diagonal`. + +!!! note + This is a conservative check that may have false negatives but should not have false positives. +""" +isstoredband(A::AbstractMatrix, k::Integer) = false + diagind(m::Integer, n::Integer, k::Integer=0) = diagind(IndexLinear(), m, n, k) diagind(::IndexLinear, m::Integer, n::Integer, k::Integer=0) = k <= 0 ? range(1-k, step=m+1, length=min(m+k, n)) : range(k*m+1, step=m+1, length=min(m, n-k)) diff --git a/src/diagonal.jl b/src/diagonal.jl index e60fb009..e4f76685 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -940,6 +940,9 @@ function diag(D::Diagonal, k::Integer=0) end return v end + +isstoredband(::Diagonal, k::Integer) = k == 0 + tr(D::Diagonal) = sum(tr, D.diag) det(D::Diagonal) = prod(det, D.diag) function logdet(D::Diagonal{<:Complex}) # make sure branch cut is correct diff --git a/src/hessenberg.jl b/src/hessenberg.jl index 7aab1c77..710f41be 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -667,3 +667,5 @@ function logdet(F::Hessenberg) d,s = logabsdet(F) return d + log(s) end + +isstoredband(U::UpperHessenberg, k::Integer) = k >= -1 && isstoredband(parent(U), k) \ No newline at end of file diff --git a/src/triangular.jl b/src/triangular.jl index d82ddd87..05f55dac 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -238,6 +238,11 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) = Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) = _shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false +isstoredband(U::UpperTriangular, k::Integer) = k >= 0 && isstoredband(parent(U), k) +isstoredband(L::LowerTriangular, k::Integer) = k <= 0 && isstoredband(parent(L), k) +isstoredband(U::UnitUpperTriangular, k::Integer) = k > 0 && isstoredband(parent(U), k) +isstoredband(L::UnitLowerTriangular, k::Integer) = k < 0 && isstoredband(parent(L), k) + @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} if _shouldforwardindex(A, i, j) A.data[i,j] diff --git a/src/tridiag.jl b/src/tridiag.jl index a0e3d821..4664d93b 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -1230,3 +1230,5 @@ function fillband!(T::SymTridiagonal, x, l, u) end return T end + +isstoredband(::Union{SymTridiagonal, Tridiagonal}, k::Integer) = -1 <= k <= 1 diff --git a/test/hessenberg.jl b/test/hessenberg.jl index 5b61c798..871c9d9b 100644 --- a/test/hessenberg.jl +++ b/test/hessenberg.jl @@ -320,4 +320,31 @@ end @test U == U2 end +@testset "isstoredband" begin + U = UpperHessenberg(Diagonal(1:4)) + @test LinearAlgebra.isstoredband(U, 0) + @test !LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, -1) + + U = UpperHessenberg(Bidiagonal(1:4, 1:3, :U)) + @test LinearAlgebra.isstoredband(U, 0) + @test LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, 2) + @test !LinearAlgebra.isstoredband(U, -1) + + U = UpperHessenberg(Bidiagonal(1:4, 1:3, :L)) + @test LinearAlgebra.isstoredband(U, 0) + @test LinearAlgebra.isstoredband(U, -1) + @test !LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, 2) + + for Tri in (Tridiagonal(1:3, 1:4, 1:3), SymTridiagonal(1:4, 1:3)) + U = UpperHessenberg(Tri) + @test LinearAlgebra.isstoredband(U, 0) + @test LinearAlgebra.isstoredband(U, 1) + @test LinearAlgebra.isstoredband(U, -1) + @test !LinearAlgebra.isstoredband(U, 2) + end +end + end # module TestHessenberg diff --git a/test/triangular.jl b/test/triangular.jl index e823f698..b60dc767 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -1110,4 +1110,47 @@ end end end +@testset "isstoredband" begin + U = UpperTriangular(Diagonal(1:4)) + @test LinearAlgebra.isstoredband(U, 0) + @test !LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, -1) + L = LowerTriangular(Diagonal(1:4)) + @test LinearAlgebra.isstoredband(L, 0) + @test !LinearAlgebra.isstoredband(L, 1) + @test !LinearAlgebra.isstoredband(L, -1) + for T in (UnitUpperTriangular, UnitLowerTriangular) + U = T(Diagonal(1:4)) + @test !LinearAlgebra.isstoredband(U, 0) + @test !LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, -1) + end + + U = UpperTriangular(Bidiagonal(1:4, 1:3, :U)) + @test LinearAlgebra.isstoredband(U, 0) + @test LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, 2) + @test !LinearAlgebra.isstoredband(U, -1) + + U = UpperTriangular(Bidiagonal(1:4, 1:3, :L)) + @test LinearAlgebra.isstoredband(U, 0) + @test !LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, 2) + @test !LinearAlgebra.isstoredband(U, -1) + + for Tri in (Tridiagonal(1:3, 1:4, 1:3), SymTridiagonal(1:4, 1:3)) + U = UpperTriangular(Tri) + @test LinearAlgebra.isstoredband(U, 0) + @test LinearAlgebra.isstoredband(U, 1) + @test !LinearAlgebra.isstoredband(U, 2) + @test !LinearAlgebra.isstoredband(U, -1) + + L = LowerTriangular(Tri) + @test LinearAlgebra.isstoredband(L, 0) + @test LinearAlgebra.isstoredband(L, -1) + @test !LinearAlgebra.isstoredband(L, -2) + @test !LinearAlgebra.isstoredband(L, 1) + end +end + end # module TestTriangular