Skip to content

Commit 3e2b7f9

Browse files
authored
Add some missing constructors (#44)
1 parent 2c04cbb commit 3e2b7f9

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.18"
4+
version = "0.3.19"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/diagonalarray/diagonalarray.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,32 +68,57 @@ end
6868

6969
# This helps to support diagonals where the elements are known
7070
# from the types, for example diagonals that are `Zeros` and `Ones`.
71-
function DiagonalArray{T,N,D}(
72-
init::ShapeInitializer, unstored::Unstored
73-
) where {T,N,D<:AbstractVector{T}}
74-
return DiagonalArray{T,N,D}(
71+
function DiagonalArray{T,N,D,U}(
72+
init::ShapeInitializer, unstored::Unstored{T,N,U}
73+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
74+
return DiagonalArray{T,N,D,U}(
7575
construct(D, init, diaglength_from_shape(axes(unstored))), unstored
7676
)
7777
end
78+
function DiagonalArray{T,N,D}(
79+
init::ShapeInitializer, unstored::Unstored{T,N,U}
80+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
81+
return DiagonalArray{T,N,D,U}(init, unstored)
82+
end
7883

7984
# This helps to support diagonals where the elements are known
8085
# from the types, for example diagonals that are `Zeros` and `Ones`.
8186
# These versions use the default unstored type `Zeros{T,N}`.
87+
function DiagonalArray{T,N,D,U}(
88+
init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
89+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
90+
return DiagonalArray{T,N,D,U}(init, Unstored(U(ax)))
91+
end
8292
function DiagonalArray{T,N,D}(
8393
init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
8494
) where {T,N,D<:AbstractVector{T}}
8595
return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax)))
8696
end
97+
function DiagonalArray{T,N,D,U}(
98+
init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}...
99+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
100+
return DiagonalArray{T,N,D,U}(init, ax)
101+
end
87102
function DiagonalArray{T,N,D}(
88103
init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}...
89104
) where {T,N,D<:AbstractVector{T}}
90105
return DiagonalArray{T,N,D}(init, ax)
91106
end
107+
function DiagonalArray{T,N,D,U}(
108+
init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}}
109+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
110+
return DiagonalArray{T,N,D,U}(init, Base.OneTo.(sz))
111+
end
92112
function DiagonalArray{T,N,D}(
93113
init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}}
94114
) where {T,N,D<:AbstractVector{T}}
95115
return DiagonalArray{T,N,D}(init, Base.OneTo.(sz))
96116
end
117+
function DiagonalArray{T,N,D,U}(
118+
init::ShapeInitializer, sz1::Integer, sz_rest::Integer...
119+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
120+
return DiagonalArray{T,N,D,U}(init, (sz1, sz_rest...))
121+
end
97122
function DiagonalArray{T,N,D}(
98123
init::ShapeInitializer, sz1::Integer, sz_rest::Integer...
99124
) where {T,N,D<:AbstractVector{T}}

test/test_basics.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,18 @@ using Test: @test, @test_throws, @testset, @test_broken, @inferred
112112

113113
# Special constructors for immutable diagonal.
114114
init = ShapeInitializer()
115+
U = Zeros{UInt32,2,Tuple{Base.OneTo{Int},Base.OneTo{Int}}}
115116
@test DiagonalMatrix(Base.OneTo(UInt32(2)))
116117
DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2)))
117118
DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...)
118119
DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, (2, 2))
119120
DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, 2, 2)
120-
DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2)))
121+
DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2)))
122+
DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Base.OneTo.((2, 2)))
123+
DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Base.OneTo.((2, 2))...)
124+
DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, (2, 2))
125+
DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, 2, 2)
126+
DiagonalMatrix{UInt32,Base.OneTo{UInt32},U}(init, Unstored(Zeros{UInt32}(2, 2)))
121127

122128
init = ShapeInitializer()
123129
@test DiagonalMatrix(Ones{elt}(2))

0 commit comments

Comments
 (0)