Skip to content
3 changes: 3 additions & 0 deletions src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ function triu!(M::Bidiagonal{T}, k::Integer=0) where T
return M
end

diag(M::Bidiagonal, ::Val{0}) = M.dv
diag(M::Bidiagonal, ::Val{1}) = M.uplo == 'U' ? M.ev : zero(M.ev)
diag(M::Bidiagonal, ::Val{-1}) = M.uplo == 'L' ? M.ev : zero(M.ev)
function diag(M::Bidiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down
16 changes: 16 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,22 @@ julia> diag(A,1)
"""
diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))]

"""
diag(M, ::Val{k}) where {k}

Return the `k`th diagonal of a matrix as a vector.
For banded matrix types such as `Diagonal`, this may return the underlying
band instead of making a copy if `k` lies within the bandwidth of the matrix.

!!! note
The type of the result may vary depending on the values of `k`.
"""
function diag(A::AbstractMatrix, ::Val{k}) where {k}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function diag(A::AbstractMatrix, ::Val{k}) where {k}
function diag(A::AbstractMatrix, ::Val{K}) where {K}
k = K::Int

# some types might have a specialized 1-arg `diag` method,
# and we may use this if possible
k == 0 ? diag(A) : diag(A, k)
end

"""
diagview(M, k::Integer=0)

Expand Down
1 change: 1 addition & 0 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ adjoint(D::Diagonal) = Diagonal(_vecadjoint(D.diag))
permutedims(D::Diagonal) = D
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D)

diag(D::Diagonal, ::Val{0}) = D.diag
function diag(D::Diagonal, k::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of k
Expand Down
6 changes: 6 additions & 0 deletions src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ _diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv)
_eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M)
_eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M))

diag(M::SymTridiagonal{<:Number}, ::Val{0})= M.dv
diag(M::SymTridiagonal{<:Number}, ::Val{1})= _evview(M)
diag(M::SymTridiagonal{<:Number}, ::Val{-1})= _evview(M)
function diag(M::SymTridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down Expand Up @@ -700,6 +703,9 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y)

\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = copy(A) \ B

diag(M::Tridiagonal, ::Val{0}) = M.d
diag(M::Tridiagonal, ::Val{1}) = M.du
diag(M::Tridiagonal, ::Val{-1}) = M.dl
function diag(M::Tridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down