Skip to content

Commit 447bbda

Browse files
authored
Merge pull request #41 from JuliaArrays/teh/getindex_view
Split getindex and view apart. Closes #38.
2 parents 6ace815 + 8b55b92 commit 447bbda

File tree

2 files changed

+72
-103
lines changed

2 files changed

+72
-103
lines changed

src/indexing.jl

Lines changed: 57 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,23 @@
1-
### Indexing returns either a scalar or a smartly-subindexed AxisArray ###
1+
typealias Idx Union{Real,Colon,AbstractArray{Int}}
22

3-
# Limit indexing to types supported by SubArrays, at least initially
4-
typealias Idx Union{Colon,Int,AbstractVector{Int}}
3+
using Base: ViewIndex, linearindexing, unsafe_getindex, unsafe_setindex!
54

65
# Defer linearindexing to the wrapped array
7-
import Base: linearindexing, unsafe_getindex, unsafe_setindex!
86
Base.linearindexing{T,N,D}(::AxisArray{T,N,D}) = linearindexing(D)
97

108
# Simple scalar indexing where we just set or return scalars
11-
Base.getindex(A::AxisArray, idxs::Int...) = A.data[idxs...]
12-
Base.setindex!(A::AxisArray, v, idxs::Int...) = (A.data[idxs...] = v)
13-
14-
# Default to views already
15-
Base.getindex{T}(A::AxisArray{T,1}, idx::Colon) = A
9+
@inline Base.getindex(A::AxisArray, idxs::Int...) = A.data[idxs...]
10+
@inline Base.setindex!(A::AxisArray, v, idxs::Int...) = (A.data[idxs...] = v)
1611

1712
# Cartesian iteration
1813
Base.eachindex(A::AxisArray) = eachindex(A.data)
1914
Base.getindex(A::AxisArray, idx::Base.IteratorsMD.CartesianIndex) = A.data[idx]
2015
Base.setindex!(A::AxisArray, v, idx::Base.IteratorsMD.CartesianIndex) = (A.data[idx] = v)
2116

22-
# More complicated cases where we must create a subindexed AxisArray
23-
# TODO: do we want to be dogmatic about using views? For the data? For the axes?
24-
# TODO: perhaps it would be better to return an entirely lazy SubAxisArray view
25-
@generated function Base.getindex{T,N,D,Ax}(A::AxisArray{T,N,D,Ax}, idxs::Idx...)
26-
newdims = length(idxs)
27-
# If the last index is a linear indexing range that may span multiple
28-
# dimensions in the original AxisArray, we can no longer track those axes.
29-
droplastaxis = N > newdims && !(idxs[end] <: Real) ? 1 : 0
30-
# Drop trailing scalar dimensions
31-
while newdims > 0 && idxs[newdims] <: Real
32-
newdims -= 1
33-
end
34-
names = axisnames(A)
35-
axes = Expr(:tuple)
36-
Isplat = Expr[]
37-
reshape = false
38-
newshape = Expr[]
39-
for i = 1:newdims-droplastaxis
40-
prepaxis!(axes.args, Isplat, idxs[i], names, i)
41-
end
42-
for i = newdims-droplastaxis+1:length(idxs)
43-
push!(Isplat, :(idxs[$i]))
44-
end
45-
quote
46-
data = view(A.data, $(Isplat...))
47-
AxisArray(data, $axes) # TODO: avoid checking the axes here
48-
end
49-
end
50-
51-
# When we index with non-vector arrays, we *add* dimensions. This isn't
52-
# supported by SubArray currently, so we instead return a copy.
53-
# TODO: we probably shouldn't hack Base like this, but it's so convenient...
54-
if VERSION < v"0.5.0-dev"
55-
@inline Base.index_shape_dim(A, dim, i::AbstractArray{Bool}, I...) = (sum(i), Base.index_shape_dim(A, dim+1, I...)...)
56-
@inline Base.index_shape_dim(A, dim, i::AbstractArray, I...) = (size(i)..., Base.index_shape_dim(A, dim+1, I...)...)
57-
end
58-
@generated function Base.getindex(A::AxisArray, I::Union{Idx, AbstractArray{Int}}...)
17+
@generated function reaxis(A::AxisArray, I::Idx...)
5918
N = length(I)
60-
Isplat = [:(I[$d]) for d=1:N]
6119
# Determine the new axes:
62-
# Like above, drop linear indexing over multiple axes
20+
# Drop linear indexing over multiple axes
6321
droplastaxis = ndims(A) > N && !(I[end] <: Real) ? 1 : 0
6422
# Drop trailing scalar dimensions
6523
lastnonscalar = N
@@ -70,44 +28,74 @@ end
7028
newaxes = Expr[]
7129
for d=1:lastnonscalar-droplastaxis
7230
if I[d] <: AxisArray
31+
# Indexing with an AxisArray joins the axis names
7332
idxnames = axisnames(I[d])
7433
for i=1:ndims(I[d])
7534
push!(newaxes, :($(Axis{Symbol(names[d], "_", idxnames[i])})(I[$d].axes[$i].val)))
7635
end
77-
elseif I[d] <: Idx
78-
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val[J[$d]])))
36+
elseif I[d] <: Real
37+
elseif I[d] <: Union{AbstractVector,Colon}
38+
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val[Base.to_index(I[$d])])))
7939
elseif I[d] <: AbstractArray
8040
for i=1:ndims(I[d])
41+
# When we index with non-vector arrays, we *add* dimensions.
8142
push!(newaxes, :($(Axis{Symbol(names[d], "_", i)})(indices(I[$d], $i))))
8243
end
8344
end
8445
end
8546
quote
86-
# First copy the data using scalar indexing - an adaptation of Base
87-
checkbounds(A, I...)
88-
J = Base.to_indexes($(Isplat...))
89-
sz = Base.index_shape(A, J...)
90-
idx_lens = Base.index_lengths(A, J...)
91-
src = A.data
92-
dest = similar(A.data, sz)
93-
D = eachindex(dest)
94-
Ds = start(D)
95-
Base.Cartesian.@nloops $N i d->(1:idx_lens[d]) d->(@inbounds j_d = J[d][i_d]) begin
96-
d, Ds = next(D, Ds)
97-
v = Base.Cartesian.@ncall $N unsafe_getindex src j
98-
unsafe_setindex!(dest, v, d)
99-
end
100-
# And now create the AxisArray:
101-
AxisArray(dest, $(newaxes...))
47+
($(newaxes...),)
48+
end
49+
end
50+
51+
@inline function Base.getindex(A::AxisArray, idxs::Idx...)
52+
AxisArray(A.data[idxs...], reaxis(A, idxs...))
53+
end
54+
55+
# To resolve ambiguities, we need several definitions
56+
if VERSION >= v"0.6.0-dev.672"
57+
using Base.AbstractCartesianIndex
58+
Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
59+
else
60+
@inline function Base.view{T,N}(A::AxisArray{T,N}, idxs::Vararg{Idx,N})
61+
AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
62+
end
63+
function Base.view(A::AxisArray, idx::Idx)
64+
AxisArray(view(A.data, idx), reaxis(A, idx))
65+
end
66+
@inline function Base.view{N}(A::AxisArray, idxs::Vararg{Idx,N})
67+
# this should eventually be deleted, see julia #14770
68+
AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
10269
end
10370
end
10471

10572
# Setindex is so much simpler. Just assign it to the data:
106-
Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
73+
@inline Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
10774

10875
### Fancier indexing capabilities provided only by AxisArrays ###
109-
Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
110-
Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)
76+
@inline Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
77+
@inline Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)
78+
# Deal with lots of ambiguities here
79+
if VERSION >= v"0.6.0-dev.672"
80+
Base.view(A::AxisArray, idxs::ViewIndex...) = view(A, to_index(A,idxs...)...)
81+
Base.view(A::AxisArray, idxs::Union{ViewIndex,AbstractCartesianIndex}...) = view(A, to_index(A,Base.IteratorsMD.flatten(idxs)...)...)
82+
Base.view(A::AxisArray, idxs...) = view(A, to_index(A,idxs...)...)
83+
else
84+
for T in (:ViewIndex, :Any)
85+
@eval begin
86+
@inline function Base.view{T,N}(A::AxisArray{T,N}, idxs::Vararg{$T,N})
87+
view(A, to_index(A,idxs...)...)
88+
end
89+
function Base.view(A::AxisArray, idx::$T)
90+
view(A, to_index(A,idx)...)
91+
end
92+
@inline function Base.view{N}(A::AxisArray, idsx::Vararg{$T,N})
93+
# this should eventually be deleted, see julia #14770
94+
view(A, to_index(A,idxs...)...)
95+
end
96+
end
97+
end
98+
end
11199

112100
# First is indexing by named axis. We simply sort the axes and re-dispatch.
113101
# When indexing by named axis the shapes of omitted dimensions are preserved
@@ -214,30 +202,3 @@ end
214202
meta = Expr(:meta, :inline)
215203
return :($meta; $ex)
216204
end
217-
218-
function prepaxis!{I<:Union{AbstractVector,Colon}}(axesargs, Isplat, ::Type{I}, names, i)
219-
idx = :(idxs[$i])
220-
push!(axesargs, :($(Axis{names[i]})(A.axes[$i].val[$idx])))
221-
push!(Isplat, :(idxs[$i]))
222-
axesargs, Isplat
223-
end
224-
function prepaxis!{I<:AxisArray}(axesargs, Isplat, ::Type{I}, names, i)
225-
idxnames = axisnames(I)
226-
push!(axesargs, :($(Axis{Symbol(names[i], "_", idxnames[1])})(idxs[$i].axes[1].val)))
227-
push!(Isplat, :(idxs[$i]))
228-
axesargs, Isplat
229-
end
230-
# For anything scalar-like
231-
if VERSION < v"0.5.0-dev"
232-
function prepaxis!{I}(axesargs, Isplat, ::Type{I}, names, i)
233-
idx = :(idxs[$i]:idxs[$i])
234-
push!(axesargs, :($(Axis{names[i]})(A.axes[$i].val[$idx])))
235-
push!(Isplat, idx)
236-
axesargs, Isplat
237-
end
238-
else
239-
function prepaxis!{I}(axesargs, Isplat, ::Type{I}, names, i)
240-
push!(Isplat, :(idxs[$i]))
241-
axesargs, Isplat
242-
end
243-
end

test/indexing.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@ D[1,1,1,1,1] = 10
99
@test A[:,:,:] == A[Axis{:row}(:)] == A[Axis{:col}(:)] == A[Axis{:page}(:)] == A.data[:,:,:]
1010
# Test UnitRange slices
1111
@test A[1:2,:,:] == A.data[1:2,:,:] == A[Axis{:row}(1:2)] == A[Axis{1}(1:2)] == A[Axis{:row}(ClosedInterval(-Inf,Inf))] == A[[true,true],:,:]
12+
@test @view(A[1:2,:,:]) == A.data[1:2,:,:] == @view(A[Axis{:row}(1:2)]) == @view(A[Axis{1}(1:2)]) == @view(A[Axis{:row}(ClosedInterval(-Inf,Inf))]) == @view(A[[true,true],:,:])
1213
@test A[:,1:2,:] == A.data[:,1:2,:] == A[Axis{:col}(1:2)] == A[Axis{2}(1:2)] == A[Axis{:col}(ClosedInterval(0.0, .25))] == A[:,[true,true,false],:]
14+
@test @view(A[:,1:2,:]) == A.data[:,1:2,:] == @view(A[Axis{:col}(1:2)]) == @view(A[Axis{2}(1:2)]) == @view(A[Axis{:col}(ClosedInterval(0.0, .25))]) == @view(A[:,[true,true,false],:])
1315
@test A[:,:,1:2] == A.data[:,:,1:2] == A[Axis{:page}(1:2)] == A[Axis{3}(1:2)] == A[Axis{:page}(ClosedInterval(-1., .22))] == A[:,:,[true,true,false,false]]
16+
@test @view(A[:,:,1:2]) == @view(A.data[:,:,1:2]) == @view(A[Axis{:page}(1:2)]) == @view(A[Axis{3}(1:2)]) == @view(A[Axis{:page}(ClosedInterval(-1., .22))]) == @view(A[:,:,[true,true,false,false]])
1417
# Test scalar slices
1518
@test A[2,:,:] == A.data[2,:,:] == A[Axis{:row}(2)]
1619
@test A[:,2,:] == A.data[:,2,:] == A[Axis{:col}(2)]
1720
@test A[:,:,2] == A.data[:,:,2] == A[Axis{:page}(2)]
1821

1922
# Test fallback methods
20-
@test A[[1 2; 3 4]] == A.data[[1 2; 3 4]]
23+
@test A[[1 2; 3 4]] == @view(A[[1 2; 3 4]]) == A.data[[1 2; 3 4]]
2124
@test A[] == A.data[]
2225

2326
# Test axis restrictions
@@ -45,14 +48,19 @@ B = AxisArray(reshape(1:15, 5,3), .1:.1:0.5, [:a, :b, :c])
4548
@test B[ClosedInterval(0.15, 0.3), :] == B[ClosedInterval(0.15, 0.3)] == B[2:3,:]
4649
@test B[ClosedInterval(0.2, 0.5), :] == B[ClosedInterval(0.2, 0.5)] == B[2:end,:]
4750
@test B[ClosedInterval(0.2, 0.6), :] == B[ClosedInterval(0.2, 0.6)] == B[2:end,:]
51+
@test @view(B[ClosedInterval(0.0, 0.5), :]) == @view(B[ClosedInterval(0.0, 0.5)]) == B[:,:]
52+
@test @view(B[ClosedInterval(0.0, 0.3), :]) == @view(B[ClosedInterval(0.0, 0.3)]) == B[1:3,:]
53+
@test @view(B[ClosedInterval(0.15, 0.3), :]) == @view(B[ClosedInterval(0.15, 0.3)]) == B[2:3,:]
54+
@test @view(B[ClosedInterval(0.2, 0.5), :]) == @view(B[ClosedInterval(0.2, 0.5)]) == B[2:end,:]
55+
@test @view(B[ClosedInterval(0.2, 0.6), :]) == @view(B[ClosedInterval(0.2, 0.6)]) == B[2:end,:]
4856

4957
# Test Categorical indexing
50-
@test B[:, :a] == B[:,1]
51-
@test B[:, :c] == B[:,3]
52-
@test B[:, [:a]] == B[:,[1]]
53-
@test B[:, [:a,:c]] == B[:,[1,3]]
58+
@test B[:, :a] == @view(B[:, :a]) == B[:,1]
59+
@test B[:, :c] == @view(B[:, :c]) == B[:,3]
60+
@test B[:, [:a]] == @view(B[:, [:a]]) == B[:,[1]]
61+
@test B[:, [:a,:c]] == @view(B[:, [:a,:c]]) == B[:,[1,3]]
5462

55-
@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == B[2:3,:]
63+
@test B[Axis{:row}(ClosedInterval(0.15, 0.3))] == @view(B[Axis{:row}(ClosedInterval(0.15, 0.3))]) == B[2:3,:]
5664

5765
A = AxisArray(reshape(1:256, 4,4,4,4), Axis{:d1}(.1:.1:.4), Axis{:d2}(1//10:1//10:4//10), Axis{:d3}(["1","2","3","4"]), Axis{:d4}([:a, :b, :c, :d]))
5866
ax1 = axes(A)[1]
@@ -68,7 +76,7 @@ A = AxisArray(reshape(1:32, 2, 2, 2, 2, 2), .1:.1:.2, .1:.1:.2, .1:.1:.2, [:a, :
6876

6977
# Test vectors
7078
v = AxisArray(collect(.1:.1:10.0), .1:.1:10.0)
71-
@test v[Colon()] === v
79+
@test v[Colon()] == v
7280
@test v[:] == v.data[:] == v[Axis{:row}(:)]
7381
@test v[3:8] == v.data[3:8] == v[ClosedInterval(.25,.85)] == v[Axis{:row}(3:8)] == v[Axis{:row}(ClosedInterval(.22,.88))]
7482

0 commit comments

Comments
 (0)