Skip to content

Commit 8b96262

Browse files
authored
[SparseArraysBase] Absorb SparseArrayDOKs (#1592)
* [SparseArraysBase] Absorb `SparseArrayDOKs` * [NDTensors] Bump to v0.3.71
1 parent 4ee8aa3 commit 8b96262

File tree

9 files changed

+315
-1
lines changed

9 files changed

+315
-1
lines changed

src/SparseArraysBase.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,9 @@ include("abstractsparsearray/convert.jl")
2828
include("abstractsparsearray/cat.jl")
2929
include("abstractsparsearray/SparseArraysBaseSparseArraysExt.jl")
3030
include("abstractsparsearray/SparseArraysBaseLinearAlgebraExt.jl")
31+
include("sparsearraydok/defaults.jl")
32+
include("sparsearraydok/sparsearraydok.jl")
33+
include("sparsearraydok/sparsematrixdok.jl")
34+
include("sparsearraydok/sparsevectordok.jl")
35+
include("sparsearraydok/arraylayouts.jl")
3136
end

src/sparsearraydok/arraylayouts.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
2+
using ..SparseArraysBase: AbstractSparseLayout, SparseLayout
3+
4+
ArrayLayouts.MemoryLayout(::Type{<:SparseArrayDOK}) = SparseLayout()
5+
6+
# Default sparse array type for `AbstractSparseLayout`.
7+
default_sparsearraytype(elt::Type) = SparseArrayDOK{elt}
8+
9+
# TODO: Preserve GPU memory! Implement `CuSparseArrayLayout`, `MtlSparseLayout`?
10+
function Base.similar(
11+
::MulAdd{<:AbstractSparseLayout,<:AbstractSparseLayout}, elt::Type, axes
12+
)
13+
return similar(default_sparsearraytype(elt), axes)
14+
end

src/sparsearraydok/defaults.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using Dictionaries: Dictionary
2+
using ..SparseArraysBase: Zero
3+
4+
default_zero() = Zero()
5+
default_data(type::Type, ndims::Int) = Dictionary{default_keytype(ndims),type}()
6+
default_keytype(ndims::Int) = CartesianIndex{ndims}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
using Accessors: @set
2+
using Dictionaries: Dictionary, set!
3+
using MacroTools: @capture
4+
using ..SparseArraysBase: SparseArraysBase, AbstractSparseArray, getindex_zero_function
5+
6+
# TODO: Parametrize by `data`?
7+
struct SparseArrayDOK{T,N,Zero} <: AbstractSparseArray{T,N}
8+
data::Dictionary{CartesianIndex{N},T}
9+
dims::Ref{NTuple{N,Int}}
10+
zero::Zero
11+
function SparseArrayDOK{T,N,Zero}(data, dims::NTuple{N,Int}, zero) where {T,N,Zero}
12+
return new{T,N,Zero}(data, Ref(dims), zero)
13+
end
14+
end
15+
16+
# Constructors
17+
function SparseArrayDOK(data, dims::Tuple{Vararg{Int}}, zero)
18+
return SparseArrayDOK{eltype(data),length(dims),typeof(zero)}(data, dims, zero)
19+
end
20+
21+
function SparseArrayDOK{T,N,Zero}(dims::Tuple{Vararg{Int}}, zero) where {T,N,Zero}
22+
return SparseArrayDOK{T,N,Zero}(default_data(T, N), dims, zero)
23+
end
24+
25+
function SparseArrayDOK{T,N}(dims::Tuple{Vararg{Int}}, zero) where {T,N}
26+
return SparseArrayDOK{T,N,typeof(zero)}(dims, zero)
27+
end
28+
29+
function SparseArrayDOK{T,N}(dims::Tuple{Vararg{Int}}) where {T,N}
30+
return SparseArrayDOK{T,N}(dims, default_zero())
31+
end
32+
33+
function SparseArrayDOK{T}(dims::Tuple{Vararg{Int}}) where {T}
34+
return SparseArrayDOK{T,length(dims)}(dims)
35+
end
36+
37+
function SparseArrayDOK{T}(dims::Int...) where {T}
38+
return SparseArrayDOK{T}(dims)
39+
end
40+
41+
# Specify zero function
42+
function SparseArrayDOK{T}(dims::Tuple{Vararg{Int}}, zero) where {T}
43+
return SparseArrayDOK{T,length(dims)}(dims, zero)
44+
end
45+
46+
# undef
47+
function SparseArrayDOK{T,N,Zero}(
48+
::UndefInitializer, dims::Tuple{Vararg{Int}}, zero
49+
) where {T,N,Zero}
50+
return SparseArrayDOK{T,N,Zero}(dims, zero)
51+
end
52+
53+
function SparseArrayDOK{T,N}(::UndefInitializer, dims::Tuple{Vararg{Int}}, zero) where {T,N}
54+
return SparseArrayDOK{T,N}(dims, zero)
55+
end
56+
57+
function SparseArrayDOK{T,N}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T,N}
58+
return SparseArrayDOK{T,N}(dims)
59+
end
60+
61+
function SparseArrayDOK{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
62+
return SparseArrayDOK{T}(dims)
63+
end
64+
65+
# Axes version
66+
function SparseArrayDOK{T}(
67+
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange}}
68+
) where {T}
69+
@assert all(isone, first.(axes))
70+
return SparseArrayDOK{T}(length.(axes))
71+
end
72+
73+
function SparseArrayDOK{T}(::UndefInitializer, dims::Int...) where {T}
74+
return SparseArrayDOK{T}(dims...)
75+
end
76+
77+
function SparseArrayDOK{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}, zero) where {T}
78+
return SparseArrayDOK{T}(dims, zero)
79+
end
80+
81+
# Base `AbstractArray` interface
82+
Base.size(a::SparseArrayDOK) = a.dims[]
83+
84+
SparseArraysBase.getindex_zero_function(a::SparseArrayDOK) = a.zero
85+
function SparseArraysBase.set_getindex_zero_function(a::SparseArrayDOK, f)
86+
return @set a.zero = f
87+
end
88+
89+
function SparseArraysBase.setindex_notstored!(
90+
a::SparseArrayDOK{<:Any,N}, value, I::CartesianIndex{N}
91+
) where {N}
92+
set!(SparseArraysBase.sparse_storage(a), I, value)
93+
return a
94+
end
95+
96+
function Base.similar(a::SparseArrayDOK, elt::Type, dims::Tuple{Vararg{Int}})
97+
return SparseArrayDOK{elt}(undef, dims, getindex_zero_function(a))
98+
end
99+
100+
# `SparseArraysBase` interface
101+
SparseArraysBase.sparse_storage(a::SparseArrayDOK) = a.data
102+
103+
function SparseArraysBase.dropall!(a::SparseArrayDOK)
104+
return empty!(SparseArraysBase.sparse_storage(a))
105+
end
106+
107+
SparseArrayDOK(a::AbstractArray) = SparseArrayDOK{eltype(a)}(a)
108+
109+
SparseArrayDOK{T}(a::AbstractArray) where {T} = SparseArrayDOK{T,ndims(a)}(a)
110+
111+
function SparseArrayDOK{T,N}(a::AbstractArray) where {T,N}
112+
return SparseArraysBase.sparse_convert(SparseArrayDOK{T,N}, a)
113+
end
114+
115+
function Base.resize!(a::SparseArrayDOK{<:Any,N}, new_size::NTuple{N,Integer}) where {N}
116+
a.dims[] = new_size
117+
return a
118+
end
119+
120+
function setindex_maybe_grow!(a::SparseArrayDOK{<:Any,N}, value, I::Vararg{Int,N}) where {N}
121+
if any(I .> size(a))
122+
resize!(a, max.(I, size(a)))
123+
end
124+
a[I...] = value
125+
return a
126+
end
127+
128+
function is_setindex!_expr(expr::Expr)
129+
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
130+
end
131+
is_setindex!_expr(x) = false
132+
133+
is_getindex_expr(expr::Expr) = (expr.head === :ref)
134+
is_getindex_expr(x) = false
135+
136+
is_assignment_expr(expr::Expr) = (expr.head === :(=))
137+
is_assignment_expr(expr) = false
138+
139+
macro maybe_grow(expr)
140+
if !is_setindex!_expr(expr)
141+
error(
142+
"@maybe_grow must be used with setindex! syntax (as @maybe_grow a[i,j,...] = value)"
143+
)
144+
end
145+
@capture(expr, array_[indices__] = value_)
146+
return :(setindex_maybe_grow!($(esc(array)), $(esc(value)), $(esc.(indices)...)))
147+
end
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
const SparseMatrixDOK{T} = SparseArrayDOK{T,2}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
const SparseVectorDOK{T} = SparseArrayDOK{T,1}

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
4+
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
45
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
56
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@eval module $(gensym())
2-
for filename in ["abstractsparsearray", "array", "diagonalarray"]
2+
for filename in ["sparsearraydok", "abstractsparsearray", "array", "diagonalarray"]
33
include("test_$filename.jl")
44
end
55
end

test/test_sparsearraydok.jl

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
@eval module $(gensym())
2+
3+
# TODO: Test:
4+
# zero (PermutedDimsArray)
5+
# Custom zero type
6+
# Slicing
7+
8+
using Dictionaries: Dictionary
9+
using Test: @test, @testset, @test_broken
10+
using NDTensors.SparseArraysBase:
11+
SparseArraysBase, SparseArrayDOK, SparseMatrixDOK, @maybe_grow
12+
using NDTensors.SparseArraysBase: storage_indices, stored_length
13+
using SparseArrays: SparseMatrixCSC, nnz
14+
@testset "SparseArrayDOK (eltype=$elt)" for elt in
15+
(Float32, ComplexF32, Float64, ComplexF64)
16+
@testset "Basics" begin
17+
a = SparseArrayDOK{elt}(3, 4)
18+
@test a == SparseArrayDOK{elt}((3, 4))
19+
@test a == SparseArrayDOK{elt}(undef, 3, 4)
20+
@test a == SparseArrayDOK{elt}(undef, (3, 4))
21+
@test iszero(a)
22+
@test iszero(nnz(a))
23+
@test stored_length(a) == nnz(a)
24+
@test size(a) == (3, 4)
25+
@test eltype(a) == elt
26+
for I in eachindex(a)
27+
@test iszero(a[I])
28+
@test a[I] isa elt
29+
end
30+
@test isempty(storage_indices(a))
31+
32+
x12 = randn(elt)
33+
x23 = randn(elt)
34+
b = copy(a)
35+
@test b isa SparseArrayDOK{elt}
36+
@test iszero(b)
37+
b[1, 2] = x12
38+
b[2, 3] = x23
39+
@test iszero(a)
40+
@test !iszero(b)
41+
@test b[1, 2] == x12
42+
@test b[2, 3] == x23
43+
@test iszero(stored_length(a))
44+
@test stored_length(b) == 2
45+
end
46+
@testset "map/broadcast" begin
47+
a = SparseArrayDOK{elt}(3, 4)
48+
a[1, 1] = 11
49+
a[3, 4] = 34
50+
@test stored_length(a) == 2
51+
b = 2 * a
52+
@test stored_length(b) == 2
53+
@test b[1, 1] == 2 * 11
54+
@test b[3, 4] == 2 * 34
55+
end
56+
@testset "reshape" begin
57+
a = SparseArrayDOK{elt}(2, 2, 2)
58+
a[1, 2, 2] = 122
59+
b = reshape(a, 2, 4)
60+
@test b[1, 4] == 122
61+
end
62+
@testset "Matrix multiplication" begin
63+
a1 = SparseArrayDOK{elt}(2, 3)
64+
a1[1, 2] = 12
65+
a1[2, 1] = 21
66+
a2 = SparseArrayDOK{elt}(3, 4)
67+
a2[1, 1] = 11
68+
a2[2, 2] = 22
69+
a2[3, 3] = 33
70+
a_dest = a1 * a2
71+
# TODO: Use `densearray` to make generic to GPU.
72+
@test Array(a_dest) Array(a1) * Array(a2)
73+
# TODO: Make this work with `ArrayLayouts`.
74+
@test stored_length(a_dest) == 2
75+
@test a_dest isa SparseMatrixDOK{elt}
76+
77+
a2 = randn(elt, (3, 4))
78+
a_dest = a1 * a2
79+
# TODO: Use `densearray` to make generic to GPU.
80+
@test Array(a_dest) Array(a1) * Array(a2)
81+
@test stored_length(a_dest) == 8
82+
@test a_dest isa Matrix{elt}
83+
end
84+
@testset "SparseMatrixCSC" begin
85+
a = SparseArrayDOK{elt}(2, 2)
86+
a[1, 2] = 12
87+
for (type, a′) in ((SparseMatrixCSC, a), (SparseArrayDOK, SparseMatrixCSC(a)))
88+
b = type(a′)
89+
@test b isa type{elt}
90+
@test b[1, 2] == 12
91+
@test isone(nnz(b))
92+
for I in eachindex(b)
93+
if I CartesianIndex(1, 2)
94+
@test iszero(b[I])
95+
end
96+
end
97+
end
98+
end
99+
@testset "Maybe Grow Feature" begin
100+
a = SparseArrayDOK{elt,2}((0, 0))
101+
SparseArraysBase.setindex_maybe_grow!(a, 230, 2, 3)
102+
@test size(a) == (2, 3)
103+
@test a[2, 3] == 230
104+
# Test @maybe_grow macro
105+
@maybe_grow a[5, 5] = 550
106+
@test size(a) == (5, 5)
107+
@test a[2, 3] == 230
108+
@test a[5, 5] == 550
109+
# Test that size remains same
110+
# if we set at an index smaller than
111+
# the maximum size:
112+
@maybe_grow a[3, 4] = 340
113+
@test size(a) == (5, 5)
114+
@test a[2, 3] == 230
115+
@test a[5, 5] == 550
116+
@test a[3, 4] == 340
117+
# Test vector case
118+
v = SparseArrayDOK{elt,1}((0,))
119+
@maybe_grow v[5] = 50
120+
@test size(v) == (5,)
121+
@test v[5] == 50
122+
# Test setting from a variable (to test macro escaping)
123+
i = 6
124+
val = 60
125+
@maybe_grow v[i] = val
126+
@test v[i] == val
127+
i, j = 1, 2
128+
val = 120
129+
@maybe_grow a[i, j] = val
130+
@test a[i, j] == val
131+
end
132+
@testset "Test Lower Level Constructor" begin
133+
d = Dictionary{CartesianIndex{2},elt}()
134+
a = SparseArrayDOK(d, (2, 2), zero(elt))
135+
a[1, 2] = 12.0
136+
@test a[1, 2] == 12.0
137+
end
138+
end
139+
end

0 commit comments

Comments
 (0)