Skip to content

Commit 20ffcb5

Browse files
[SparseArraysDOKs] Add setindex_maybe_grow! and macro @maybe_grow (#1434)
1 parent f2791ba commit 20ffcb5

File tree

3 files changed

+81
-5
lines changed

3 files changed

+81
-5
lines changed

NDTensors/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1818
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
1919
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
2020
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
21+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2122
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
2223
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -34,20 +35,20 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
3435
[weakdeps]
3536
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3637
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
37-
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
3838
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
3939
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4040
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
4141
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
42+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
4243

4344
[extensions]
4445
NDTensorsAMDGPUExt = "AMDGPU"
4546
NDTensorsCUDAExt = "CUDA"
46-
NDTensorscuTENSORExt = "cuTENSOR"
4747
NDTensorsHDF5Ext = "HDF5"
4848
NDTensorsMetalExt = "Metal"
4949
NDTensorsOctavianExt = "Octavian"
5050
NDTensorsTBLISExt = "TBLIS"
51+
NDTensorscuTENSORExt = "cuTENSOR"
5152

5253
[compat]
5354
Accessors = "0.1.33"
@@ -65,6 +66,7 @@ HDF5 = "0.14, 0.15, 0.16, 0.17"
6566
HalfIntegers = "1"
6667
InlineStrings = "1"
6768
LinearAlgebra = "1.6"
69+
MacroTools = "0.5"
6870
MappedArrays = "0.4"
6971
PackageExtensionCompat = "1"
7072
Random = "1.6"

NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
using Accessors: @set
22
using Dictionaries: Dictionary, set!
3+
using MacroTools: @capture
34
using ..SparseArrayInterface:
45
SparseArrayInterface, AbstractSparseArray, getindex_zero_function
56

67
# TODO: Parametrize by `data`?
78
struct SparseArrayDOK{T,N,Zero} <: AbstractSparseArray{T,N}
89
data::Dictionary{CartesianIndex{N},T}
9-
dims::NTuple{N,Int}
10+
dims::Ref{NTuple{N,Int}}
1011
zero::Zero
12+
function SparseArrayDOK{T,N,Zero}(data, dims::NTuple{N,Int}, zero) where {T,N,Zero}
13+
return new{T,N,Zero}(data, Ref(dims), zero)
14+
end
1115
end
1216

1317
# Constructors
18+
function SparseArrayDOK(data, dims::Tuple{Vararg{Int}}, zero)
19+
return SparseArrayDOK{eltype(data),length(dims),typeof(zero)}(data, dims, zero)
20+
end
21+
1422
function SparseArrayDOK{T,N,Zero}(dims::Tuple{Vararg{Int}}, zero) where {T,N,Zero}
1523
return SparseArrayDOK{T,N,Zero}(default_data(T, N), dims, zero)
1624
end
@@ -72,7 +80,7 @@ function SparseArrayDOK{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}, zero) w
7280
end
7381

7482
# Base `AbstractArray` interface
75-
Base.size(a::SparseArrayDOK) = a.dims
83+
Base.size(a::SparseArrayDOK) = a.dims[]
7684

7785
SparseArrayInterface.getindex_zero_function(a::SparseArrayDOK) = a.zero
7886
function SparseArrayInterface.set_getindex_zero_function(a::SparseArrayDOK, f)
@@ -104,3 +112,37 @@ SparseArrayDOK{T}(a::AbstractArray) where {T} = SparseArrayDOK{T,ndims(a)}(a)
104112
function SparseArrayDOK{T,N}(a::AbstractArray) where {T,N}
105113
return SparseArrayInterface.sparse_convert(SparseArrayDOK{T,N}, a)
106114
end
115+
116+
function Base.resize!(a::SparseArrayDOK{<:Any,N}, new_size::NTuple{N,Integer}) where {N}
117+
a.dims[] = new_size
118+
return a
119+
end
120+
121+
function setindex_maybe_grow!(a::SparseArrayDOK{<:Any,N}, value, I::Vararg{Int,N}) where {N}
122+
if any(I .> size(a))
123+
resize!(a, max.(I, size(a)))
124+
end
125+
a[I...] = value
126+
return a
127+
end
128+
129+
function is_setindex!_expr(expr::Expr)
130+
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
131+
end
132+
is_setindex!_expr(x) = false
133+
134+
is_getindex_expr(expr::Expr) = (expr.head === :ref)
135+
is_getindex_expr(x) = false
136+
137+
is_assignment_expr(expr::Expr) = (expr.head === :(=))
138+
is_assignment_expr(expr) = false
139+
140+
macro maybe_grow(expr)
141+
if !is_setindex!_expr(expr)
142+
error(
143+
"@maybe_grow must be used with setindex! syntax (as @maybe_grow a[i,j,...] = value)"
144+
)
145+
end
146+
@capture(expr, array_[indices__] = value_)
147+
return :(setindex_maybe_grow!($(esc(array)), $value, $indices...))
148+
end

NDTensors/src/lib/SparseArrayDOKs/test/runtests.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
# Custom zero type
66
# Slicing
77

8+
using Dictionaries: Dictionary
89
using Test: @test, @testset, @test_broken
9-
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK
10+
using NDTensors.SparseArrayDOKs:
11+
SparseArrayDOKs, SparseArrayDOK, SparseMatrixDOK, @maybe_grow
1012
using NDTensors.SparseArrayInterface: storage_indices, nstored
1113
using SparseArrays: SparseMatrixCSC, nnz
1214
@testset "SparseArrayDOK (eltype=$elt)" for elt in
@@ -94,5 +96,35 @@ using SparseArrays: SparseMatrixCSC, nnz
9496
end
9597
end
9698
end
99+
@testset "Maybe Grow Feature" begin
100+
a = SparseArrayDOK{elt,2}((0, 0))
101+
SparseArrayDOKs.setindex_maybe_grow!(a, 230, 2, 3)
102+
@test size(a) == (2, 3)
103+
@test a[2, 3] == 230
104+
# Test @maybe_grow macro
105+
@maybe_grow a[5, 5] = 550
106+
@test size(a) == (5, 5)
107+
@test a[2, 3] == 230
108+
@test a[5, 5] == 550
109+
# Test that size remains same
110+
# if we set at an index smaller than
111+
# the maximum size:
112+
@maybe_grow a[3, 4] = 340
113+
@test size(a) == (5, 5)
114+
@test a[2, 3] == 230
115+
@test a[5, 5] == 550
116+
@test a[3, 4] == 340
117+
# Test vector case
118+
v = SparseArrayDOK{elt,1}((0,))
119+
@maybe_grow v[5] = 50
120+
@test size(v) == (5,)
121+
@test v[5] == 50
122+
end
123+
@testset "Test Lower Level Constructor" begin
124+
d = Dictionary{CartesianIndex{2},elt}()
125+
a = SparseArrayDOK(d, (2, 2), zero(elt))
126+
a[1, 2] = 12.0
127+
@test a[1, 2] == 12.0
128+
end
97129
end
98130
end

0 commit comments

Comments
 (0)