diff --git a/src/adjtrans.jl b/src/adjtrans.jl index 96db07d5..29f8d064 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -577,6 +577,8 @@ _vecadjoint(A::Base.ReshapedArray{<:Any,1,<:AdjointAbsVec}) = adjoint(parent(A)) diagview(A::Transpose, k::Integer = 0) = _vectranspose(diagview(parent(A), -k)) diagview(A::Adjoint, k::Integer = 0) = _vecadjoint(diagview(parent(A), -k)) +diag(A::Transpose, ::Val{k}) where {k} = _vectranspose(diag(parent(A), Val(-k))) +diag(A::Adjoint, ::Val{k}) where {k} = _vecadjoint(diag(parent(A), Val(-k))) # triu and tril triu!(A::AdjOrTransAbsMat, k::Integer = 0) = wrapperop(A)(tril!(parent(A), -k)) diff --git a/src/bidiag.jl b/src/bidiag.jl index cc5e6de7..0e36e7da 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -418,6 +418,9 @@ function triu!(M::Bidiagonal{T}, k::Integer=0) where T return M end +diag(M::Bidiagonal, ::Val{0}) = M.dv +diag(M::Bidiagonal, ::Val{1}) = M.uplo == 'U' ? M.ev : zero(M.ev) +diag(M::Bidiagonal, ::Val{-1}) = M.uplo == 'L' ? M.ev : zero(M.ev) function diag(M::Bidiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n diff --git a/src/dense.jl b/src/dense.jl index f6f1accb..3f6ef78e 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -307,6 +307,22 @@ julia> diag(A,1) """ diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))] +""" + diag(M, ::Val{k}) where {k} + +Return the `k`th diagonal of a matrix as a vector. +For banded matrix types such as `Diagonal`, this may return the underlying +band instead of making a copy if `k` lies within the bandwidth of the matrix. + +!!! note + The type of the result may vary depending on the values of `k`. +""" +function diag(A::AbstractMatrix, ::Val{k}) where {k} + # some types might have a specialized 1-arg `diag` method, + # and we may use this if possible + k == 0 ? diag(A) : diag(A, k) +end + """ diagview(M, k::Integer=0) diff --git a/src/diagonal.jl b/src/diagonal.jl index e60fb009..ef2e7005 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -926,6 +926,7 @@ adjoint(D::Diagonal) = Diagonal(_vecadjoint(D.diag)) permutedims(D::Diagonal) = D permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D) +diag(D::Diagonal, ::Val{0}) = D.diag function diag(D::Diagonal, k::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of k diff --git a/src/hessenberg.jl b/src/hessenberg.jl index 7aab1c77..b65f7be8 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -667,3 +667,6 @@ function logdet(F::Hessenberg) d,s = logabsdet(F) return d + log(s) end + +diag(A::UpperHessenberg) = diag(A.data) +diag(A::UpperHessenberg, ::Val{k}) where {k} = k >= -1 ? diag(A.data, Val(k)) : diag(A, k) diff --git a/src/triangular.jl b/src/triangular.jl index d82ddd87..966e68e4 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -550,6 +550,12 @@ adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , tr diag(A::UpperOrLowerTriangular) = diag(A.data) diag(A::Union{UnitLowerTriangular, UnitUpperTriangular}) = fill(oneunit(eltype(A)), size(A,1)) +diag(A::UpperTriangular, ::Val{k}) where {k} = k >= 0 ? diag(A.data, Val(k)) : diag(A, k) +diag(A::LowerTriangular, ::Val{k}) where {k} = k <= 0 ? diag(A.data, Val(k)) : diag(A, k) +diag(A::UnitUpperTriangular, ::Val{0}) = diag(A) +diag(A::UnitLowerTriangular, ::Val{0}) = diag(A) +diag(A::UnitUpperTriangular, ::Val{k}) where {k} = k > 0 ? diag(A.data, Val(k)) : diag(A, k) +diag(A::UnitLowerTriangular, ::Val{k}) where {k} = k < 0 ? diag(A.data, Val(k)) : diag(A, k) # Unary operations -(A::LowerTriangular) = LowerTriangular(-A.data) @@ -2994,8 +3000,8 @@ logdet(A::UnitUpperTriangular{T}) where {T} = zero(T) logdet(A::UnitLowerTriangular{T}) where {T} = zero(T) logabsdet(A::UnitUpperTriangular{T}) where {T} = zero(T), one(T) logabsdet(A::UnitLowerTriangular{T}) where {T} = zero(T), one(T) -det(A::UpperTriangular) = prod(diag(A.data)) -det(A::LowerTriangular) = prod(diag(A.data)) +det(A::UpperTriangular) = prod(diag(A.data, Val(0))) +det(A::LowerTriangular) = prod(diag(A.data, Val(0))) function logabsdet(A::Union{UpperTriangular{T},LowerTriangular{T}}) where T sgn = one(T) abs_det = zero(real(T)) diff --git a/src/tridiag.jl b/src/tridiag.jl index 519be750..2426c096 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -111,7 +111,7 @@ function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal} checksquare(A) du = diag(A, 1) d = diag(A) - if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, -1))) + if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, Val(-1)))) throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end return SymTri(d, du) @@ -195,6 +195,9 @@ _diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv) _eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M) _eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M)) +diag(M::SymTridiagonal{<:Number}, ::Val{0}) = M.dv +diag(M::SymTridiagonal{<:Number}, ::Val{1}) = _evview(M) +diag(M::SymTridiagonal{<:Number}, ::Val{-1}) = _evview(M) function diag(M::SymTridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n @@ -700,6 +703,9 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) \(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = copy(A) \ B +diag(M::Tridiagonal, ::Val{0}) = M.d +diag(M::Tridiagonal, ::Val{1}) = M.du +diag(M::Tridiagonal, ::Val{-1}) = M.dl function diag(M::Tridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n diff --git a/test/adjtrans.jl b/test/adjtrans.jl index 075f0acf..89d7cbd9 100644 --- a/test/adjtrans.jl +++ b/test/adjtrans.jl @@ -799,6 +799,25 @@ end end end +@testset "diag with a Val index" begin + @testset "$(typeof(A))" for A in Any[rand(4, 4), rand(ComplexF64,4,4), fill([1 2; 3 4], 4, 4), + Diagonal(1:4), Bidiagonal(1:4, 1:3, :U), + Tridiagonal(1:3, 1:4, 1:3), SymTridiagonal(1:4, 1:3)] + @testset for (wrap_fn, wrap_T) in ((transpose,Transpose), (adjoint,Adjoint)) + At = wrap_fn(A) + @test diag(At, 1) == diag(At, Val(1)) + @test diag(At, 0) == diag(At, Val(0)) + @test diag(At, -1) == diag(At, Val(-1)) + if !(At isa wrap_T) + AT = wrap_T(A) + @test diag(At, Val(1)) == diag(AT, Val(1)) + @test diag(At, Val(0)) == diag(AT, Val(0)) + @test diag(At, Val(-1)) == diag(AT, Val(-1)) + end + end + end +end + @testset "triu!/tril!" begin @testset for sz in ((4,4), (3,4), (4,3)) A = rand(sz...) diff --git a/test/bidiag.jl b/test/bidiag.jl index 6c20a4b8..d9051511 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1298,4 +1298,13 @@ end end end +@testset "diag with a Val index" begin + B = Bidiagonal(1:4, 1:3, :U) + @test diag(B, Val(0)) === 1:4 + @test diag(B, Val(1)) === 1:3 + B = Bidiagonal(1:4, 1:3, :L) + @test diag(B, Val(0)) === 1:4 + @test diag(B, Val(-1)) === 1:3 +end + end # module TestBidiagonal diff --git a/test/diagonal.jl b/test/diagonal.jl index 712f426c..ce331483 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -1576,4 +1576,9 @@ end @test D == D2 end +@testset "diag with a Val index" begin + D = Diagonal(1:4) + @test diag(D, Val(0)) === 1:4 +end + end # module TestDiagonal diff --git a/test/hessenberg.jl b/test/hessenberg.jl index 5b61c798..81e6655a 100644 --- a/test/hessenberg.jl +++ b/test/hessenberg.jl @@ -320,4 +320,14 @@ end @test U == U2 end +@testset "diag with a Val index" begin + H = UpperHessenberg(Tridiagonal(1:3, 1:4, 1:3)) + @test diag(H, Val(0)) === 1:4 + @test diag(H, Val(1)) === 1:3 + @test diag(H, Val(-1)) === 1:3 + @test diag(H, Val(0)) == diag(H) == diag(H, 0) + @test diag(H, Val(2)) == diag(H, 2) + @test diag(H, Val(-2)) == diag(H, -2) +end + end # module TestHessenberg diff --git a/test/triangular.jl b/test/triangular.jl index e823f698..61a39950 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -1110,4 +1110,24 @@ end end end +@testset "diag with a Val index" begin + U = UpperTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(U, Val(0)) === 1:4 + @test diag(U, Val(1)) === 1:3 + @test diag(U, Val(-1)) == diag(U, -1) == zeros(3) + L = LowerTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(L, Val(0)) === 1:4 + @test diag(L, Val(-1)) === 2:4 + @test diag(L, Val(1)) == diag(L, 1) == zeros(3) + + U = UnitUpperTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(U, Val(1)) === 1:3 + @test diag(U, Val(0)) == diag(U, 0) == diag(U) == ones(4) + @test diag(U, Val(-1)) == diag(U, -1) == zeros(3) + L = UnitLowerTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(L, Val(-1)) === 2:4 + @test diag(L, Val(0)) == diag(L, 0) == diag(L) == ones(4) + @test diag(L, Val(1)) == diag(L, 1) == zeros(3) +end + end # module TestTriangular diff --git a/test/tridiag.jl b/test/tridiag.jl index e955a37e..888557ac 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1255,4 +1255,15 @@ end end end +@testset "diag with a Val index" begin + T = Tridiagonal(2:4, 1:4, 1:3) + @test diag(T, Val(0)) === 1:4 + @test diag(T, Val(1)) === 1:3 + @test diag(T, Val(-1)) === 2:4 + ST = SymTridiagonal(1:4, 1:3) + @test diag(ST, Val(0)) === 1:4 + @test diag(ST, Val(1)) === 1:3 + @test diag(ST, Val(-1)) === 1:3 +end + end # module TestTridiagonal