Skip to content

Commit 33b3d6c

Browse files
committed
Reductions
1 parent 9985386 commit 33b3d6c

File tree

4 files changed

+73
-4
lines changed

4 files changed

+73
-4
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ a = SparseArrayDOK{Float64}(2, 2)
5353
AbstractArray interface:
5454

5555
````julia
56+
@test iszero(a)
57+
@test iszero(sum(a))
58+
@test iszero(storedlength(a))
59+
5660
a[1, 2] = 12
5761
@test a == [0 12; 0 0]
5862
@test a[1, 1] == 0
@@ -78,6 +82,10 @@ using Dictionaries: IndexError
7882
@test storedlength(a) == 1
7983
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
8084
@test issetequal(storedvalues(a), [12])
85+
@test sum(a) == 12
86+
@test isreal(a)
87+
@test !iszero(a)
88+
@test mapreduce(x -> 2x, +, a) == 24
8189
````
8290

8391
AbstractArray functionality:
@@ -87,6 +95,10 @@ b = a .+ 2 .* a'
8795
@test b isa SparseMatrixDOK{Float64}
8896
@test b == [0 12; 24 0]
8997
@test storedlength(b) == 2
98+
@test sum(b) == 36
99+
@test isreal(b)
100+
@test !iszero(b)
101+
@test mapreduce(x -> 2x, +, b) == 72
90102

91103
b = permutedims(a, (2, 1))
92104
@test b isa SparseMatrixDOK{Float64}

examples/README.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ a = SparseArrayDOK{Float64}(2, 2)
5656

5757
# AbstractArray interface:
5858

59+
@test iszero(a)
60+
@test iszero(sum(a))
61+
@test iszero(storedlength(a))
62+
5963
a[1, 2] = 12
6064
@test a == [0 12; 0 0]
6165
@test a[1, 1] == 0
@@ -79,13 +83,21 @@ using Dictionaries: IndexError
7983
@test storedlength(a) == 1
8084
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
8185
@test issetequal(storedvalues(a), [12])
86+
@test sum(a) == 12
87+
@test isreal(a)
88+
@test !iszero(a)
89+
@test mapreduce(x -> 2x, +, a) == 24
8290

8391
# AbstractArray functionality:
8492

8593
b = a .+ 2 .* a'
8694
@test b isa SparseMatrixDOK{Float64}
8795
@test b == [0 12; 24 0]
8896
@test storedlength(b) == 2
97+
@test sum(b) == 36
98+
@test isreal(b)
99+
@test !iszero(b)
100+
@test mapreduce(x -> 2x, +, b) == 72
89101

90102
b = permutedims(a, (2, 1))
91103
@test b isa SparseMatrixDOK{Float64}

src/abstractsparsearrayinterface.jl

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,27 @@ getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a))
3232
# Derived interface.
3333
storedlength(a::AbstractArray) = length(storedvalues(a))
3434
storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
35-
function storedvalues(a::AbstractArray)
36-
return @view a[collect(eachstoredindex(a))]
37-
end
35+
36+
# A view of the stored values of an array.
37+
# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
38+
# with that is it returns a `SubArray` wrapping a sparse array, which
39+
# is then interpreted as a sparse array so it can lead to recursion.
40+
# Also, that involves extra logic for determining if the indices are
41+
# stored or not, but we know the indices are stored so we can use
42+
# `getstoredindex` and `setstoredindex!`.
43+
# Most sparse arrays should overload `storedvalues` directly
44+
# and avoid this wrapper since it adds extra indirection to
45+
# access stored values.
46+
struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T}
47+
array::A
48+
storedindices::I
49+
end
50+
StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a)))
51+
Base.size(a::StoredValues) = size(a.storedindices)
52+
Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I])
53+
Base.setindex!(a::StoredValues, value, I::Int) = setstoredindex!(a.array, value, a.storedindices[I])
54+
55+
storedvalues(a::AbstractArray) = StoredValues(a)
3856

3957
function eachstoredindex(a1, a2, a_rest...)
4058
# TODO: Make this more customizable, say with a function
@@ -64,8 +82,8 @@ end
6482
@interface ::AbstractSparseArrayInterface function Base.setindex!(
6583
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
6684
) where {N}
67-
iszero(value) && return a
6885
if !isstored(a, I...)
86+
iszero(value) && return a
6987
setunstoredindex!(a, value, I...)
7088
return a
7189
end
@@ -94,6 +112,30 @@ end
94112
return dest
95113
end
96114

115+
# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
116+
function reduce_init(f, op, as...)
117+
# TODO: Generalize this.
118+
@assert isone(length(as))
119+
a = only(as)
120+
## TODO: Make this more efficient for block sparse
121+
## arrays, in that case it allocates a block. Maybe
122+
## it can use `FillArrays.Zeros`.
123+
return f(getunstoredindex(a, first(eachindex(a))))
124+
end
125+
126+
@interface ::AbstractSparseArrayInterface function Base.mapreduce(
127+
f, op, as::AbstractArray...; init=reduce_init(f, op, as...), kwargs...
128+
)
129+
# TODO: Generalize this.
130+
@assert isone(length(as))
131+
a = only(as)
132+
output = mapreduce(f, op, storedvalues(a); init, kwargs...)
133+
## TODO: Bring this check back, or make the function more general.
134+
## f_notstored = apply_notstored(f, a)
135+
## @assert isequal(op(output, eltype(output)(f_notstored)), output)
136+
return output
137+
end
138+
97139
abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
98140

99141
@derive AbstractSparseArrayStyle AbstractArrayStyleOps

src/sparsearraydok.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ function getunstoredindex(a::SparseArrayDOK, I::Int...)
4949
return a.getunstoredindex(a, I...)
5050
end
5151
function setstoredindex!(a::SparseArrayDOK, value, I::Int...)
52+
# TODO: Have a way to disable this check, analogous to `checkbounds`,
53+
# since this is already checked in `setindex!`.
5254
isstored(a, I...) || throw(IndexError("key $(CartesianIndex(I)) not found"))
55+
# TODO: If `iszero(value)`, unstore the index.
5356
storage(a)[CartesianIndex(I)] = value
5457
return a
5558
end

0 commit comments

Comments
 (0)