diff --git a/src/symmetric.jl b/src/symmetric.jl index 98f3429d..09819325 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -441,6 +441,10 @@ issymmetric(A::Hermitian{<:Real}) = true issymmetric(A::Hermitian{<:Complex}) = isreal(A) issymmetric(A::Symmetric) = true +# check if the symmetry is known from the type +_issymmetric(::Union{SymSymTri, Hermitian{<:Real}}) = true +_issymmetric(::Any) = false + adjoint(A::Hermitian) = A transpose(A::Symmetric) = A adjoint(A::Symmetric{<:Real}) = A diff --git a/src/tridiag.jl b/src/tridiag.jl index 10c43a2b..226b9417 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -111,16 +111,14 @@ function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal} checksquare(A) du = diag(A, 1) d = diag(A) - dl = diag(A, -1) - if _checksymmetric(d, du, dl) - SymTri(d, du) - else + if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, -1))) throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end + return SymTri(d, du) end _checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) -_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1)) +_checksymmetric(A::AbstractMatrix) = _issymmetric(A) || _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1)) SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} = diff --git a/test/tridiag.jl b/test/tridiag.jl index f44df3f7..849dfa17 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1172,4 +1172,16 @@ end end end +@testset "SymTridiagonal from Symmetric" begin + S = Symmetric(reshape(1:9, 3, 3)) + @testset "helper functions" begin + @test LinearAlgebra._issymmetric(S) + @test !LinearAlgebra._issymmetric(Array(S)) + end + ST = SymTridiagonal(S) + @test ST == SymTridiagonal(diag(S), diag(S,1)) + S = Symmetric(Tridiagonal(1:3, 1:4, 1:3)) + @test convert(SymTridiagonal, S) == S +end + end # module TestTridiagonal