Skip to content

Commit 9ec21a3

Browse files
committed
diag with a Val band index
1 parent ed53855 commit 9ec21a3

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed

src/bidiag.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ function triu!(M::Bidiagonal{T}, k::Integer=0) where T
418418
return M
419419
end
420420

421+
diag(M::Bidiagonal, ::Val{0}) = M.dv
422+
diag(M::Bidiagonal, ::Val{1}) = M.uplo == 'U' ? M.ev : zero(M.ev)
423+
diag(M::Bidiagonal, ::Val{-1}) = M.uplo == 'L' ? M.ev : zero(M.ev)
421424
function diag(M::Bidiagonal, n::Integer=0)
422425
# every branch call similar(..., ::Int) to make sure the
423426
# same vector type is returned independent of n

src/dense.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,16 @@ julia> diag(A,1)
307307
"""
308308
diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))]
309309

310+
"""
311+
diag(M, ::Val{k}) where {k}
312+
313+
Return the `k`th diagonal of a matrix as a vector.
314+
For structured matrices such as `Diagonal`, this may return the underlying
315+
band instead of making a copy if `k` lies within the bandwidth of the matrix.
316+
This means that the type of the result may vary depending on the values of `k`.
317+
"""
318+
diag(A::AbstractMatrix, ::Val{N}) where {N} = diag(A, N)
319+
310320
"""
311321
diagview(M, k::Integer=0)
312322

src/diagonal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,7 @@ adjoint(D::Diagonal) = Diagonal(_vecadjoint(D.diag))
926926
permutedims(D::Diagonal) = D
927927
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D)
928928

929+
diag(D::Diagonal, ::Val{0}) = D.diag
929930
function diag(D::Diagonal, k::Integer=0)
930931
# every branch call similar(..., ::Int) to make sure the
931932
# same vector type is returned independent of k

src/tridiag.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,9 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y)
700700

701701
\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = copy(A) \ B
702702

703+
diag(M::Tridiagonal, ::Val{0}) = M.d
704+
diag(M::Tridiagonal, ::Val{1}) = M.du
705+
diag(M::Tridiagonal, ::Val{-1}) = M.dl
703706
function diag(M::Tridiagonal, n::Integer=0)
704707
# every branch call similar(..., ::Int) to make sure the
705708
# same vector type is returned independent of n

0 commit comments

Comments
 (0)