Skip to content

Commit 4e30a82

Browse files
authored
Refactor reaxis to be slightly less generated (#79)
* Refactor reaxis to be slightly less generated Mostly use dispatch now, except in cases where we change axis names. * Add tests Test indexing beyond the dimensionality and adding dimensions. Add some inferred tests. * 0.5 compatibility
1 parent 2b723be commit 4e30a82

File tree

4 files changed

+59
-46
lines changed

4 files changed

+59
-46
lines changed

src/combine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ Combines AxisArrays with matching axis names into a single AxisArray. Unlike `me
122122
If an array value in the output array is not defined in any of the input arrays (i.e. in the case of a left, right, or outer join), it takes the value of the optional `fillvalue` keyword argument (default zero).
123123
"""
124124
function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
125-
newaxis::Axis=_nextaxistype(As[1].data, As[1].axes)(1:length(As)),
125+
newaxis::Axis=_nextaxistype(As[1].axes)(1:length(As)),
126126
method::Symbol=:outer)
127127

128128
prejoin_resultaxes = map(as -> axismerge(method, as...), map(tuple, axes.(As)...))

src/core.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ _defaultdimname(i) = i == 1 ? (:row) : i == 2 ? (:col) : i == 3 ? (:page) : Symb
166166
default_axes(A::AbstractArray) = _default_axes(A, indices(A), ())
167167
_default_axes{T,N}(A::AbstractArray{T,N}, inds, axs::NTuple{N,Axis}) = axs
168168
@inline _default_axes{T,N,M}(A::AbstractArray{T,N}, inds, axs::NTuple{M,Axis}) =
169-
_default_axes(A, inds, (axs..., _nextaxistype(A, axs)(inds[M+1])))
169+
_default_axes(A, inds, (axs..., _nextaxistype(axs)(inds[M+1])))
170170
# Why doesn't @pure work here?
171-
@generated function _nextaxistype{T,M}(A::AbstractArray{T}, axs::NTuple{M,Axis})
171+
@generated function _nextaxistype{M}(axs::NTuple{M,Axis})
172172
name = _defaultdimname(M+1)
173173
:(Axis{$(Expr(:quote, name))})
174174
end

src/indexing.jl

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
const Idx = Union{Real,Colon,AbstractArray{Int}}
22

3-
using Base: ViewIndex, @propagate_inbounds
3+
using Base: ViewIndex, @propagate_inbounds, tail
44

55
# Defer IndexStyle to the wrapped array
66
@compat Base.IndexStyle{T,N,D,Ax}(::Type{AxisArray{T,N,D,Ax}}) = IndexStyle(D)
@@ -12,46 +12,58 @@ using Base: ViewIndex, @propagate_inbounds
1212
# Cartesian iteration
1313
Base.eachindex(A::AxisArray) = eachindex(A.data)
1414

15-
@generated function reaxis(A::AxisArray, I::Idx...)
16-
N = length(I)
17-
# Determine the new axes:
18-
# Drop linear indexing over multiple axes
19-
droplastaxis = ndims(A) > N && !(I[end] <: Real) ? 1 : 0
20-
# Drop trailing scalar dimensions
21-
lastnonscalar = N
22-
while lastnonscalar > 0 && I[lastnonscalar] <: Real
23-
lastnonscalar -= 1
15+
"""
16+
reaxis(A::AxisArray, I...)
17+
18+
This internal function determines the new set of axes that are constructed upon
19+
indexing with I.
20+
"""
21+
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
22+
# Ensure the number of axes matches the number of indexing dimensions
23+
@inline make_axes_match(axs, idxs) = _make_axes_match((), axs, Base.index_ndims(idxs...))
24+
# Move the axes into newaxes, until we run out of both simultaneously
25+
@inline _make_axes_match(newaxes, axs::Tuple, nidxs::Tuple) =
26+
_make_axes_match((newaxes..., axs[1]), tail(axs), tail(nidxs))
27+
@inline _make_axes_match(newaxes, axs::Tuple{}, nidxs::Tuple{}) = newaxes
28+
# Drop trailing axes, replacing it with a default name for the linear span
29+
@inline _make_axes_match(newaxes, axs::Tuple, nidxs::Tuple{}) =
30+
(maybefront(newaxes)..., _nextaxistype(newaxes)(Base.OneTo(length(newaxes[end]) * prod(map(length, axs)))))
31+
# Insert phony singleton trailing axes
32+
@inline _make_axes_match(newaxes, axs::Tuple{}, nidxs::Tuple) =
33+
_make_axes_match((newaxes..., _nextaxistype(newaxes)(Base.OneTo(1))), (), tail(nidxs))
34+
35+
@inline maybefront(::Tuple{}) = ()
36+
@inline maybefront(t::Tuple) = Base.front(t)
37+
38+
# Now we can reaxis without worrying about mismatched axes/indices
39+
@inline _reaxis(axs::Tuple{}, idxs::Tuple{}) = ()
40+
# Scalars are dropped
41+
const ScalarIndex = @compat Union{Real, AbstractArray{<:Any, 0}}
42+
@inline _reaxis(axs::Tuple, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), tail(idxs))
43+
# Colon passes straight through
44+
@inline _reaxis(axs::Tuple, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), tail(idxs))...)
45+
# But arrays can add or change dimensions and accompanying axis names
46+
@inline _reaxis(axs::Tuple, idxs::Tuple{AbstractArray, Vararg{Any}}) =
47+
(_new_axes(axs[1], idxs[1])..., _reaxis(tail(axs), tail(idxs))...)
48+
49+
# Vectors simply create new axes with the same name; just subsetted by their value
50+
@inline _new_axes{name}(ax::Axis{name}, idx::AbstractVector) = (Axis{name}(ax.val[idx]),)
51+
# Arrays create multiple axes with _N appended to the axis name containing their indices
52+
@generated function _new_axes{name, N}(ax::Axis{name}, idx::@compat(AbstractArray{<:Any,N}))
53+
newaxes = Expr(:tuple)
54+
for i=1:N
55+
push!(newaxes.args, :($(Axis{Symbol(name, "_", i)})(indices(idx, $i))))
2456
end
25-
names = axisnames(A)
26-
newaxes = Expr[]
27-
drange = 1:lastnonscalar-droplastaxis
28-
for d=drange
29-
if I[d] <: AxisArray
30-
# Indexing with an AxisArray joins the axis names
31-
idxnames = axisnames(I[d])
32-
for i=1:ndims(I[d])
33-
push!(newaxes, :($(Axis{Symbol(names[d], "_", idxnames[i])})(I[$d].axes[$i].val)))
34-
end
35-
elseif I[d] <: Real
36-
elseif I[d] <: AbstractVector
37-
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val[Base.to_index(I[$d])])))
38-
elseif I[d] <: Colon
39-
if d < length(I) || d <= ndims(A)
40-
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val)))
41-
else
42-
dimname = _defaultdimname(d)
43-
push!(newaxes, :($(Axis{dimname})(Base.OneTo(Base.trailingsize(A, $d)))))
44-
end
45-
elseif I[d] <: AbstractArray
46-
for i=1:ndims(I[d])
47-
# When we index with non-vector arrays, we *add* dimensions.
48-
push!(newaxes, :($(Axis{Symbol(names[d], "_", i)})(indices(I[$d], $i))))
49-
end
50-
end
51-
end
52-
quote
53-
($(newaxes...),)
57+
newaxes
58+
end
59+
# And indexing with an AxisArray joins the name and overrides the values
60+
@generated function _new_axes{name, N}(ax::Axis{name}, idx::@compat(AxisArray{<:Any, N}))
61+
newaxes = Expr(:tuple)
62+
idxnames = axisnames(idx)
63+
for i=1:N
64+
push!(newaxes.args, :($(Axis{Symbol(name, "_", idxnames[i])})(idx.axes[$i].val)))
5465
end
66+
newaxes
5567
end
5668

5769
@propagate_inbounds function Base.getindex(A::AxisArray, idxs::Idx...)

test/indexing.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
A = AxisArray(reshape(1:24, 2,3,4), .1:.1:.2, .1:.1:.3, .1:.1:.4)
22
D = similar(A)
33
D[1,1,1,1,1] = 10
4-
@test D[1,1,1,1,1] == D[1] == D.data[1] == 10
4+
@test @inferred(D[1,1,1,1,1]) == @inferred(D[1]) == D.data[1] == 10
5+
@test @inferred(D[1,1,1,:]) == @inferred(D[1,1,1,1:1]) == @inferred(D[1,1,1,[1]]) == AxisArray([10], Axis{:dim_4}(Base.OneTo(1)))
56

67
# Test slices
78

89
@test A == A.data
910
@test A[:,:,:] == A[Axis{:row}(:)] == A[Axis{:col}(:)] == A[Axis{:page}(:)] == A.data[:,:,:]
1011
# Test UnitRange slices
11-
@test A[1:2,:,:] == A.data[1:2,:,:] == A[Axis{:row}(1:2)] == A[Axis{1}(1:2)] == A[Axis{:row}(ClosedInterval(-Inf,Inf))] == A[[true,true],:,:]
12-
@test @view(A[1:2,:,:]) == A.data[1:2,:,:] == @view(A[Axis{:row}(1:2)]) == @view(A[Axis{1}(1:2)]) == @view(A[Axis{:row}(ClosedInterval(-Inf,Inf))]) == @view(A[[true,true],:,:])
13-
@test A[:,1:2,:] == A.data[:,1:2,:] == A[Axis{:col}(1:2)] == A[Axis{2}(1:2)] == A[Axis{:col}(ClosedInterval(0.0, .25))] == A[:,[true,true,false],:]
12+
@test @inferred(A[1:2,:,:]) == A.data[1:2,:,:] == @inferred(A[Axis{:row}(1:2)]) == @inferred(A[Axis{1}(1:2)]) == @inferred(A[Axis{:row}(ClosedInterval(-Inf,Inf))]) == @inferred(A[[true,true],:,:])
13+
@test @inferred(view(A,1:2,:,:)) == A.data[1:2,:,:] == @inferred(view(A,Axis{:row}(1:2))) == @inferred(view(A,Axis{1}(1:2))) == @inferred(view(A,Axis{:row}(ClosedInterval(-Inf,Inf)))) == @inferred(view(A,[true,true],:,:))
14+
@test @inferred(A[:,1:2,:]) == A.data[:,1:2,:] == @inferred(A[Axis{:col}(1:2)]) == @inferred(A[Axis{2}(1:2)]) == @inferred(A[Axis{:col}(ClosedInterval(0.0, .25))]) == @inferred(A[:,[true,true,false],:])
1415
@test @view(A[:,1:2,:]) == A.data[:,1:2,:] == @view(A[Axis{:col}(1:2)]) == @view(A[Axis{2}(1:2)]) == @view(A[Axis{:col}(ClosedInterval(0.0, .25))]) == @view(A[:,[true,true,false],:])
1516
@test A[:,:,1:2] == A.data[:,:,1:2] == A[Axis{:page}(1:2)] == A[Axis{3}(1:2)] == A[Axis{:page}(ClosedInterval(-1., .22))] == A[:,:,[true,true,false,false]]
1617
@test @view(A[:,:,1:2]) == @view(A.data[:,:,1:2]) == @view(A[Axis{:page}(1:2)]) == @view(A[Axis{3}(1:2)]) == @view(A[Axis{:page}(ClosedInterval(-1., .22))]) == @view(A[:,:,[true,true,false,false]])

0 commit comments

Comments
 (0)