Skip to content

Commit 5ded08e

Browse files
committed
Add convert for banded matrix types
1 parent b464203 commit 5ded08e

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/bidiag.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ Bidiagonal(A::Bidiagonal) = A
118118
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
119119
Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo)
120120

121+
function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal
122+
checksquare(A)
123+
isbanded(A, -1, 1) || throw(InexactError(:convert, T, A))
124+
iszero(diagview(A, 1)) ? T(A, :L) :
125+
iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A))
126+
end
127+
121128
_offdiagind(uplo) = uplo == 'U' ? 1 : -1
122129

123130
@inline function Base.isassigned(A::Bidiagonal, i::Int, j::Int)

src/tridiag.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,18 @@ julia> SymTridiagonal(B)
109109
"""
110110
function SymTridiagonal(A::AbstractMatrix)
111111
checksquare(A)
112-
du = diag(A, 1)
113-
d = diag(A)
114-
dl = diag(A, -1)
115-
if all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
112+
if _checksymmetric(A)
113+
du = diag(A, 1)
114+
d = diag(A)
116115
SymTridiagonal(d, du)
117116
else
118117
throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal"))
119118
end
120119
end
121120

121+
_checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
122+
_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1))
123+
122124
SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S
123125
SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} =
124126
SymTridiagonal(convert(V, S.dv)::V, convert(V, S.ev)::V)
@@ -128,6 +130,11 @@ SymTridiagonal{T}(S::SymTridiagonal) where {T} =
128130
convert(AbstractVector{T}, S.ev)::AbstractVector{T})
129131
SymTridiagonal(S::SymTridiagonal) = S
130132

133+
function convert(::Type{T}, A::AbstractMatrix) where T<:SymTridiagonal
134+
checksquare(A)
135+
_checksymmetric(A) && isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A))
136+
end
137+
131138
AbstractMatrix{T}(S::SymTridiagonal) where {T} = SymTridiagonal{T}(S)
132139
AbstractMatrix{T}(S::SymTridiagonal{T}) where {T} = copy(S)
133140

@@ -605,6 +612,11 @@ function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}}
605612
end
606613
end
607614

615+
function convert(::Type{T}, A::AbstractMatrix) where T<:Tridiagonal
616+
checksquare(A)
617+
isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A))
618+
end
619+
608620
size(M::Tridiagonal) = (n = length(M.d); (n, n))
609621
axes(M::Tridiagonal) = (ax = axes(M.d,1); (ax, ax))
610622

0 commit comments

Comments
 (0)