Skip to content

Commit cb23366

Browse files
authored
Construct DiagonalArray and delta from axes (#32)
1 parent 1de5527 commit cb23366

File tree

4 files changed

+149
-7
lines changed

4 files changed

+149
-7
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.9"
4+
version = "0.3.10"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -15,5 +15,5 @@ ArrayLayouts = "1.10.4"
1515
DerivableInterfaces = "0.5"
1616
FillArrays = "1.13.0"
1717
LinearAlgebra = "1.10.0"
18-
SparseArraysBase = "0.7.1"
18+
SparseArraysBase = "0.7.2"
1919
julia = "1.10"

src/diagonalarray/delta.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,39 @@
11
using FillArrays: Ones
22

3+
function delta(
4+
elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
5+
)
6+
return DiagonalArray(Ones{elt}(minimum(length, ax)), ax)
7+
end
8+
function δ(
9+
elt::Type, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
10+
)
11+
return delta(elt, ax)
12+
end
13+
function delta(ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}})
14+
return delta(Float64, ax)
15+
end
16+
function δ(ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}})
17+
return delta(Float64, ax)
18+
end
19+
20+
function delta(
21+
elt::Type, ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}...
22+
)
23+
return delta(elt, (ax1, axs...))
24+
end
25+
function δ(
26+
elt::Type, ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}...
27+
)
28+
return delta(elt, (ax1, axs...))
29+
end
30+
function delta(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}...)
31+
return delta(Float64, (ax1, axs...))
32+
end
33+
function δ(ax1::AbstractUnitRange{<:Integer}, axs::AbstractUnitRange{<:Integer}...)
34+
return delta(Float64, (ax1, axs...))
35+
end
36+
337
function delta(elt::Type, size::Tuple{Vararg{Int}})
438
return DiagonalArray(Ones{elt}(minimum(size)), size)
539
end

src/diagonalarray/diagonalarray.jl

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,79 @@ Base.size(a::DiagonalArray) = size(unstored(a))
2121
Base.axes(a::DiagonalArray) = axes(unstored(a))
2222

2323
function DiagonalArray(::UndefInitializer, unstored::Unstored)
24-
return _DiagonalArray(Vector{eltype(unstored)}(undef, ndims(unstored)), parent(unstored))
25-
end
26-
27-
function DiagonalArray{T,N}(diag::AbstractVector, unstored::AbstractArray) where {T,N}
28-
return _DiagonalArray(convert(AbstractVector{T}, diag), dims, getunstored)
24+
return _DiagonalArray(
25+
Vector{eltype(unstored)}(undef, minimum(size(unstored))), parent(unstored)
26+
)
27+
end
28+
29+
# Constructors accepting axes.
30+
function DiagonalArray{T,N}(
31+
diag::AbstractVector,
32+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
33+
) where {T,N}
34+
N == length(ax) || throw(ArgumentError("Wrong number of axes"))
35+
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(ax))
36+
end
37+
function DiagonalArray{T,N}(
38+
diag::AbstractVector,
39+
ax1::AbstractUnitRange{<:Integer},
40+
axs::AbstractUnitRange{<:Integer}...,
41+
) where {T,N}
42+
return DiagonalArray{T,N}(diag, (ax1, axs...))
43+
end
44+
function DiagonalArray{T}(
45+
diag::AbstractVector,
46+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
47+
) where {T}
48+
return DiagonalArray{T,length(ax)}(diag, ax)
49+
end
50+
function DiagonalArray{T}(
51+
diag::AbstractVector,
52+
ax1::AbstractUnitRange{<:Integer},
53+
axs::AbstractUnitRange{<:Integer}...,
54+
) where {T}
55+
return DiagonalArray{T}(diag, (ax1, axs...))
56+
end
57+
function DiagonalArray(
58+
diag::AbstractVector{T},
59+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
60+
) where {T}
61+
return DiagonalArray{T,length(ax)}(diag, ax)
62+
end
63+
function DiagonalArray(
64+
diag::AbstractVector,
65+
ax1::AbstractUnitRange{<:Integer},
66+
axs::AbstractUnitRange{<:Integer}...,
67+
)
68+
return DiagonalArray(diag, (ax1, axs...))
69+
end
70+
71+
# undef constructors accepting axes.
72+
function DiagonalArray{T,N}(
73+
::UndefInitializer,
74+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
75+
) where {T,N}
76+
return DiagonalArray{T,N}(Vector{T}(undef, minimum(length, ax)), ax)
77+
end
78+
function DiagonalArray{T,N}(
79+
::UndefInitializer,
80+
ax1::AbstractUnitRange{<:Integer},
81+
axs::AbstractUnitRange{<:Integer}...,
82+
) where {T,N}
83+
return DiagonalArray{T,N}(undef, (ax1, axs...))
84+
end
85+
function DiagonalArray{T}(
86+
::UndefInitializer,
87+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
88+
) where {T}
89+
return DiagonalArray{T,length(ax)}(undef, ax)
90+
end
91+
function DiagonalArray{T}(
92+
::UndefInitializer,
93+
ax1::AbstractUnitRange{<:Integer},
94+
axs::AbstractUnitRange{<:Integer}...,
95+
) where {T}
96+
return DiagonalArray{T}(undef, (ax1, axs...))
2997
end
3098

3199
function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N}

test/test_basics.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,38 @@ using LinearAlgebra: Diagonal
6969
@test diagindices(IndexCartesian(), a) ==
7070
CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3))
7171
end
72+
@testset "DiagonalArray constructors" begin
73+
v = randn(elt, 2)
74+
@test DiagonalArray(v, 2, 2)
75+
DiagonalArray(v, (2, 2))
76+
DiagonalArray(v, Base.OneTo(2), Base.OneTo(2))
77+
DiagonalArray(v, (Base.OneTo(2), Base.OneTo(2)))
78+
DiagonalArray{elt}(v, 2, 2)
79+
DiagonalArray{elt}(v, (2, 2))
80+
DiagonalArray{elt}(v, Base.OneTo(2), Base.OneTo(2))
81+
DiagonalArray{elt}(v, (Base.OneTo(2), Base.OneTo(2)))
82+
DiagonalArray{elt,2}(v, 2, 2)
83+
DiagonalArray{elt,2}(v, (2, 2))
84+
DiagonalArray{elt,2}(v, Base.OneTo(2), Base.OneTo(2))
85+
DiagonalArray{elt,2}(v, (Base.OneTo(2), Base.OneTo(2)))
86+
@test size(DiagonalArray{elt}(undef, 2, 2))
87+
size(DiagonalArray{elt}(undef, (2, 2)))
88+
size(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2)))
89+
size(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2))))
90+
size(DiagonalArray{elt,2}(undef, 2, 2))
91+
size(DiagonalArray{elt,2}(undef, (2, 2)))
92+
size(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2)))
93+
size(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2))))
94+
@test elt
95+
eltype(DiagonalArray{elt}(undef, 2, 2))
96+
eltype(DiagonalArray{elt}(undef, (2, 2)))
97+
eltype(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2)))
98+
eltype(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2))))
99+
eltype(DiagonalArray{elt,2}(undef, 2, 2))
100+
eltype(DiagonalArray{elt,2}(undef, (2, 2)))
101+
eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2)))
102+
eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2))))
103+
end
72104
@testset "Matrix multiplication" begin
73105
a1 = DiagonalArray{elt}(undef, (2, 3))
74106
a1[1, 1] = 11
@@ -120,13 +152,21 @@ using LinearAlgebra: Diagonal
120152
@testset "delta" begin
121153
for (a, elt′) in (
122154
(delta(2, 3), Float64),
155+
(delta(Base.OneTo(2), Base.OneTo(3)), Float64),
123156
(δ(2, 3), Float64),
157+
(δ(Base.OneTo(2), Base.OneTo(3)), Float64),
124158
(delta((2, 3)), Float64),
159+
(delta(Base.OneTo.((2, 3))), Float64),
125160
(δ((2, 3)), Float64),
161+
(δ(Base.OneTo.((2, 3))), Float64),
126162
(delta(Bool, 2, 3), Bool),
163+
(delta(Bool, Base.OneTo(2), Base.OneTo(3)), Bool),
127164
(δ(Bool, 2, 3), Bool),
165+
(δ(Bool, Base.OneTo(2), Base.OneTo(3)), Bool),
128166
(delta(Bool, (2, 3)), Bool),
167+
(delta(Bool, Base.OneTo.((2, 3))), Bool),
129168
(δ(Bool, (2, 3)), Bool),
169+
(δ(Bool, Base.OneTo.((2, 3))), Bool),
130170
)
131171
@test eltype(a) === elt′
132172
@test diaglength(a) == 2

0 commit comments

Comments
 (0)