diff --git a/Project.toml b/Project.toml index 8ffc6db..50a5244 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.5" +version = "0.3.6" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/diaginterface/diaginterface.jl b/src/diaginterface/diaginterface.jl index 7baad7b..ab572c9 100644 --- a/src/diaginterface/diaginterface.jl +++ b/src/diaginterface/diaginterface.jl @@ -84,6 +84,9 @@ function diagview(a::AbstractArray) return @view a[diagindices(a)] end +using LinearAlgebra: Diagonal +diagview(a::Diagonal) = a.diag + function getdiagindex(a::AbstractArray, i::Integer) return diagview(a)[i] end @@ -110,7 +113,36 @@ end diagonal(v::AbstractVector) -> AbstractMatrix Return a diagonal matrix from a vector `v`. -This is an extension of `LinearAlgebra.Diagonal`, designed to avoid the implication of the output type. +This is an extension of `LinearAlgebra.Diagonal`, designed to avoid +the implication of the output type. Defaults to `Diagonal(v)`. """ diagonal(v::AbstractVector) = LinearAlgebra.Diagonal(v) + +""" + diagonal(m::AbstractMatrix) -> AbstractMatrix + +Return a diagonal matrix from a matrix `m` where the diagonal +values are copied from the diagonal of `m`. +This is an extension of `LinearAlgebra.Diagonal`, designed to avoid +the implication of the output type. +Defaults to `diagonal(copy(diagview(m)))`, which in general is +equivalent to `Diagonal(m)`. +""" +diagonal(m::AbstractMatrix) = diagonal(copy(diagview(m))) + +""" + diagonaltype(::AbstractVector) -> Type{<:AbstractMatrix} + diagonaltype(::Type{<:AbstractVector}) -> Type{<:AbstractMatrix} + diagonaltype(::AbstractMatrix) -> Type{<:AbstractMatrix} + diagonaltype(::Type{<:AbstractMatrix}) -> Type{<:AbstractMatrix} + +Return the type of diagonal matrix that would be created from a vector or matrix +using the [`diagonal`](@ref) function. +""" +diagonaltype + +diagonaltype(v::AbstractVector) = diagonaltype(typeof(v)) +diagonaltype(V::Type{<:AbstractVector}) = Base.promote_op(diagonal, V) +diagonaltype(m::AbstractMatrix) = diagonaltype(typeof(m)) +diagonaltype(M::Type{<:AbstractMatrix}) = Base.promote_op(diagonal, M) diff --git a/test/test_basics.jl b/test/test_basics.jl index 922958b..3660f92 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -8,6 +8,7 @@ using DiagonalArrays: diagindices, diaglength, diagonal, + diagonaltype, diagview using FillArrays: Fill, Ones using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength @@ -104,8 +105,17 @@ using LinearAlgebra: Diagonal @test a_dest isa SparseArrayDOK{elt,2} end @testset "diagonal" begin - @test @inferred(diagonal(rand(2))) isa AbstractMatrix - @test diagonal(zeros(Int, 2)) isa Diagonal + v = randn(2) + d = @inferred diagonal(v) + @test d isa Diagonal{eltype(v)} + @test diagview(d) === v + @test diagonaltype(v) === typeof(d) + + a = randn(2, 2) + d = @inferred diagonal(a) + @test d isa Diagonal{eltype(v)} + @test diagview(d) == diagview(a) + @test diagonaltype(a) === typeof(d) end @testset "delta" begin for (a, elt′) in (