@@ -32,9 +32,27 @@ getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a))
32
32
# Derived interface.
33
33
storedlength (a:: AbstractArray ) = length (storedvalues (a))
34
34
storedpairs (a:: AbstractArray ) = map (I -> I => getstoredindex (a, I), eachstoredindex (a))
35
- function storedvalues (a:: AbstractArray )
36
- return @view a[collect (eachstoredindex (a))]
37
- end
35
+
36
+ # A view of the stored values of an array.
37
+ # Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
38
+ # with that is it returns a `SubArray` wrapping a sparse array, which
39
+ # is then interpreted as a sparse array so it can lead to recursion.
40
+ # Also, that involves extra logic for determining if the indices are
41
+ # stored or not, but we know the indices are stored so we can use
42
+ # `getstoredindex` and `setstoredindex!`.
43
+ # Most sparse arrays should overload `storedvalues` directly
44
+ # and avoid this wrapper since it adds extra indirection to
45
+ # access stored values.
46
+ struct StoredValues{T,A<: AbstractArray{T} ,I} <: AbstractVector{T}
47
+ array:: A
48
+ storedindices:: I
49
+ end
50
+ StoredValues (a:: AbstractArray ) = StoredValues (a, collect (eachstoredindex (a)))
51
+ Base. size (a:: StoredValues ) = size (a. storedindices)
52
+ Base. getindex (a:: StoredValues , I:: Int ) = getstoredindex (a. array, a. storedindices[I])
53
+ Base. setindex! (a:: StoredValues , value, I:: Int ) = setstoredindex! (a. array, value, a. storedindices[I])
54
+
55
+ storedvalues (a:: AbstractArray ) = StoredValues (a)
38
56
39
57
function eachstoredindex (a1, a2, a_rest... )
40
58
# TODO : Make this more customizable, say with a function
64
82
@interface :: AbstractSparseArrayInterface function Base. setindex! (
65
83
a:: AbstractArray{<:Any,N} , value, I:: Vararg{Int,N}
66
84
) where {N}
67
- iszero (value) && return a
68
85
if ! isstored (a, I... )
86
+ iszero (value) && return a
69
87
setunstoredindex! (a, value, I... )
70
88
return a
71
89
end
94
112
return dest
95
113
end
96
114
115
+ # `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
116
+ function reduce_init (f, op, as... )
117
+ # TODO : Generalize this.
118
+ @assert isone (length (as))
119
+ a = only (as)
120
+ # # TODO : Make this more efficient for block sparse
121
+ # # arrays, in that case it allocates a block. Maybe
122
+ # # it can use `FillArrays.Zeros`.
123
+ return f (getunstoredindex (a, first (eachindex (a))))
124
+ end
125
+
126
+ @interface :: AbstractSparseArrayInterface function Base. mapreduce (
127
+ f, op, as:: AbstractArray... ; init= reduce_init (f, op, as... ), kwargs...
128
+ )
129
+ # TODO : Generalize this.
130
+ @assert isone (length (as))
131
+ a = only (as)
132
+ output = mapreduce (f, op, storedvalues (a); init, kwargs... )
133
+ # # TODO : Bring this check back, or make the function more general.
134
+ # # f_notstored = apply_notstored(f, a)
135
+ # # @assert isequal(op(output, eltype(output)(f_notstored)), output)
136
+ return output
137
+ end
138
+
97
139
abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
98
140
99
141
@derive AbstractSparseArrayStyle AbstractArrayStyleOps
0 commit comments