diff --git a/Project.toml b/Project.toml index d37916b..1d55fd5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.18" +version = "0.3.19" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 7e0c7cd..f0c85ac 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -68,32 +68,57 @@ end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. -function DiagonalArray{T,N,D}( - init::ShapeInitializer, unstored::Unstored -) where {T,N,D<:AbstractVector{T}} - return DiagonalArray{T,N,D}( +function DiagonalArray{T,N,D,U}( + init::ShapeInitializer, unstored::Unstored{T,N,U} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}( construct(D, init, diaglength_from_shape(axes(unstored))), unstored ) end +function DiagonalArray{T,N,D}( + init::ShapeInitializer, unstored::Unstored{T,N,U} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(init, unstored) +end # This helps to support diagonals where the elements are known # from the types, for example diagonals that are `Zeros` and `Ones`. # These versions use the default unstored type `Zeros{T,N}`. +function DiagonalArray{T,N,D,U}( + init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(init, Unstored(U(ax))) +end function DiagonalArray{T,N,D}( init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) where {T,N,D<:AbstractVector{T}} return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax))) end +function DiagonalArray{T,N,D,U}( + init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(init, ax) +end function DiagonalArray{T,N,D}( init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... ) where {T,N,D<:AbstractVector{T}} return DiagonalArray{T,N,D}(init, ax) end +function DiagonalArray{T,N,D,U}( + init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(init, Base.OneTo.(sz)) +end function DiagonalArray{T,N,D}( init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} ) where {T,N,D<:AbstractVector{T}} return DiagonalArray{T,N,D}(init, Base.OneTo.(sz)) end +function DiagonalArray{T,N,D,U}( + init::ShapeInitializer, sz1::Integer, sz_rest::Integer... +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(init, (sz1, sz_rest...)) +end function DiagonalArray{T,N,D}( init::ShapeInitializer, sz1::Integer, sz_rest::Integer... ) where {T,N,D<:AbstractVector{T}} diff --git a/test/test_basics.jl b/test/test_basics.jl index de1dce7..c40a913 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -112,12 +112,18 @@ using Test: @test, @test_throws, @testset, @test_broken, @inferred # Special constructors for immutable diagonal. init = ShapeInitializer() + U = Zeros{UInt32,2,Tuple{Base.OneTo{Int},Base.OneTo{Int}}} @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, (2, 2)) ≡ DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, 2, 2) ≡ - DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, (2, 2)) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, 2, 2) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Unstored(Zeros{UInt32}(2, 2))) init = ShapeInitializer() @test DiagonalMatrix(Ones{elt}(2)) ≡