From 5ded08eb301b68c11e2da3e94848ef2d5be21a25 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 16:10:38 +0530 Subject: [PATCH 01/11] Add `convert` for banded matrix types --- src/bidiag.jl | 7 +++++++ src/tridiag.jl | 20 ++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/bidiag.jl b/src/bidiag.jl index bb5b8830..741707b1 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -118,6 +118,13 @@ Bidiagonal(A::Bidiagonal) = A Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo) +function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal + checksquare(A) + isbanded(A, -1, 1) || throw(InexactError(:convert, T, A)) + iszero(diagview(A, 1)) ? T(A, :L) : + iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A)) +end + _offdiagind(uplo) = uplo == 'U' ? 1 : -1 @inline function Base.isassigned(A::Bidiagonal, i::Int, j::Int) diff --git a/src/tridiag.jl b/src/tridiag.jl index a24cc50b..732a1e07 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -109,16 +109,18 @@ julia> SymTridiagonal(B) """ function SymTridiagonal(A::AbstractMatrix) checksquare(A) - du = diag(A, 1) - d = diag(A) - dl = diag(A, -1) - if all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) + if _checksymmetric(A) + du = diag(A, 1) + d = diag(A) SymTridiagonal(d, du) else throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end end +_checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) +_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1)) + SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} = SymTridiagonal(convert(V, S.dv)::V, convert(V, S.ev)::V) @@ -128,6 +130,11 @@ SymTridiagonal{T}(S::SymTridiagonal) where {T} = convert(AbstractVector{T}, S.ev)::AbstractVector{T}) SymTridiagonal(S::SymTridiagonal) = S +function convert(::Type{T}, A::AbstractMatrix) where T<:SymTridiagonal + checksquare(A) + _checksymmetric(A) && isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A)) +end + AbstractMatrix{T}(S::SymTridiagonal) where {T} = SymTridiagonal{T}(S) AbstractMatrix{T}(S::SymTridiagonal{T}) where {T} = copy(S) @@ -605,6 +612,11 @@ function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}} end end +function convert(::Type{T}, A::AbstractMatrix) where T<:Tridiagonal + checksquare(A) + isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A)) +end + size(M::Tridiagonal) = (n = length(M.d); (n, n)) axes(M::Tridiagonal) = (ax = axes(M.d,1); (ax, ax)) From 357cc701aad3fa5cd049abba25d4cb40c968819c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 16:51:52 +0530 Subject: [PATCH 02/11] Replace duplicate `convert` method for `Bidiagonal` --- src/bidiag.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/bidiag.jl b/src/bidiag.jl index 741707b1..f83eb1e2 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -118,13 +118,6 @@ Bidiagonal(A::Bidiagonal) = A Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo) -function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal - checksquare(A) - isbanded(A, -1, 1) || throw(InexactError(:convert, T, A)) - iszero(diagview(A, 1)) ? T(A, :L) : - iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A)) -end - _offdiagind(uplo) = uplo == 'U' ? 1 : -1 @inline function Base.isassigned(A::Bidiagonal, i::Int, j::Int) @@ -227,7 +220,12 @@ promote_rule(::Type{<:Tridiagonal}, ::Type{<:Bidiagonal}) = Tridiagonal AbstractMatrix{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A) AbstractMatrix{T}(A::Bidiagonal{T}) where {T} = copy(A) -convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)::T +function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal + checksquare(A) + isbanded(A, -1, 1) || throw(InexactError(:convert, T, A)) + iszero(diagview(A, 1)) ? T(A, :L) : + iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A)) +end similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo) similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(B.dv, T, dims) From b064b72b1de8405d14ef5d52c78a7193c5bf54dd Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 17:53:12 +0530 Subject: [PATCH 03/11] Add tests for `convert` --- test/bidiag.jl | 11 +++++++++++ test/tridiag.jl | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/test/bidiag.jl b/test/bidiag.jl index a39fa027..ab45b063 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1158,4 +1158,15 @@ end @test opnorm(B, Inf) == opnorm(Matrix(B), Inf) end +@testset "convert to Bidiagonal" begin + M = diagm(0 => [1,2,3], 1=>[4,5]) + B = convert(Bidiagonal, M) + @test B == Bidiagonal(M, :U) + M = diagm(0 => [1,2,3], -1=>[4,5]) + B = convert(Bidiagonal, M) + @test B == Bidiagonal(M, :L) + M = diagm(-1 => [1,2], 1=>[4,5]) + @test_throws InexactError convert(Bidiagonal, M) +end + end # module TestBidiagonal diff --git a/test/tridiag.jl b/test/tridiag.jl index 4b592a87..f9ce3b1d 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1099,4 +1099,26 @@ end @test opnorm(S, Inf) == opnorm(Matrix(S), Inf) end +@testset "convert to Tridiagonal/SymTridiagonal" begin + @testset "Tridiagonal" begin + for M in [diagm(0 => [1,2,3], 1=>[4,5]), + diagm(0 => [1,2,3], 1=>[4,5], -1=>[6,7]), + diagm(-1 => [1,2], 1=>[4,5])] + B = convert(Tridiagonal, M) + @test B == Tridiagonal(M) + end + @test_throws InexactError convert(Tridiagonal, fill(5, 4, 4)) + end + @testset "SymTridiagonal" begin + for M in [diagm(0 => [1,2,3], 1=>[4,5], -1=>[4,5]), + diagm(0 => [1,2,3]), + diagm(-1 => [1,2], 1=>[1,2])] + B = convert(SymTridiagonal, M) + @test B == SymTridiagonal(M) + end + @test_throws InexactError convert(SymTridiagonal, fill(5, 4, 4)) + @test_throws InexactError convert(SymTridiagonal, diagm(0=>fill(NaN,4))) + end +end + end # module TestTridiagonal From e0f4bf11b2a6265839145e53321643a8985c3181 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 19:03:45 +0530 Subject: [PATCH 04/11] Generalize `Bidiagonal`/`Tridiagonal` constructors --- src/bidiag.jl | 4 ++-- src/diagonal.jl | 4 +--- src/tridiag.jl | 2 +- test/bidiag.jl | 6 ++++++ test/tridiag.jl | 6 ++++++ 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/bidiag.jl b/src/bidiag.jl index f83eb1e2..43f1336e 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -109,8 +109,8 @@ julia> Bidiagonal(A, :L) # contains the main diagonal and first subdiagonal of A ⋅ ⋅ 4 4 ``` """ -function Bidiagonal(A::AbstractMatrix, uplo::Symbol) - Bidiagonal(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo) +function (::Type{Bi})(A::AbstractMatrix, uplo::Symbol) where {Bi<:Bidiagonal} + Bi(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo) end diff --git a/src/diagonal.jl b/src/diagonal.jl index 9f8d54e5..f6e1b55e 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -99,9 +99,7 @@ julia> Diagonal(A) ⋅ 5 ``` """ -Diagonal(A::AbstractMatrix) = Diagonal(diag(A)) -Diagonal{T}(A::AbstractMatrix) where T = Diagonal{T}(diag(A)) -Diagonal{T,V}(A::AbstractMatrix) where {T,V<:AbstractVector{T}} = Diagonal{T,V}(diag(A)) +(::Type{D})(A::AbstractMatrix) where {D<:Diagonal} = D(diag(A)) function convert(::Type{T}, A::AbstractMatrix) where T<:Diagonal checksquare(A) isdiag(A) ? T(A) : throw(InexactError(:convert, T, A)) diff --git a/src/tridiag.jl b/src/tridiag.jl index 732a1e07..03495096 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -590,7 +590,7 @@ julia> Tridiagonal(A) ⋅ ⋅ 3 4 ``` """ -Tridiagonal(A::AbstractMatrix) = Tridiagonal(diag(A,-1), diag(A,0), diag(A,1)) +(::Type{Tri})(A::AbstractMatrix) where {Tri<:Tridiagonal} = Tri(diag(A,-1), diag(A,0), diag(A,1)) Tridiagonal(A::Tridiagonal) = A Tridiagonal{T}(A::Tridiagonal{T}) where {T} = A diff --git a/test/bidiag.jl b/test/bidiag.jl index ab45b063..b331f666 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1165,6 +1165,12 @@ end M = diagm(0 => [1,2,3], -1=>[4,5]) B = convert(Bidiagonal, M) @test B == Bidiagonal(M, :L) + B = convert(Bidiagonal{Int8}, M) + @test B == M + @test B isa Bidiagonal{Int8, Vector{Int8}} + B = convert(Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M) + @test B == M + @test B isa Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}} M = diagm(-1 => [1,2], 1=>[4,5]) @test_throws InexactError convert(Bidiagonal, M) end diff --git a/test/tridiag.jl b/test/tridiag.jl index f9ce3b1d..4f4a80ce 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1106,6 +1106,12 @@ end diagm(-1 => [1,2], 1=>[4,5])] B = convert(Tridiagonal, M) @test B == Tridiagonal(M) + B = convert(Tridiagonal{Int8}, M) + @test B == M + @test B isa Tridiagonal{Int8} + B = convert(Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M) + @test B == M + @test B isa Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}} end @test_throws InexactError convert(Tridiagonal, fill(5, 4, 4)) end From 327af4a2e8e7d0b809e49cf34aa0b7f706110488 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 19:06:33 +0530 Subject: [PATCH 05/11] Generalize error type in special tests --- test/special.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/special.jl b/test/special.jl index ac76279f..abbfb303 100644 --- a/test/special.jl +++ b/test/special.jl @@ -79,7 +79,7 @@ Random.seed!(1) end A = UpperTriangular(triu(rand(n,n))) for newtype in [Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal] - @test_throws ArgumentError convert(newtype,A) + @test_throws Union{ArgumentError,InexactError} convert(newtype,A) end From c0e7f985871790ef23a632353ffcad14299ff3f5 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 19:37:05 +0530 Subject: [PATCH 06/11] Restore `Diagonal` constructors --- src/diagonal.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diagonal.jl b/src/diagonal.jl index f6e1b55e..9f8d54e5 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -99,7 +99,9 @@ julia> Diagonal(A) ⋅ 5 ``` """ -(::Type{D})(A::AbstractMatrix) where {D<:Diagonal} = D(diag(A)) +Diagonal(A::AbstractMatrix) = Diagonal(diag(A)) +Diagonal{T}(A::AbstractMatrix) where T = Diagonal{T}(diag(A)) +Diagonal{T,V}(A::AbstractMatrix) where {T,V<:AbstractVector{T}} = Diagonal{T,V}(diag(A)) function convert(::Type{T}, A::AbstractMatrix) where T<:Diagonal checksquare(A) isdiag(A) ? T(A) : throw(InexactError(:convert, T, A)) From 676db8989aa7b7868406cb6a681a06bb2f9661c2 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 20:09:59 +0530 Subject: [PATCH 07/11] Use `diag` in `SymTridiagonal` constructor --- src/tridiag.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tridiag.jl b/src/tridiag.jl index 03495096..f98ae33e 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -109,9 +109,10 @@ julia> SymTridiagonal(B) """ function SymTridiagonal(A::AbstractMatrix) checksquare(A) - if _checksymmetric(A) - du = diag(A, 1) - d = diag(A) + du = diag(A, 1) + d = diag(A) + dl = diag(A,-1) + if _checksymmetric(d, du, dl) SymTridiagonal(d, du) else throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) From e63238135484330a0d35a1318f1348eeddca24af Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 20:33:18 +0530 Subject: [PATCH 08/11] whitespace --- src/tridiag.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tridiag.jl b/src/tridiag.jl index f98ae33e..90600ec0 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -111,7 +111,7 @@ function SymTridiagonal(A::AbstractMatrix) checksquare(A) du = diag(A, 1) d = diag(A) - dl = diag(A,-1) + dl = diag(A,-1) if _checksymmetric(d, du, dl) SymTridiagonal(d, du) else From 6868e52c70bbbd1299a305cfe17a16b94fe29f6c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 20:34:01 +0530 Subject: [PATCH 09/11] whitespace 2 --- src/tridiag.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tridiag.jl b/src/tridiag.jl index 90600ec0..6d637a36 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -111,7 +111,7 @@ function SymTridiagonal(A::AbstractMatrix) checksquare(A) du = diag(A, 1) d = diag(A) - dl = diag(A,-1) + dl = diag(A, -1) if _checksymmetric(d, du, dl) SymTridiagonal(d, du) else From 91886b1de9f12d7797a72ca14f1ae8835e5fc70a Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 19 Feb 2025 23:04:40 +0530 Subject: [PATCH 10/11] Relax more error types in special tests --- test/special.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/special.jl b/test/special.jl index abbfb303..d1d4cbd6 100644 --- a/test/special.jl +++ b/test/special.jl @@ -43,7 +43,7 @@ Random.seed!(1) @test Matrix(convert(newtype, A)) == Matrix(A) end for newtype in [Diagonal, Bidiagonal] - @test_throws ArgumentError convert(newtype,A) + @test_throws Union{ArgumentError,InexactError} convert(newtype,A) end A = SymTridiagonal(a, zeros(n-1)) @test Matrix(convert(Bidiagonal,A)) == Matrix(A) @@ -57,7 +57,7 @@ Random.seed!(1) @test Matrix(convert(newtype, A)) == Matrix(A) end for newtype in [Diagonal, Bidiagonal] - @test_throws ArgumentError convert(newtype,A) + @test_throws Union{ArgumentError,InexactError} convert(newtype,A) end A = Tridiagonal(zeros(n-1), [1.0:n;], fill(1., n-1)) #not morally Diagonal @test Matrix(convert(Bidiagonal, A)) == Matrix(A) From 6f4b663542d41a70d875c62ee8f69a16cbb17aa5 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 28 Feb 2025 10:18:59 +0100 Subject: [PATCH 11/11] extend to symtridiagonal case --- src/tridiag.jl | 4 ++-- test/tridiag.jl | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tridiag.jl b/src/tridiag.jl index 6d637a36..f1c3b0e4 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -107,13 +107,13 @@ julia> SymTridiagonal(B) [1 2; 3 4] [1 2; 2 3] ``` """ -function SymTridiagonal(A::AbstractMatrix) +function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal} checksquare(A) du = diag(A, 1) d = diag(A) dl = diag(A, -1) if _checksymmetric(d, du, dl) - SymTridiagonal(d, du) + SymTri(d, du) else throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end diff --git a/test/tridiag.jl b/test/tridiag.jl index 4f4a80ce..08e16f3d 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1121,6 +1121,12 @@ end diagm(-1 => [1,2], 1=>[1,2])] B = convert(SymTridiagonal, M) @test B == SymTridiagonal(M) + B = convert(SymTridiagonal{Int8}, M) + @test B == M + @test B isa SymTridiagonal{Int8} + B = convert(SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M) + @test B == M + @test B isa SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}} end @test_throws InexactError convert(SymTridiagonal, fill(5, 4, 4)) @test_throws InexactError convert(SymTridiagonal, diagm(0=>fill(NaN,4)))