diff --git a/src/bidiag.jl b/src/bidiag.jl index bb5b8830..43f1336e 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -109,8 +109,8 @@ julia> Bidiagonal(A, :L) # contains the main diagonal and first subdiagonal of A ⋅ ⋅ 4 4 ``` """ -function Bidiagonal(A::AbstractMatrix, uplo::Symbol) - Bidiagonal(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo) +function (::Type{Bi})(A::AbstractMatrix, uplo::Symbol) where {Bi<:Bidiagonal} + Bi(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo) end @@ -220,7 +220,12 @@ promote_rule(::Type{<:Tridiagonal}, ::Type{<:Bidiagonal}) = Tridiagonal AbstractMatrix{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A) AbstractMatrix{T}(A::Bidiagonal{T}) where {T} = copy(A) -convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)::T +function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal + checksquare(A) + isbanded(A, -1, 1) || throw(InexactError(:convert, T, A)) + iszero(diagview(A, 1)) ? T(A, :L) : + iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A)) +end similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo) similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(B.dv, T, dims) diff --git a/src/tridiag.jl b/src/tridiag.jl index a24cc50b..f1c3b0e4 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -107,18 +107,21 @@ julia> SymTridiagonal(B) [1 2; 3 4] [1 2; 2 3] ``` """ -function SymTridiagonal(A::AbstractMatrix) +function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal} checksquare(A) du = diag(A, 1) d = diag(A) dl = diag(A, -1) - if all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) - SymTridiagonal(d, du) + if _checksymmetric(d, du, dl) + SymTri(d, du) else throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end end +_checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) +_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1)) + SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} = SymTridiagonal(convert(V, S.dv)::V, convert(V, S.ev)::V) @@ -128,6 +131,11 @@ SymTridiagonal{T}(S::SymTridiagonal) where {T} = convert(AbstractVector{T}, S.ev)::AbstractVector{T}) SymTridiagonal(S::SymTridiagonal) = S +function convert(::Type{T}, A::AbstractMatrix) where T<:SymTridiagonal + checksquare(A) + _checksymmetric(A) && isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A)) +end + AbstractMatrix{T}(S::SymTridiagonal) where {T} = SymTridiagonal{T}(S) AbstractMatrix{T}(S::SymTridiagonal{T}) where {T} = copy(S) @@ -583,7 +591,7 @@ julia> Tridiagonal(A) ⋅ ⋅ 3 4 ``` """ -Tridiagonal(A::AbstractMatrix) = Tridiagonal(diag(A,-1), diag(A,0), diag(A,1)) +(::Type{Tri})(A::AbstractMatrix) where {Tri<:Tridiagonal} = Tri(diag(A,-1), diag(A,0), diag(A,1)) Tridiagonal(A::Tridiagonal) = A Tridiagonal{T}(A::Tridiagonal{T}) where {T} = A @@ -605,6 +613,11 @@ function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}} end end +function convert(::Type{T}, A::AbstractMatrix) where T<:Tridiagonal + checksquare(A) + isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A)) +end + size(M::Tridiagonal) = (n = length(M.d); (n, n)) axes(M::Tridiagonal) = (ax = axes(M.d,1); (ax, ax)) diff --git a/test/bidiag.jl b/test/bidiag.jl index a39fa027..b331f666 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1158,4 +1158,21 @@ end @test opnorm(B, Inf) == opnorm(Matrix(B), Inf) end +@testset "convert to Bidiagonal" begin + M = diagm(0 => [1,2,3], 1=>[4,5]) + B = convert(Bidiagonal, M) + @test B == Bidiagonal(M, :U) + M = diagm(0 => [1,2,3], -1=>[4,5]) + B = convert(Bidiagonal, M) + @test B == Bidiagonal(M, :L) + B = convert(Bidiagonal{Int8}, M) + @test B == M + @test B isa Bidiagonal{Int8, Vector{Int8}} + B = convert(Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M) + @test B == M + @test B isa Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}} + M = diagm(-1 => [1,2], 1=>[4,5]) + @test_throws InexactError convert(Bidiagonal, M) +end + end # module TestBidiagonal diff --git a/test/special.jl b/test/special.jl index ac76279f..d1d4cbd6 100644 --- a/test/special.jl +++ b/test/special.jl @@ -43,7 +43,7 @@ Random.seed!(1) @test Matrix(convert(newtype, A)) == Matrix(A) end for newtype in [Diagonal, Bidiagonal] - @test_throws ArgumentError convert(newtype,A) + @test_throws Union{ArgumentError,InexactError} convert(newtype,A) end A = SymTridiagonal(a, zeros(n-1)) @test Matrix(convert(Bidiagonal,A)) == Matrix(A) @@ -57,7 +57,7 @@ Random.seed!(1) @test Matrix(convert(newtype, A)) == Matrix(A) end for newtype in [Diagonal, Bidiagonal] - @test_throws ArgumentError convert(newtype,A) + @test_throws Union{ArgumentError,InexactError} convert(newtype,A) end A = Tridiagonal(zeros(n-1), [1.0:n;], fill(1., n-1)) #not morally Diagonal @test Matrix(convert(Bidiagonal, A)) == Matrix(A) @@ -79,7 +79,7 @@ Random.seed!(1) end A = UpperTriangular(triu(rand(n,n))) for newtype in [Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal] - @test_throws ArgumentError convert(newtype,A) + @test_throws Union{ArgumentError,InexactError} convert(newtype,A) end diff --git a/test/tridiag.jl b/test/tridiag.jl index 4b592a87..08e16f3d 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1099,4 +1099,38 @@ end @test opnorm(S, Inf) == opnorm(Matrix(S), Inf) end +@testset "convert to Tridiagonal/SymTridiagonal" begin + @testset "Tridiagonal" begin + for M in [diagm(0 => [1,2,3], 1=>[4,5]), + diagm(0 => [1,2,3], 1=>[4,5], -1=>[6,7]), + diagm(-1 => [1,2], 1=>[4,5])] + B = convert(Tridiagonal, M) + @test B == Tridiagonal(M) + B = convert(Tridiagonal{Int8}, M) + @test B == M + @test B isa Tridiagonal{Int8} + B = convert(Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M) + @test B == M + @test B isa Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}} + end + @test_throws InexactError convert(Tridiagonal, fill(5, 4, 4)) + end + @testset "SymTridiagonal" begin + for M in [diagm(0 => [1,2,3], 1=>[4,5], -1=>[4,5]), + diagm(0 => [1,2,3]), + diagm(-1 => [1,2], 1=>[1,2])] + B = convert(SymTridiagonal, M) + @test B == SymTridiagonal(M) + B = convert(SymTridiagonal{Int8}, M) + @test B == M + @test B isa SymTridiagonal{Int8} + B = convert(SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M) + @test B == M + @test B isa SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}} + end + @test_throws InexactError convert(SymTridiagonal, fill(5, 4, 4)) + @test_throws InexactError convert(SymTridiagonal, diagm(0=>fill(NaN,4))) + end +end + end # module TestTridiagonal