Skip to content
11 changes: 8 additions & 3 deletions src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ julia> Bidiagonal(A, :L) # contains the main diagonal and first subdiagonal of A
⋅ ⋅ 4 4
```
"""
function Bidiagonal(A::AbstractMatrix, uplo::Symbol)
Bidiagonal(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo)
function (::Type{Bi})(A::AbstractMatrix, uplo::Symbol) where {Bi<:Bidiagonal}
Bi(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo)
end


Expand Down Expand Up @@ -220,7 +220,12 @@ promote_rule(::Type{<:Tridiagonal}, ::Type{<:Bidiagonal}) = Tridiagonal
AbstractMatrix{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A)
AbstractMatrix{T}(A::Bidiagonal{T}) where {T} = copy(A)

convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)::T
function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal
checksquare(A)
isbanded(A, -1, 1) || throw(InexactError(:convert, T, A))
iszero(diagview(A, 1)) ? T(A, :L) :
iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A))
end

similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo)
similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(B.dv, T, dims)
Expand Down
21 changes: 17 additions & 4 deletions src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,21 @@ julia> SymTridiagonal(B)
[1 2; 3 4] [1 2; 2 3]
```
"""
function SymTridiagonal(A::AbstractMatrix)
function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal}
checksquare(A)
du = diag(A, 1)
d = diag(A)
dl = diag(A, -1)
if all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
SymTridiagonal(d, du)
if _checksymmetric(d, du, dl)
SymTri(d, du)
else
throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal"))
end
end

_checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1))

SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S
SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} =
SymTridiagonal(convert(V, S.dv)::V, convert(V, S.ev)::V)
Expand All @@ -128,6 +131,11 @@ SymTridiagonal{T}(S::SymTridiagonal) where {T} =
convert(AbstractVector{T}, S.ev)::AbstractVector{T})
SymTridiagonal(S::SymTridiagonal) = S

function convert(::Type{T}, A::AbstractMatrix) where T<:SymTridiagonal
checksquare(A)
_checksymmetric(A) && isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A))
end

AbstractMatrix{T}(S::SymTridiagonal) where {T} = SymTridiagonal{T}(S)
AbstractMatrix{T}(S::SymTridiagonal{T}) where {T} = copy(S)

Expand Down Expand Up @@ -583,7 +591,7 @@ julia> Tridiagonal(A)
⋅ ⋅ 3 4
```
"""
Tridiagonal(A::AbstractMatrix) = Tridiagonal(diag(A,-1), diag(A,0), diag(A,1))
(::Type{Tri})(A::AbstractMatrix) where {Tri<:Tridiagonal} = Tri(diag(A,-1), diag(A,0), diag(A,1))

Tridiagonal(A::Tridiagonal) = A
Tridiagonal{T}(A::Tridiagonal{T}) where {T} = A
Expand All @@ -605,6 +613,11 @@ function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}}
end
end

function convert(::Type{T}, A::AbstractMatrix) where T<:Tridiagonal
checksquare(A)
isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A))
end

size(M::Tridiagonal) = (n = length(M.d); (n, n))
axes(M::Tridiagonal) = (ax = axes(M.d,1); (ax, ax))

Expand Down
17 changes: 17 additions & 0 deletions test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1158,4 +1158,21 @@ end
@test opnorm(B, Inf) == opnorm(Matrix(B), Inf)
end

@testset "convert to Bidiagonal" begin
M = diagm(0 => [1,2,3], 1=>[4,5])
B = convert(Bidiagonal, M)
@test B == Bidiagonal(M, :U)
M = diagm(0 => [1,2,3], -1=>[4,5])
B = convert(Bidiagonal, M)
@test B == Bidiagonal(M, :L)
B = convert(Bidiagonal{Int8}, M)
@test B == M
@test B isa Bidiagonal{Int8, Vector{Int8}}
B = convert(Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M)
@test B == M
@test B isa Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}
M = diagm(-1 => [1,2], 1=>[4,5])
@test_throws InexactError convert(Bidiagonal, M)
end

end # module TestBidiagonal
6 changes: 3 additions & 3 deletions test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Random.seed!(1)
@test Matrix(convert(newtype, A)) == Matrix(A)
end
for newtype in [Diagonal, Bidiagonal]
@test_throws ArgumentError convert(newtype,A)
@test_throws Union{ArgumentError,InexactError} convert(newtype,A)
end
A = SymTridiagonal(a, zeros(n-1))
@test Matrix(convert(Bidiagonal,A)) == Matrix(A)
Expand All @@ -57,7 +57,7 @@ Random.seed!(1)
@test Matrix(convert(newtype, A)) == Matrix(A)
end
for newtype in [Diagonal, Bidiagonal]
@test_throws ArgumentError convert(newtype,A)
@test_throws Union{ArgumentError,InexactError} convert(newtype,A)
end
A = Tridiagonal(zeros(n-1), [1.0:n;], fill(1., n-1)) #not morally Diagonal
@test Matrix(convert(Bidiagonal, A)) == Matrix(A)
Expand All @@ -79,7 +79,7 @@ Random.seed!(1)
end
A = UpperTriangular(triu(rand(n,n)))
for newtype in [Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal]
@test_throws ArgumentError convert(newtype,A)
@test_throws Union{ArgumentError,InexactError} convert(newtype,A)
end


Expand Down
34 changes: 34 additions & 0 deletions test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1099,4 +1099,38 @@ end
@test opnorm(S, Inf) == opnorm(Matrix(S), Inf)
end

@testset "convert to Tridiagonal/SymTridiagonal" begin
@testset "Tridiagonal" begin
for M in [diagm(0 => [1,2,3], 1=>[4,5]),
diagm(0 => [1,2,3], 1=>[4,5], -1=>[6,7]),
diagm(-1 => [1,2], 1=>[4,5])]
B = convert(Tridiagonal, M)
@test B == Tridiagonal(M)
B = convert(Tridiagonal{Int8}, M)
@test B == M
@test B isa Tridiagonal{Int8}
B = convert(Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M)
@test B == M
@test B isa Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}
end
@test_throws InexactError convert(Tridiagonal, fill(5, 4, 4))
end
@testset "SymTridiagonal" begin
for M in [diagm(0 => [1,2,3], 1=>[4,5], -1=>[4,5]),
diagm(0 => [1,2,3]),
diagm(-1 => [1,2], 1=>[1,2])]
B = convert(SymTridiagonal, M)
@test B == SymTridiagonal(M)
B = convert(SymTridiagonal{Int8}, M)
@test B == M
@test B isa SymTridiagonal{Int8}
B = convert(SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M)
@test B == M
@test B isa SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}
end
@test_throws InexactError convert(SymTridiagonal, fill(5, 4, 4))
@test_throws InexactError convert(SymTridiagonal, diagm(0=>fill(NaN,4)))
end
end

end # module TestTridiagonal