1
1
module SparseArrayDOKs
2
2
3
+ isstored (a:: AbstractArray , I:: CartesianIndex ) = isstored (a, Tuple (I)... )
4
+ getstoredindex (a:: AbstractArray , I:: CartesianIndex ) = getstoredindex (a, Tuple (I)... )
5
+ getunstoredindex (a:: AbstractArray , I:: CartesianIndex ) = getunstoredindex (a, Tuple (I)... )
6
+ function setstoredindex! (a:: AbstractArray , value, I:: CartesianIndex )
7
+ return setstoredindex! (a, value, Tuple (I)... )
8
+ end
9
+ function setunstoredindex! (a:: AbstractArray , value, I:: CartesianIndex )
10
+ return setunstoredindex! (a, value, Tuple (I)... )
11
+ end
12
+
3
13
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout
4
14
using Derive: Derive, @array_aliases , @derive , @interface , AbstractArrayInterface, interface
5
15
using LinearAlgebra: LinearAlgebra
@@ -8,14 +18,16 @@ using LinearAlgebra: LinearAlgebra
8
18
struct SparseArrayInterface <: AbstractArrayInterface end
9
19
10
20
# Define interface functions.
11
- @interface :: SparseArrayInterface function Base. getindex (a:: AbstractArray , I:: Int... )
21
+ @interface :: SparseArrayInterface function Base. getindex (
22
+ a:: AbstractArray{<:Any,N} , I:: Vararg{Int,N}
23
+ ) where {N}
12
24
checkbounds (a, I... )
13
25
! isstored (a, I... ) && return getunstoredindex (a, I... )
14
26
return getstoredindex (a, I... )
15
27
end
16
28
@interface :: SparseArrayInterface function Base. setindex! (
17
- a:: AbstractArray , value, I:: Int...
18
- )
29
+ a:: AbstractArray{<:Any,N} , value, I:: Vararg{ Int,N}
30
+ ) where {N}
19
31
checkbounds (a, I... )
20
32
iszero (value) && return a
21
33
if ! isstored (a, I... )
@@ -93,19 +105,42 @@ function eachstoredindex(a::Adjoint)
93
105
return map (CartesianIndex ∘ reverse ∘ Tuple, collect (eachstoredindex (parent (a))))
94
106
end
95
107
108
+ perm (:: PermutedDimsArray{<:Any,<:Any,p} ) where {p} = p
109
+ iperm (:: PermutedDimsArray{<:Any,<:Any,<:Any,ip} ) where {ip} = ip
110
+
111
+ # TODO : Use `Base.PermutedDimsArrays.genperm` or
112
+ # https://github.com/jipolanco/StaticPermutations.jl?
113
+ genperm (v, perm) = map (j -> v[j], perm)
114
+
96
115
function isstored (a:: PermutedDimsArray , I:: Int... )
97
- return isstored (parent (a), reverse (I )... )
116
+ return isstored (parent (a), genperm (I, iperm (a) )... )
98
117
end
99
118
function getstoredindex (a:: PermutedDimsArray , I:: Int... )
100
- return getstoredindex (parent (a), reverse (I )... )
119
+ return getstoredindex (parent (a), genperm (I, iperm (a) )... )
101
120
end
102
121
function getunstoredindex (a:: PermutedDimsArray , I:: Int... )
103
- return getunstoredindex (parent (a), reverse (I )... )
122
+ return getunstoredindex (parent (a), genperm (I, iperm (a) )... )
104
123
end
105
124
function eachstoredindex (a:: PermutedDimsArray )
106
- return map (CartesianIndex ∘ reverse ∘ Tuple, collect (eachstoredindex (parent (a))))
125
+ return map (collect (eachstoredindex (parent (a)))) do I
126
+ return CartesianIndex (genperm (I, perm (a)))
127
+ end
107
128
end
108
129
130
+ tuple_oneto (n) = ntuple (identity, n)
131
+ # # This is an optimization for `storedvalues` for DOK.
132
+ # # function valuesview(d::Dict, keys)
133
+ # # return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
134
+ # # end
135
+
136
+ function eachstoredparentindex (a:: SubArray )
137
+ return filter (eachstoredindex (parent (a))) do I
138
+ return all (d -> I[d] ∈ parentindices (a)[d], 1 : ndims (parent (a)))
139
+ end
140
+ end
141
+ function storedvalues (a:: SubArray )
142
+ return @view parent (a)[collect (eachstoredparentindex (a))]
143
+ end
109
144
function isstored (a:: SubArray , I:: Int... )
110
145
return isstored (parent (a), Base. reindex (parentindices (a), I)... )
111
146
end
@@ -115,18 +150,23 @@ end
115
150
function getunstoredindex (a:: SubArray , I:: Int... )
116
151
return getunstoredindex (parent (a), Base. reindex (parentindices (a), I)... )
117
152
end
153
+ function setstoredindex! (a:: SubArray , value, I:: Int... )
154
+ return setstoredindex! (parent (a), value, Base. reindex (parentindices (a), I)... )
155
+ end
156
+ function setunstoredindex! (a:: SubArray , value, I:: Int... )
157
+ return setunstoredindex! (parent (a), value, Base. reindex (parentindices (a), I)... )
158
+ end
118
159
function eachstoredindex (a:: SubArray )
119
- nonscalardims = filter (ntuple (identity, ndims (parent (a)))) do d
160
+ nonscalardims = filter (tuple_oneto ( ndims (parent (a)))) do d
120
161
return ! (parentindices (a)[d] isa Real)
121
162
end
122
- nonscalar_parentindices = map (d -> parentindices (a)[d], nonscalardims)
123
- subindices = filter (eachstoredindex (parent (a))) do I
124
- return all (d -> I[d] ∈ parentindices (a)[d], 1 : ndims (parent (a)))
125
- end
126
- return map (collect (subindices)) do I
127
- I_nonscalar = CartesianIndex (map (d -> I[d], nonscalardims))
128
- return CartesianIndex (Base. reindex (nonscalar_parentindices, Tuple (I_nonscalar)))
129
- end
163
+ return collect ((
164
+ CartesianIndex (
165
+ map (nonscalardims) do d
166
+ return findfirst (== (I[d]), parentindices (a)[d])
167
+ end ,
168
+ ) for I in eachstoredparentindex (a)
169
+ ))
130
170
end
131
171
132
172
# Define a type that will derive the interface.
0 commit comments