Skip to content

Commit f44f3f2

Browse files
authored
Introduce OneElementArray (#26)
1 parent b41609b commit f44f3f2

9 files changed

+408
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.10"
4+
version = "0.2.11"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
99
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1010
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
11+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1314

@@ -17,6 +18,7 @@ Aqua = "0.8.9"
1718
ArrayLayouts = "1.11.0"
1819
DerivableInterfaces = "0.3.7"
1920
Dictionaries = "0.4.3"
21+
FillArrays = "1.13.0"
2022
LinearAlgebra = "1.10"
2123
MapBroadcast = "0.1.5"
2224
SafeTestsets = "0.1"

src/SparseArraysBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ module SparseArraysBase
33
export SparseArrayDOK,
44
SparseMatrixDOK,
55
SparseVectorDOK,
6+
OneElementArray,
7+
OneElementMatrix,
8+
OneElementVector,
69
eachstoredindex,
710
isstored,
11+
oneelementarray,
812
storedlength,
913
storedpairs,
1014
storedvalues
@@ -14,5 +18,6 @@ include("sparsearrayinterface.jl")
1418
include("wrappers.jl")
1519
include("abstractsparsearray.jl")
1620
include("sparsearraydok.jl")
21+
include("oneelementarray.jl")
1722

1823
end

src/oneelementarray.jl

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
using FillArrays: Fill
2+
3+
# Like [`FillArrays.OneElement`](https://github.com/JuliaArrays/FillArrays.jl)
4+
# and [`OneHotArrays.OneHotArray`](https://github.com/FluxML/OneHotArrays.jl).
5+
struct OneElementArray{T,N,I,A,F} <: AbstractSparseArray{T,N}
6+
value::T
7+
index::I
8+
axes::A
9+
getunstoredindex::F
10+
end
11+
12+
using DerivableInterfaces: @array_aliases
13+
# Define `OneElementMatrix`, `AnyOneElementArray`, etc.
14+
@array_aliases OneElementArray
15+
16+
function OneElementArray{T,N}(
17+
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}, getunstoredindex
18+
) where {T,N}
19+
return OneElementArray{T,N,typeof(index),typeof(axes),typeof(getunstoredindex)}(
20+
value, index, axes, getunstoredindex
21+
)
22+
end
23+
24+
function OneElementArray{T,N}(
25+
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
26+
) where {T,N}
27+
return OneElementArray{T,N}(value, index, axes, default_getunstoredindex)
28+
end
29+
function OneElementArray{<:Any,N}(
30+
value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
31+
) where {T,N}
32+
return OneElementArray{T,N}(value, index, axes)
33+
end
34+
function OneElementArray(
35+
value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
36+
) where {T,N}
37+
return OneElementArray{T,N}(value, index, axes)
38+
end
39+
40+
function OneElementArray{T,N}(
41+
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
42+
) where {T,N}
43+
return OneElementArray{T,N}(one(T), index, axes)
44+
end
45+
function OneElementArray{<:Any,N}(
46+
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
47+
) where {N}
48+
return OneElementArray{Bool,N}(index, axes)
49+
end
50+
function OneElementArray{T}(
51+
index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
52+
) where {T,N}
53+
return OneElementArray{T,N}(index, axes)
54+
end
55+
function OneElementArray(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N}
56+
return OneElementArray{Bool,N}(index, axes)
57+
end
58+
59+
function OneElementArray{T,N}(
60+
value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
61+
) where {T,N}
62+
return OneElementArray{T,N}(value, last.(ax_ind), first.(ax_ind))
63+
end
64+
function OneElementArray{<:Any,N}(
65+
value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
66+
) where {T,N}
67+
return OneElementArray{T,N}(value, ax_ind...)
68+
end
69+
function OneElementArray{T}(
70+
value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
71+
) where {T,N}
72+
return OneElementArray{T,N}(value, ax_ind...)
73+
end
74+
function OneElementArray(
75+
value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}
76+
) where {T,N}
77+
return OneElementArray{T,N}(value, ax_ind...)
78+
end
79+
80+
function OneElementArray{T,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N}
81+
return OneElementArray{T,N}(last.(ax_ind), first.(ax_ind))
82+
end
83+
function OneElementArray{<:Any,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N}
84+
return OneElementArray{Bool,N}(ax_ind...)
85+
end
86+
function OneElementArray{T}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N}
87+
return OneElementArray{T,N}(ax_ind...)
88+
end
89+
function OneElementArray(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N}
90+
return OneElementArray{Bool,N}(ax_ind...)
91+
end
92+
93+
# Fix ambiguity errors.
94+
function OneElementArray{T,0}(value, index::Tuple{}, axes::Tuple{}) where {T}
95+
return OneElementArray{T,0}(value, index, axes, default_getunstoredindex)
96+
end
97+
function OneElementArray{<:Any,0}(value::T, index::Tuple{}, axes::Tuple{}) where {T}
98+
return OneElementArray{T,0}(value, index, axes)
99+
end
100+
function OneElementArray{T}(value, index::Tuple{}, axes::Tuple{}) where {T}
101+
return OneElementArray{T,0}(value, index, axes)
102+
end
103+
function OneElementArray(value::T, index::Tuple{}, axes::Tuple{}) where {T}
104+
return OneElementArray{T,0}(value, index, axes)
105+
end
106+
107+
# Fix ambiguity errors.
108+
function OneElementArray{T,0}(index::Tuple{}, axes::Tuple{}) where {T}
109+
return OneElementArray{T,0}(one(T), index, axes)
110+
end
111+
function OneElementArray{<:Any,0}(index::Tuple{}, axes::Tuple{})
112+
return OneElementArray{Bool,0}(index, axes)
113+
end
114+
function OneElementArray{T}(index::Tuple{}, axes::Tuple{}) where {T}
115+
return OneElementArray{T,0}(index, axes)
116+
end
117+
function OneElementArray(index::Tuple{}, axes::Tuple{})
118+
return OneElementArray{Bool,0}(value, index, axes)
119+
end
120+
121+
function OneElementArray{T,0}(value) where {T}
122+
return OneElementArray{T,0}(value, (), ())
123+
end
124+
function OneElementArray{<:Any,0}(value::T) where {T}
125+
return OneElementArray{T,0}(value)
126+
end
127+
function OneElementArray{T}(value) where {T}
128+
return OneElementArray{T,0}(value)
129+
end
130+
function OneElementArray(value::T) where {T}
131+
return OneElementArray{T}(value)
132+
end
133+
134+
function OneElementArray{T,0}() where {T}
135+
return OneElementArray{T,0}((), ())
136+
end
137+
function OneElementArray{<:Any,0}()
138+
return OneElementArray{Bool,0}(value)
139+
end
140+
function OneElementArray{T}() where {T}
141+
return OneElementArray{T,0}()
142+
end
143+
function OneElementArray()
144+
return OneElementArray{Bool}()
145+
end
146+
147+
function OneElementArray{T,N}(
148+
value, index::NTuple{N,Int}, size::NTuple{N,Integer}
149+
) where {T,N}
150+
return OneElementArray{T,N}(value, index, Base.oneto.(size))
151+
end
152+
function OneElementArray{<:Any,N}(
153+
value::T, index::NTuple{N,Int}, size::NTuple{N,Integer}
154+
) where {T,N}
155+
return OneElementArray{T,N}(value, index, size)
156+
end
157+
function OneElementArray{T}(
158+
value, index::NTuple{N,Int}, size::NTuple{N,Integer}
159+
) where {T,N}
160+
return OneElementArray{T,N}(value, index, size)
161+
end
162+
function OneElementArray(
163+
value::T, index::NTuple{N,Int}, size::NTuple{N,Integer}
164+
) where {T,N}
165+
return OneElementArray{T,N}(value, index, Base.oneto.(size))
166+
end
167+
168+
function OneElementArray{T,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N}
169+
return OneElementArray{T,N}(one(T), index, size)
170+
end
171+
function OneElementArray{<:Any,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
172+
return OneElementArray{Bool,N}(index, size)
173+
end
174+
function OneElementArray{T}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N}
175+
return OneElementArray{T,N}(index, size)
176+
end
177+
function OneElementArray(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
178+
return OneElementArray{Bool,N}(index, size)
179+
end
180+
181+
function OneElementVector{T}(value, index::Int, length::Integer) where {T}
182+
return OneElementVector{T}(value, (index,), (length,))
183+
end
184+
function OneElementVector(value::T, index::Int, length::Integer) where {T}
185+
return OneElementVector{T}(value, index, length)
186+
end
187+
function OneElementArray{T}(value, index::Int, length::Integer) where {T}
188+
return OneElementVector{T}(value, index, length)
189+
end
190+
function OneElementArray(value::T, index::Int, length::Integer) where {T}
191+
return OneElementVector{T}(value, index, length)
192+
end
193+
194+
function OneElementVector{T}(index::Int, size::Integer) where {T}
195+
return OneElementVector{T}((index,), (size,))
196+
end
197+
function OneElementVector(index::Int, length::Integer)
198+
return OneElementVector{Bool}(index, length)
199+
end
200+
function OneElementArray{T}(index::Int, size::Integer) where {T}
201+
return OneElementVector{T}(index, size)
202+
end
203+
OneElementArray(index::Int, size::Integer) = OneElementVector{Bool}(index, size)
204+
205+
# Interface to overload for constructing arrays like `OneElementArray`,
206+
# that may not be `OneElementArray` (i.e. wrapped versions).
207+
function oneelement(
208+
value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
209+
) where {N}
210+
return OneElementArray(value, index, axes)
211+
end
212+
function oneelement(
213+
eltype::Type, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}
214+
) where {N}
215+
return oneelement(one(eltype), index, axes)
216+
end
217+
function oneelement(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N}
218+
return oneelement(Bool, index, axes)
219+
end
220+
221+
function oneelement(value, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
222+
return oneelement(value, index, Base.oneto.(size))
223+
end
224+
function oneelement(eltype::Type, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
225+
return oneelement(one(eltype), index, size)
226+
end
227+
function oneelement(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N}
228+
return oneelement(Bool, index, size)
229+
end
230+
231+
function oneelement(value, ax_ind::Pair{<:AbstractUnitRange,Int}...)
232+
return oneelement(value, last.(ax_ind), first.(ax_ind))
233+
end
234+
function oneelement(eltype::Type, ax_ind::Pair{<:AbstractUnitRange,Int}...)
235+
return oneelement(one(eltype), ax_ind...)
236+
end
237+
function oneelement(ax_ind::Pair{<:AbstractUnitRange,Int}...)
238+
return oneelement(Bool, ax_ind...)
239+
end
240+
241+
function oneelement(value)
242+
return oneelement(value, (), ())
243+
end
244+
function oneelement(eltype::Type)
245+
return oneelement(one(eltype))
246+
end
247+
function oneelement()
248+
return oneelement(Bool)
249+
end
250+
251+
Base.axes(a::OneElementArray) = getfield(a, :axes)
252+
Base.size(a::OneElementArray) = length.(axes(a))
253+
storedvalue(a::OneElementArray) = getfield(a, :value)
254+
storedvalues(a::OneElementArray) = Fill(storedvalue(a), 1)
255+
256+
storedindex(a::OneElementArray) = getfield(a, :index)
257+
function isstored(a::OneElementArray, I::Int...)
258+
return I == storedindex(a)
259+
end
260+
function eachstoredindex(a::OneElementArray)
261+
return Fill(CartesianIndex(storedindex(a)), 1)
262+
end
263+
264+
function getstoredindex(a::OneElementArray, I::Int...)
265+
return storedvalue(a)
266+
end
267+
function getunstoredindex(a::OneElementArray, I::Int...)
268+
return a.getunstoredindex(a, I...)
269+
end
270+
function setstoredindex!(a::OneElementArray, value, I::Int...)
271+
return error("`OneElementArray` is immutable, you can't set elements.")
272+
end
273+
function setunstoredindex!(a::OneElementArray, value, I::Int...)
274+
return error("`OneElementArray` is immutable, you can't set elements.")
275+
end
File renamed without changes.

test/basics/test_diagonal.jl renamed to test/test_diagonal.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,22 @@ using SparseArraysBase:
1111

1212
using Test: @test, @testset
1313

14+
# compat with LTS:
15+
@static if VERSION v"1.11"
16+
_diagind = diagind
17+
else
18+
function _diagind(x::Diagonal, ::IndexCartesian)
19+
return view(CartesianIndices(x), diagind(x))
20+
end
21+
end
22+
1423
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
1524

1625
@testset "Diagonal{$T}" for T in elts
1726
L = 4
1827
D = Diagonal(rand(T, 4))
1928
@test storedlength(D) == 4
20-
@test eachstoredindex(D) == diagind(D, IndexCartesian())
29+
@test eachstoredindex(D) == _diagind(D, IndexCartesian())
2130
@test isstored(D, 2, 2)
2231
@test getstoredindex(D, 2, 2) == D[2, 2]
2332
@test !isstored(D, 2, 1)

test/basics/test_exports.jl renamed to test/test_exports.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ using Test: @test, @testset
66
:SparseArrayDOK,
77
:SparseMatrixDOK,
88
:SparseVectorDOK,
9+
:OneElementArray,
10+
:OneElementMatrix,
11+
:OneElementVector,
912
:eachstoredindex,
1013
:isstored,
14+
:oneelementarray,
1115
:storedlength,
1216
:storedpairs,
1317
:storedvalues,
File renamed without changes.

0 commit comments

Comments
 (0)