Skip to content
Open
2 changes: 2 additions & 0 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ _vecadjoint(A::Base.ReshapedArray{<:Any,1,<:AdjointAbsVec}) = adjoint(parent(A))

diagview(A::Transpose, k::Integer = 0) = _vectranspose(diagview(parent(A), -k))
diagview(A::Adjoint, k::Integer = 0) = _vecadjoint(diagview(parent(A), -k))
diag(A::Transpose, ::Val{k}) where {k} = _vectranspose(diag(parent(A), Val(-k)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::Transpose, ::Val{k}) where {k} = _vectranspose(diag(parent(A), Val(-k)))
diag(A::Transpose, ::Val{k}) where {k} = _vectranspose(diag(parent(A), Val(-k::Int)))

diag(A::Adjoint, ::Val{k}) where {k} = _vecadjoint(diag(parent(A), Val(-k)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::Adjoint, ::Val{k}) where {k} = _vecadjoint(diag(parent(A), Val(-k)))
diag(A::Adjoint, ::Val{k}) where {k} = _vecadjoint(diag(parent(A), Val(-k::Int)))


# triu and tril
triu!(A::AdjOrTransAbsMat, k::Integer = 0) = wrapperop(A)(tril!(parent(A), -k))
Expand Down
3 changes: 3 additions & 0 deletions src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ function triu!(M::Bidiagonal{T}, k::Integer=0) where T
return M
end

diag(M::Bidiagonal, ::Val{0}) = M.dv
diag(M::Bidiagonal, ::Val{1}) = M.uplo == 'U' ? M.ev : zero(M.ev)
diag(M::Bidiagonal, ::Val{-1}) = M.uplo == 'L' ? M.ev : zero(M.ev)
function diag(M::Bidiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down
16 changes: 16 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,22 @@ julia> diag(A,1)
"""
diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))]

"""
diag(M, ::Val{k}) where {k}

Return the `k`th diagonal of a matrix as a vector.
For banded matrix types such as `Diagonal`, this may return the underlying
band instead of making a copy if `k` lies within the bandwidth of the matrix.

!!! note
The type of the result may vary depending on the values of `k`.
"""
function diag(A::AbstractMatrix, ::Val{k}) where {k}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function diag(A::AbstractMatrix, ::Val{k}) where {k}
function diag(A::AbstractMatrix, ::Val{K}) where {K}
k = K::Int

# some types might have a specialized 1-arg `diag` method,
# and we may use this if possible
k == 0 ? diag(A) : diag(A, k)
end

"""
diagview(M, k::Integer=0)

Expand Down
1 change: 1 addition & 0 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ adjoint(D::Diagonal) = Diagonal(_vecadjoint(D.diag))
permutedims(D::Diagonal) = D
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D)

diag(D::Diagonal, ::Val{0}) = D.diag
function diag(D::Diagonal, k::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of k
Expand Down
3 changes: 3 additions & 0 deletions src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,6 @@ function logdet(F::Hessenberg)
d,s = logabsdet(F)
return d + log(s)
end

diag(A::UpperHessenberg) = diag(A.data)
diag(A::UpperHessenberg, ::Val{k}) where {k} = k >= -1 ? diag(A.data, Val(k)) : diag(A, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::UpperHessenberg, ::Val{k}) where {k} = k >= -1 ? diag(A.data, Val(k)) : diag(A, k)
function diag(A::UpperHessenberg, ::Val{K}) where {K}
k = K::Int
k >= -1 ? diag(A.data, Val(k)) : diag(A, k)
end

10 changes: 8 additions & 2 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,12 @@ adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , tr

diag(A::UpperOrLowerTriangular) = diag(A.data)
diag(A::Union{UnitLowerTriangular, UnitUpperTriangular}) = fill(oneunit(eltype(A)), size(A,1))
diag(A::UpperTriangular, ::Val{k}) where {k} = k >= 0 ? diag(A.data, Val(k)) : diag(A, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::UpperTriangular, ::Val{k}) where {k} = k >= 0 ? diag(A.data, Val(k)) : diag(A, k)
function diag(A::UpperTriangular, ::Val{K}) where {K}
k = K::Int
k >= 0 ? diag(A.data, Val(k)) : diag(A, k)
end

diag(A::LowerTriangular, ::Val{k}) where {k} = k <= 0 ? diag(A.data, Val(k)) : diag(A, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::LowerTriangular, ::Val{k}) where {k} = k <= 0 ? diag(A.data, Val(k)) : diag(A, k)
function diag(A::LowerTriangular, ::Val{K}) where {K}
k = K::Int
k <= 0 ? diag(A.data, Val(k)) : diag(A, k)
end

diag(A::UnitUpperTriangular, ::Val{0}) = diag(A)
diag(A::UnitLowerTriangular, ::Val{0}) = diag(A)
diag(A::UnitUpperTriangular, ::Val{k}) where {k} = k > 0 ? diag(A.data, Val(k)) : diag(A, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::UnitUpperTriangular, ::Val{k}) where {k} = k > 0 ? diag(A.data, Val(k)) : diag(A, k)
function diag(A::UnitUpperTriangular, ::Val{K}) where {K}
k = K::Int
k > 0 ? diag(A.data, Val(k)) : diag(A, k)
end

diag(A::UnitLowerTriangular, ::Val{k}) where {k} = k < 0 ? diag(A.data, Val(k)) : diag(A, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag(A::UnitLowerTriangular, ::Val{k}) where {k} = k < 0 ? diag(A.data, Val(k)) : diag(A, k)
function diag(A::UnitLowerTriangular, ::Val{K}) where {K}
k = K::Int
k < 0 ? diag(A.data, Val(k)) : diag(A, k)
end


# Unary operations
-(A::LowerTriangular) = LowerTriangular(-A.data)
Expand Down Expand Up @@ -2994,8 +3000,8 @@ logdet(A::UnitUpperTriangular{T}) where {T} = zero(T)
logdet(A::UnitLowerTriangular{T}) where {T} = zero(T)
logabsdet(A::UnitUpperTriangular{T}) where {T} = zero(T), one(T)
logabsdet(A::UnitLowerTriangular{T}) where {T} = zero(T), one(T)
det(A::UpperTriangular) = prod(diag(A.data))
det(A::LowerTriangular) = prod(diag(A.data))
det(A::UpperTriangular) = prod(diag(A.data, Val(0)))
det(A::LowerTriangular) = prod(diag(A.data, Val(0)))
function logabsdet(A::Union{UpperTriangular{T},LowerTriangular{T}}) where T
sgn = one(T)
abs_det = zero(real(T))
Expand Down
8 changes: 7 additions & 1 deletion src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal}
checksquare(A)
du = diag(A, 1)
d = diag(A)
if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, -1)))
if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, Val(-1))))
throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal"))
end
return SymTri(d, du)
Expand Down Expand Up @@ -195,6 +195,9 @@ _diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv)
_eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M)
_eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M))

diag(M::SymTridiagonal{<:Number}, ::Val{0}) = M.dv
diag(M::SymTridiagonal{<:Number}, ::Val{1}) = _evview(M)
diag(M::SymTridiagonal{<:Number}, ::Val{-1}) = _evview(M)
function diag(M::SymTridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down Expand Up @@ -700,6 +703,9 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y)

\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = copy(A) \ B

diag(M::Tridiagonal, ::Val{0}) = M.d
diag(M::Tridiagonal, ::Val{1}) = M.du
diag(M::Tridiagonal, ::Val{-1}) = M.dl
function diag(M::Tridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down
19 changes: 19 additions & 0 deletions test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,25 @@ end
end
end

@testset "diag with a Val index" begin
@testset "$(typeof(A))" for A in Any[rand(4, 4), rand(ComplexF64,4,4), fill([1 2; 3 4], 4, 4),
Diagonal(1:4), Bidiagonal(1:4, 1:3, :U),
Tridiagonal(1:3, 1:4, 1:3), SymTridiagonal(1:4, 1:3)]
@testset for (wrap_fn, wrap_T) in ((transpose,Transpose), (adjoint,Adjoint))
At = wrap_fn(A)
@test diag(At, 1) == diag(At, Val(1))
@test diag(At, 0) == diag(At, Val(0))
@test diag(At, -1) == diag(At, Val(-1))
if !(At isa wrap_T)
AT = wrap_T(A)
@test diag(At, Val(1)) == diag(AT, Val(1))
@test diag(At, Val(0)) == diag(AT, Val(0))
@test diag(At, Val(-1)) == diag(AT, Val(-1))
end
end
end
end

@testset "triu!/tril!" begin
@testset for sz in ((4,4), (3,4), (4,3))
A = rand(sz...)
Expand Down
9 changes: 9 additions & 0 deletions test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1298,4 +1298,13 @@ end
end
end

@testset "diag with a Val index" begin
B = Bidiagonal(1:4, 1:3, :U)
@test diag(B, Val(0)) === 1:4
@test diag(B, Val(1)) === 1:3
B = Bidiagonal(1:4, 1:3, :L)
@test diag(B, Val(0)) === 1:4
@test diag(B, Val(-1)) === 1:3
end

end # module TestBidiagonal
5 changes: 5 additions & 0 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1576,4 +1576,9 @@ end
@test D == D2
end

@testset "diag with a Val index" begin
D = Diagonal(1:4)
@test diag(D, Val(0)) === 1:4
end

end # module TestDiagonal
10 changes: 10 additions & 0 deletions test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,14 @@ end
@test U == U2
end

@testset "diag with a Val index" begin
H = UpperHessenberg(Tridiagonal(1:3, 1:4, 1:3))
@test diag(H, Val(0)) === 1:4
@test diag(H, Val(1)) === 1:3
@test diag(H, Val(-1)) === 1:3
@test diag(H, Val(0)) == diag(H) == diag(H, 0)
@test diag(H, Val(2)) == diag(H, 2)
@test diag(H, Val(-2)) == diag(H, -2)
end

end # module TestHessenberg
20 changes: 20 additions & 0 deletions test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1110,4 +1110,24 @@ end
end
end

@testset "diag with a Val index" begin
U = UpperTriangular(Tridiagonal(2:4, 1:4, 1:3))
@test diag(U, Val(0)) === 1:4
@test diag(U, Val(1)) === 1:3
@test diag(U, Val(-1)) == diag(U, -1) == zeros(3)
L = LowerTriangular(Tridiagonal(2:4, 1:4, 1:3))
@test diag(L, Val(0)) === 1:4
@test diag(L, Val(-1)) === 2:4
@test diag(L, Val(1)) == diag(L, 1) == zeros(3)

U = UnitUpperTriangular(Tridiagonal(2:4, 1:4, 1:3))
@test diag(U, Val(1)) === 1:3
@test diag(U, Val(0)) == diag(U, 0) == diag(U) == ones(4)
@test diag(U, Val(-1)) == diag(U, -1) == zeros(3)
L = UnitLowerTriangular(Tridiagonal(2:4, 1:4, 1:3))
@test diag(L, Val(-1)) === 2:4
@test diag(L, Val(0)) == diag(L, 0) == diag(L) == ones(4)
@test diag(L, Val(1)) == diag(L, 1) == zeros(3)
end

end # module TestTriangular
11 changes: 11 additions & 0 deletions test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1255,4 +1255,15 @@ end
end
end

@testset "diag with a Val index" begin
T = Tridiagonal(2:4, 1:4, 1:3)
@test diag(T, Val(0)) === 1:4
@test diag(T, Val(1)) === 1:3
@test diag(T, Val(-1)) === 2:4
ST = SymTridiagonal(1:4, 1:3)
@test diag(ST, Val(0)) === 1:4
@test diag(ST, Val(1)) === 1:3
@test diag(ST, Val(-1)) === 1:3
end

end # module TestTridiagonal