Skip to content

Commit 263ac28

Browse files
authored
Slicing and trailing indices (#13)
1 parent 3763330 commit 263ac28

File tree

3 files changed

+79
-33
lines changed

3 files changed

+79
-33
lines changed

src/abstractarrayinterface.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# TODO: Add `ndims` type parameter.
22
abstract type AbstractArrayInterface <: AbstractInterface end
33

4+
# TODO: Define as `DefaultArrayInterface()`.
45
function interface(::Type{<:Broadcast.AbstractArrayStyle})
56
return error("Not defined.")
67
end
@@ -18,23 +19,25 @@ using ArrayLayouts: ArrayLayouts
1819
return ArrayLayouts.layout_getindex(a, I...)
1920
end
2021

21-
@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I::Int...)
22-
# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
23-
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
24-
# TODO: Use `MethodError`?
22+
# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
23+
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
24+
# TODO: Use `MethodError`?
25+
@interface ::AbstractArrayInterface function Base.getindex(
26+
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
27+
) where {N}
2528
return error("Not implemented.")
2629
end
2730

2831
@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type)
2932
return Broadcast.DefaultArrayStyle{ndims(type)}()
3033
end
3134

35+
# TODO: Maybe define as `Array{T}(undef, size...)` or
36+
# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`.
37+
# TODO: Use `MethodError`?
3238
@interface interface::AbstractArrayInterface function Base.similar(
3339
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
3440
)
35-
# TODO: Maybe define as `Array{T}(undef, size...)` or
36-
# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`.
37-
# TODO: Use `MethodError`?
3841
return similar(arraytype(interface, T), size)
3942
end
4043

@@ -43,12 +46,12 @@ end
4346
return a_dest .= a
4447
end
4548

49+
# TODO: Use `Base.to_shape(axes)` or
50+
# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`.
4651
# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`).
4752
@interface interface::AbstractArrayInterface function Base.similar(
4853
a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
4954
)
50-
# TODO: Use `Base.to_shape(axes)` or
51-
# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`.
5255
return @interface interface similar(a, T, Base.to_shape(axes))
5356
end
5457

@@ -80,12 +83,12 @@ end
8083
return f.(as...)
8184
end
8285

86+
# TODO: Maybe define as
87+
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`.
88+
# TODO: Use `MethodError`?
8389
@interface ::AbstractArrayInterface function Base.map!(
8490
f, dest::AbstractArray, as::AbstractArray...
8591
)
86-
# TODO: Maybe define as
87-
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`.
88-
# TODO: Use `MethodError`?
8992
return error("Not implemented.")
9093
end
9194

test/basics/SparseArrayDOKs.jl

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
module SparseArrayDOKs
22

3+
isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...)
4+
getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...)
5+
getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...)
6+
function setstoredindex!(a::AbstractArray, value, I::CartesianIndex)
7+
return setstoredindex!(a, value, Tuple(I)...)
8+
end
9+
function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex)
10+
return setunstoredindex!(a, value, Tuple(I)...)
11+
end
12+
313
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout
414
using Derive: Derive, @array_aliases, @derive, @interface, AbstractArrayInterface, interface
515
using LinearAlgebra: LinearAlgebra
@@ -8,14 +18,16 @@ using LinearAlgebra: LinearAlgebra
818
struct SparseArrayInterface <: AbstractArrayInterface end
919

1020
# Define interface functions.
11-
@interface ::SparseArrayInterface function Base.getindex(a::AbstractArray, I::Int...)
21+
@interface ::SparseArrayInterface function Base.getindex(
22+
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
23+
) where {N}
1224
checkbounds(a, I...)
1325
!isstored(a, I...) && return getunstoredindex(a, I...)
1426
return getstoredindex(a, I...)
1527
end
1628
@interface ::SparseArrayInterface function Base.setindex!(
17-
a::AbstractArray, value, I::Int...
18-
)
29+
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
30+
) where {N}
1931
checkbounds(a, I...)
2032
iszero(value) && return a
2133
if !isstored(a, I...)
@@ -93,19 +105,42 @@ function eachstoredindex(a::Adjoint)
93105
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
94106
end
95107

108+
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
109+
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
110+
111+
# TODO: Use `Base.PermutedDimsArrays.genperm` or
112+
# https://github.com/jipolanco/StaticPermutations.jl?
113+
genperm(v, perm) = map(j -> v[j], perm)
114+
96115
function isstored(a::PermutedDimsArray, I::Int...)
97-
return isstored(parent(a), reverse(I)...)
116+
return isstored(parent(a), genperm(I, iperm(a))...)
98117
end
99118
function getstoredindex(a::PermutedDimsArray, I::Int...)
100-
return getstoredindex(parent(a), reverse(I)...)
119+
return getstoredindex(parent(a), genperm(I, iperm(a))...)
101120
end
102121
function getunstoredindex(a::PermutedDimsArray, I::Int...)
103-
return getunstoredindex(parent(a), reverse(I)...)
122+
return getunstoredindex(parent(a), genperm(I, iperm(a))...)
104123
end
105124
function eachstoredindex(a::PermutedDimsArray)
106-
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
125+
return map(collect(eachstoredindex(parent(a)))) do I
126+
return CartesianIndex(genperm(I, perm(a)))
127+
end
107128
end
108129

130+
tuple_oneto(n) = ntuple(identity, n)
131+
## This is an optimization for `storedvalues` for DOK.
132+
## function valuesview(d::Dict, keys)
133+
## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
134+
## end
135+
136+
function eachstoredparentindex(a::SubArray)
137+
return filter(eachstoredindex(parent(a))) do I
138+
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
139+
end
140+
end
141+
function storedvalues(a::SubArray)
142+
return @view parent(a)[collect(eachstoredparentindex(a))]
143+
end
109144
function isstored(a::SubArray, I::Int...)
110145
return isstored(parent(a), Base.reindex(parentindices(a), I)...)
111146
end
@@ -115,18 +150,23 @@ end
115150
function getunstoredindex(a::SubArray, I::Int...)
116151
return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
117152
end
153+
function setstoredindex!(a::SubArray, value, I::Int...)
154+
return setstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...)
155+
end
156+
function setunstoredindex!(a::SubArray, value, I::Int...)
157+
return setunstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...)
158+
end
118159
function eachstoredindex(a::SubArray)
119-
nonscalardims = filter(ntuple(identity, ndims(parent(a)))) do d
160+
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
120161
return !(parentindices(a)[d] isa Real)
121162
end
122-
nonscalar_parentindices = map(d -> parentindices(a)[d], nonscalardims)
123-
subindices = filter(eachstoredindex(parent(a))) do I
124-
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
125-
end
126-
return map(collect(subindices)) do I
127-
I_nonscalar = CartesianIndex(map(d -> I[d], nonscalardims))
128-
return CartesianIndex(Base.reindex(nonscalar_parentindices, Tuple(I_nonscalar)))
129-
end
163+
return collect((
164+
CartesianIndex(
165+
map(nonscalardims) do d
166+
return findfirst(==(I[d]), parentindices(a)[d])
167+
end,
168+
) for I in eachstoredparentindex(a)
169+
))
130170
end
131171

132172
# Define a type that will derive the interface.

test/basics/test_basics.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
88
a[1, 2] = 12
99
@test a isa SparseArrayDOK{elt,2}
1010
@test size(a) == (2, 2)
11+
@test a[1, 1] == 0
12+
@test a[1, 1, 1] == 0
1113
@test a[1, 2] == 12
14+
@test a[1, 2, 1] == 12
1215
@test storedlength(a) == 1
1316

1417
a = SparseArrayDOK{elt}(2, 2)
@@ -28,11 +31,11 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
2831
@test b == [0 24; 0 0]
2932
@test storedlength(b) == 1
3033

31-
a = SparseArrayDOK{elt}(2, 2)
32-
a[1, 2] = 12
33-
b = permutedims(a, (2, 1))
34-
@test b isa SparseArrayDOK{elt,2}
35-
@test b == [0 0; 12 0]
34+
a = SparseArrayDOK{elt}(3, 3, 3)
35+
a[1, 2, 3] = 123
36+
b = permutedims(a, (2, 3, 1))
37+
@test b isa SparseArrayDOK{elt,3}
38+
@test b[2, 3, 1] == 123
3639
@test storedlength(b) == 1
3740

3841
a = SparseArrayDOK{elt}(2, 2)

0 commit comments

Comments
 (0)