Skip to content

Commit 1b9dfe8

Browse files
committed
Define fallback for sparse interface for dense arrays
1 parent c99a04a commit 1b9dfe8

File tree

3 files changed

+84
-17
lines changed

3 files changed

+84
-17
lines changed

src/abstractsparsearrayinterface.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
# Minimal interface for `SparseArrayInterface`.
2-
# TODO: Define default definitions for these based
3-
# on the dense case.
4-
# TODO: Define as `MethodError`.
5-
## isstored(a::AbstractArray, I::Int...) = true
6-
isstored(a::AbstractArray, I::Int...) = error("Not implemented.")
7-
## eachstoredindex(a::AbstractArray) = eachindex(a)
8-
eachstoredindex(a::AbstractArray) = error("Not implemented.")
9-
## getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...)
10-
getstoredindex(a::AbstractArray, I::Int...) = error("Not implemented.")
11-
## setstoredindex!(a::AbstractArray, value, I::Int...) = setindex!(a, value, I...)
12-
setstoredindex!(a::AbstractArray, value, I::Int...) = error("Not implemented.")
13-
## setunstoredindex!(a::AbstractArray, value, I::Int...) = setindex!(a, value, I...)
14-
setunstoredindex!(a::AbstractArray, value, I::Int...) = error("Not implemented.")
2+
isstored(a::AbstractArray, I::Int...) = true
3+
eachstoredindex(a::AbstractArray) = eachindex(a)
4+
getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...)
5+
function setstoredindex!(a::AbstractArray, value, I::Int...)
6+
setindex!(a, value, I...)
7+
return a
8+
end
9+
# TODO: Should this error by default if the value at the index
10+
# is stored? It could be disabled with something analogous
11+
# to `checkbounds`, like `checkstored`/`checkunstored`.
12+
function setunstoredindex!(a::AbstractArray, value, I::Int...)
13+
setindex!(a, value, I...)
14+
return a
15+
end
1516

1617
# TODO: Use `Base.to_indices`?
1718
isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...)
19+
# TODO: Use `Base.to_indices`?
1820
getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...)
21+
# TODO: Use `Base.to_indices`?
1922
getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...)
23+
# TODO: Use `Base.to_indices`?
2024
function setstoredindex!(a::AbstractArray, value, I::CartesianIndex)
2125
return setstoredindex!(a, value, Tuple(I)...)
2226
end
27+
# TODO: Use `Base.to_indices`?
2328
function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex)
2429
return setunstoredindex!(a, value, Tuple(I)...)
2530
end
@@ -33,6 +38,9 @@ getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a))
3338
storedlength(a::AbstractArray) = length(storedvalues(a))
3439
storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
3540

41+
to_vec(x) = vec(collect(x))
42+
to_vec(x::AbstractArray) = vec(x)
43+
3644
# A view of the stored values of an array.
3745
# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
3846
# with that is it returns a `SubArray` wrapping a sparse array, which
@@ -47,7 +55,7 @@ struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T}
4755
array::A
4856
storedindices::I
4957
end
50-
StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a)))
58+
StoredValues(a::AbstractArray) = StoredValues(a, to_vec(eachstoredindex(a)))
5159
Base.size(a::StoredValues) = size(a.storedindices)
5260
Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I])
5361
function Base.setindex!(a::StoredValues, value, I::Int)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
34
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
45
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
6+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
57
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
68
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
79
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"

test/basics/test_basics.jl

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,63 @@
1-
using SparseArraysBase: SparseArraysBase
1+
using Adapt: adapt
2+
using JLArrays: JLArray, @allowscalar
3+
using SparseArraysBase:
4+
SparseArraysBase,
5+
eachstoredindex,
6+
getstoredindex,
7+
getunstoredindex,
8+
isstored,
9+
setstoredindex!,
10+
setunstoredindex!,
11+
storedlength,
12+
storedpairs,
13+
storedvalues
214
using Test: @test, @testset
315

4-
@testset "SparseArraysBase" begin
5-
# Tests go here.
16+
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
17+
arrayts = (Array, JLArray)
18+
@testset "SparseArraysBase (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
19+
elt in elts
20+
21+
dev(x) = adapt(arrayt, x)
22+
23+
n = 2
24+
a = dev(randn(elt, n, n))
25+
@test storedlength(a) == length(a)
26+
for indexstyle in (IndexLinear(), IndexCartesian())
27+
for I in eachindex(indexstyle, a)
28+
@test isstored(a, I)
29+
end
30+
end
31+
@test eachstoredindex(a) == eachindex(a)
32+
# TODO: We should be specializing these for dense/strided arrays,
33+
# probably we can have a trait for that. It could be based
34+
# on the `ArrayLayouts.MemoryLayout`.
35+
@allowscalar @test storedvalues(a) == vec(a)
36+
@allowscalar @test storedpairs(a) == collect(pairs(vec(a)))
37+
@allowscalar for I in eachindex(a)
38+
@test getstoredindex(a, I) == a[I]
39+
@test iszero(getunstoredindex(a, I))
40+
end
41+
@allowscalar for I in eachindex(IndexCartesian(), a)
42+
@test getstoredindex(a, I) == a[I]
43+
@test iszero(getunstoredindex(a, I))
44+
end
45+
46+
a = dev(randn(elt, n, n))
47+
for I in ((1, 2), (CartesianIndex(1, 2),))
48+
b = copy(a)
49+
value = randn(elt)
50+
@allowscalar setstoredindex!(b, value, I...)
51+
@allowscalar b[I...] == value
52+
end
53+
54+
# TODO: Should `setunstoredindex!` error by default
55+
# if the value at that index is already stored?
56+
a = dev(randn(elt, n, n))
57+
for I in ((1, 2), (CartesianIndex(1, 2),))
58+
b = copy(a)
59+
value = randn(elt)
60+
@allowscalar setunstoredindex!(b, value, I...)
61+
@allowscalar b[I...] == value
62+
end
663
end

0 commit comments

Comments
 (0)