Skip to content

Commit 5d3ef46

Browse files
authored
Skip symmetry check in converting Symmetric to SymTridiagonal (#1269)
Since the symmetry check in the `SymTridiagonal` constructor is unnecessary if the argument is known to be symmetric from its type, we may skip it (as well as the extra allocation of the lower diagonal that is necessary for the check).
1 parent ae5385b commit 5d3ef46

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

src/symmetric.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ issymmetric(A::Hermitian{<:Real}) = true
441441
issymmetric(A::Hermitian{<:Complex}) = isreal(A)
442442
issymmetric(A::Symmetric) = true
443443

444+
# check if the symmetry is known from the type
445+
_issymmetric(::Union{SymSymTri, Hermitian{<:Real}}) = true
446+
_issymmetric(::Any) = false
447+
444448
adjoint(A::Hermitian) = A
445449
transpose(A::Symmetric) = A
446450
adjoint(A::Symmetric{<:Real}) = A

src/tridiag.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,14 @@ function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal}
111111
checksquare(A)
112112
du = diag(A, 1)
113113
d = diag(A)
114-
dl = diag(A, -1)
115-
if _checksymmetric(d, du, dl)
116-
SymTri(d, du)
117-
else
114+
if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, -1)))
118115
throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal"))
119116
end
117+
return SymTri(d, du)
120118
end
121119

122120
_checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
123-
_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1))
121+
_checksymmetric(A::AbstractMatrix) = _issymmetric(A) || _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1))
124122

125123
SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S
126124
SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} =

test/tridiag.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,4 +1172,16 @@ end
11721172
end
11731173
end
11741174

1175+
@testset "SymTridiagonal from Symmetric" begin
1176+
S = Symmetric(reshape(1:9, 3, 3))
1177+
@testset "helper functions" begin
1178+
@test LinearAlgebra._issymmetric(S)
1179+
@test !LinearAlgebra._issymmetric(Array(S))
1180+
end
1181+
ST = SymTridiagonal(S)
1182+
@test ST == SymTridiagonal(diag(S), diag(S,1))
1183+
S = Symmetric(Tridiagonal(1:3, 1:4, 1:3))
1184+
@test convert(SymTridiagonal, S) == S
1185+
end
1186+
11751187
end # module TestTridiagonal

0 commit comments

Comments
 (0)