Skip to content

Commit 6ace815

Browse files
authored
Merge pull request #40 from JuliaArrays/teh/more_indexing
Indexing, equality, and dimension-changing operations
2 parents 8d8973a + 0b2c670 commit 6ace815

File tree

8 files changed

+153
-26
lines changed

8 files changed

+153
-26
lines changed

src/AxisArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AxisArrays
22

3+
using Base: tail
34
using RangeArrays, Iterators, IntervalSets, Compat
45
using Compat.view
56

src/core.jl

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ else
88
using Base: @pure
99
end
1010

11+
typealias Symbols Tuple{Symbol,Vararg{Symbol}}
12+
1113
@doc """
1214
Type-stable axis-specific indexing and identification with a
1315
parametric type.
@@ -51,7 +53,7 @@ immutable Axis{name,T}
5153
end
5254
# Constructed exclusively through Axis{:symbol}(...) or Axis{1}(...)
5355
@compat (::Type{Axis{name}}){name,T}(I::T=()) = Axis{name,T}(I)
54-
@compat Base.:(==){name,T}(A::Axis{name,T}, B::Axis{name,T}) = A.val == B.val
56+
@compat Base.:(==){name}(A::Axis{name}, B::Axis{name}) = A.val == B.val
5557
Base.hash{name}(A::Axis{name}, hx::UInt) = hash(A.val, hash(name, hx))
5658
axistype{name,T}(::Axis{name,T}) = T
5759
axistype{name,T}(::Type{Axis{name,T}}) = T
@@ -61,8 +63,12 @@ Base.getindex(A::Axis, i...) = A.val[i...]
6163
Base.unsafe_getindex(A::Axis, i...) = Base.unsafe_getindex(A, i...)
6264
Base.eltype{_,T}(::Type{Axis{_,T}}) = eltype(T)
6365
Base.size(A::Axis) = size(A.val)
66+
Base.indices(A::Axis) = indices(A.val)
67+
Base.indices(A::Axis, d) = indices(A.val, d)
6468
Base.length(A::Axis) = length(A.val)
6569
@compat (A::Axis{name}){name}(i) = Axis{name}(i)
70+
Base.convert{name,T}(::Type{Axis{name,T}}, ax::Axis{name,T}) = ax
71+
Base.convert{name,T}(::Type{Axis{name,T}}, ax::Axis{name}) = Axis{name}(convert(T, ax.val))
6672

6773
@doc """
6874
An AxisArray is an AbstractArray that wraps another AbstractArray and
@@ -95,11 +101,12 @@ AxisArray(A::AbstractArray, vectors::AbstractVector...)
95101
* `A::AbstractArray` : the wrapped array data
96102
* `axes` or `names` or `vectors` : dimensional information for the wrapped array
97103
98-
The dimensional information may be passed in one of three ways and is entirely
99-
optional. When the axis name or value is missing for a dimension, a default is
100-
substituted. The default axis names for dimensions `(1, 2, 3, 4, 5, ...)` are
101-
`(:row, :col, :page, :dim_4, :dim_5, ...)`. The default axis values are the
102-
integer unit ranges: `1:size(A, d)` for each missing dimension `d`.
104+
The dimensional information may be passed in one of three ways and is
105+
entirely optional. When the axis name or value is missing for a
106+
dimension, a default is substituted. The default axis names for
107+
dimensions `(1, 2, 3, 4, 5, ...)` are `(:row, :col, :page, :dim_4,
108+
:dim_5, ...)`. The default axis values are `indices(A, d)` for each
109+
missing dimension `d`.
103110
104111
### Indexing
105112
@@ -166,12 +173,12 @@ AxisArray(A::AbstractArray, axs::Axis...) = AxisArray(A, axs)
166173
push!(ax.args, :(axs[$i]))
167174
end
168175
for i=L+1:N
169-
push!(ax.args, :(Axis{_defaultdimname($i)}(1:size(A, $i))))
176+
push!(ax.args, :(Axis{_defaultdimname($i)}(indices(A, $i))))
170177
end
171178
quote
172179
for i = 1:length(axs)
173180
checkaxis(axs[i].val)
174-
if length(axs[i].val) != size(A, i)
181+
if _length(axs[i].val) != _size(A, i)
175182
throw(ArgumentError("the length of each axis must match the corresponding size of data"))
176183
end
177184
end
@@ -183,7 +190,7 @@ AxisArray(A::AbstractArray, axs::Axis...) = AxisArray(A, axs)
183190
end
184191
# Simple non-type-stable constructors to specify just the name or axis values
185192
AxisArray(A::AbstractArray) = AxisArray(A, ()) # Disambiguation
186-
AxisArray(A::AbstractArray, names::Symbol...) = AxisArray(A, ntuple(i->Axis{names[i]}(1:size(A, i)), length(names)))
193+
AxisArray(A::AbstractArray, names::Symbol...) = AxisArray(A, map((name,ind)->Axis{name}(ind), names, indices(A)))
187194
AxisArray(A::AbstractArray, vects::AbstractVector...) = AxisArray(A, ntuple(i->Axis{_defaultdimname(i)}(vects[i]), length(vects)))
188195

189196
# Axis definitions
@@ -214,47 +221,103 @@ end
214221
Base.size(A::AxisArray) = size(A.data)
215222
Base.size(A::AxisArray, Ax::Axis) = size(A.data, axisdim(A, Ax))
216223
Base.size{Ax<:Axis}(A::AxisArray, ::Type{Ax}) = size(A.data, axisdim(A, Ax))
224+
Base.indices(A::AxisArray) = indices(A.data)
225+
Base.indices(A::AxisArray, Ax::Axis) = indices(A.data, axisdim(A, Ax))
226+
Base.indices{Ax<:Axis}(A::AxisArray, ::Type{Ax}) = indices(A.data, axisdim(A, Ax))
217227
Base.linearindexing(A::AxisArray) = Base.linearindexing(A.data)
218228
Base.convert{T,N}(::Type{Array{T,N}}, A::AxisArray{T,N}) = convert(Array{T,N}, A.data)
219229
# Similar is tricky. If we're just changing the element type, it can stay as an
220230
# AxisArray. But if we're changing dimensions, there's no way it can know how
221231
# to keep track of the axes, so just punt and return a regular old Array.
222232
# TODO: would it feel more consistent to return an AxisArray without any axes?
223-
Base.similar{T}(A::AxisArray{T}) = (d = similar(A.data, T); AxisArray(d, A.axes))
224-
Base.similar{T}(A::AxisArray{T}, S::Type) = (d = similar(A.data, S); AxisArray(d, A.axes))
225-
Base.similar{T}(A::AxisArray{T}, S::Type, ::Tuple{}) = (d = similar(A.data, S); AxisArray(d, A.axes))
226-
Base.similar{T}(A::AxisArray{T}, dims::Int) = similar(A, T, (dims,))
227-
Base.similar{T}(A::AxisArray{T}, dims::Int...) = similar(A, T, dims)
228-
Base.similar{T}(A::AxisArray{T}, dims::Tuple{Vararg{Int}}) = similar(A, T, dims)
229-
Base.similar{T}(A::AxisArray{T}, S::Type, dims::Int...) = similar(A.data, S, dims)
230-
Base.similar{T}(A::AxisArray{T}, S::Type, dims::Tuple{Vararg{Int}}) = similar(A.data, S, dims)
233+
Base.similar{S}(A::AxisArray, ::Type{S}) = (d = similar(A.data, S); AxisArray(d, A.axes))
234+
Base.similar{S,N}(A::AxisArray, ::Type{S}, dims::Dims{N}) = similar(A.data, S, dims)
231235
# If, however, we pass Axis objects containing the new axis for that dimension,
232236
# we can return a similar AxisArray with an appropriately modified size
233-
Base.similar{T}(A::AxisArray{T}, axs::Axis...) = similar(A, T, axs)
234-
Base.similar{T}(A::AxisArray{T}, S::Type, axs::Axis...) = similar(A, S, axs)
235-
@generated function Base.similar{T,N}(A::AxisArray{T,N}, S::Type, axs::Tuple{Vararg{Axis}})
236-
sz = Expr(:tuple)
237+
Base.similar{T}(A::AxisArray{T}, ax1::Axis, axs::Axis...) = similar(A, T, (ax1, axs...))
238+
Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S, (ax1, axs...))
239+
@generated function Base.similar{T,S,N}(A::AxisArray{T,N}, ::Type{S}, axs::Tuple{Axis,Vararg{Axis}})
240+
inds = Expr(:tuple)
237241
ax = Expr(:tuple)
238242
for d=1:N
239-
push!(sz.args, :(size(A, Axis{$d})))
243+
push!(inds.args, :(indices(A, Axis{$d})))
240244
push!(ax.args, :(axes(A, Axis{$d})))
241245
end
242246
to_delete = Int[]
243247
for i=1:length(axs.parameters)
244248
a = axs.parameters[i]
245249
d = axisdim(A, a)
246250
axistype(a) <: Tuple{} && push!(to_delete, d)
247-
sz.args[d] = :(length(axs[$i].val))
251+
inds.args[d] = :(indices(axs[$i].val, 1))
248252
ax.args[d] = :(axs[$i])
249253
end
250254
sort!(to_delete)
251-
deleteat!(sz.args, to_delete)
255+
deleteat!(inds.args, to_delete)
252256
deleteat!(ax.args, to_delete)
253257
quote
254-
d = similar(A.data, S, $sz)
258+
d = similar(A.data, S, $inds)
255259
AxisArray(d, $ax)
256260
end
257261
end
262+
263+
function Base.permutedims(A::AxisArray, perm)
264+
p = permutation(perm, axisnames(A))
265+
AxisArray(permutedims(A.data, p), axes(A)[[p...]])
266+
end
267+
permutation(to::Union{AbstractVector{Int},Tuple{Int,Vararg{Int}}}, from::Symbols) = to
268+
269+
"""
270+
permutation(to, from) -> p
271+
272+
Calculate the permutation of labels in `from` to produce the order in
273+
`to`. Any entries in `to` that are missing in `from` will receive an
274+
index of 0. Any entries in `from` that are missing in `to` will have
275+
their indices appended to the end of the permutation. Consequently,
276+
the length of `p` is equal to the longer of `to` and `from`.
277+
"""
278+
function permutation(to::Symbols, from::Symbols)
279+
n = length(to)
280+
nf = length(from)
281+
li = linearindices(from)
282+
d = Dict(from[i]=>i for i in li)
283+
covered = similar(dims->falses(length(li)), li)
284+
ind = Array(Int, max(n, nf))
285+
for (i,toi) in enumerate(to)
286+
j = get(d, toi, 0)
287+
ind[i] = j
288+
if j != 0
289+
covered[j] = true
290+
end
291+
end
292+
k = n
293+
for i in li
294+
if !covered[i]
295+
d[from[i]] != i && throw(ArgumentError("$(from[i]) is a duplicated argument"))
296+
k += 1
297+
k > nf && throw(ArgumentError("no incomplete containment allowed in $to and $from"))
298+
ind[k] = i
299+
end
300+
end
301+
ind
302+
end
303+
304+
function Base.squeeze(A::AxisArray, dims::Dims)
305+
keepdims = setdiff(1:ndims(A), dims)
306+
AxisArray(squeeze(A.data, dims), axes(A)[keepdims])
307+
end
308+
# This version is type-stable
309+
function Base.squeeze{Ax<:Axis}(A::AxisArray, ::Type{Ax})
310+
dim = axisdim(A, Ax)
311+
AxisArray(squeeze(A.data, dim), dropax(Ax, axes(A)...))
312+
end
313+
314+
@inline dropax(ax, ax1, axs...) = (ax1, dropax(ax, axs...)...)
315+
@inline dropax{name}(ax::Axis{name}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
316+
@inline dropax{name}(ax::Type{Axis{name}}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
317+
@inline dropax{name,T}(ax::Type{Axis{name,T}}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
318+
dropax(ax) = ()
319+
320+
258321
# A simple display method to include axis information. It might be nice to
259322
# eventually display the axis labels alongside the data array, but that is
260323
# much more difficult.
@@ -356,3 +419,10 @@ function checkaxis(::Type{Categorical}, ax)
356419
push!(seen, elt)
357420
end
358421
end
422+
423+
_length(A::AbstractArray) = length(linearindices(A))
424+
_length(A) = length(A)
425+
_size(A::AbstractArray) = map(length, indices(A))
426+
_size(A) = size(A)
427+
_size(A::AbstractArray, d) = length(indices(A, d))
428+
_size(A, d) = size(A, d)

src/indexing.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878
push!(newaxes, :($(Axis{names[d]})(A.axes[$d].val[J[$d]])))
7979
elseif I[d] <: AbstractArray
8080
for i=1:ndims(I[d])
81-
push!(newaxes, :($(Axis{Symbol(names[d], "_", i)})(1:size(I[$d], $i))))
81+
push!(newaxes, :($(Axis{Symbol(names[d], "_", i)})(indices(I[$d], $i))))
8282
end
8383
end
8484
end
@@ -198,6 +198,10 @@ end
198198
push!(ex.args, :(I[$i]))
199199
elseif I[i] <: AbstractArray{Bool}
200200
push!(ex.args, :(find(I[$i])))
201+
elseif I[i] <: CartesianIndex
202+
for j = 1:length(I[i])
203+
push!(ex.args, :(I[$i][$j]))
204+
end
201205
elseif i <= length(Ax.parameters)
202206
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
203207
else

src/sortedvector.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Base.getindex(v::SortedVector, idx::AbstractVector) =
6363
Base.length(v::SortedVector) = length(v.data)
6464
Base.size(v::SortedVector) = size(v.data)
6565
Base.size(v::SortedVector, i) = size(v.data, i)
66+
Base.indices(v::SortedVector) = indices(v.data)
6667

6768
axistrait(::SortedVector) = Dimensional
6869
checkaxis(::SortedVector) = nothing

test/REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
OffsetArrays

test/core.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ C = similar(A, 0)
2626
D = similar(A)
2727
@test size(A) == size(D)
2828
@test eltype(A) == eltype(D)
29+
@test axisnames(permutedims(A, (2,1,3))) == (:col, :row, :page)
30+
@test axisnames(permutedims(A, (2,3,1))) == (:col, :page, :row)
31+
@test axisnames(permutedims(A, (3,2,1))) == (:page, :col, :row)
32+
@test axisnames(permutedims(A, (3,1,2))) == (:page, :row, :col)
33+
for perm in ((:col, :row, :page), (:col, :page, :row),
34+
(:page, :col, :row), (:page, :row, :col),
35+
(:row, :page, :col), (:row, :col, :page))
36+
@test axisnames(permutedims(A, perm)) == perm
37+
end
2938
# Test modifying a particular axis
3039
E = similar(A, Float64, Axis{:col}(1:2))
3140
@test size(E) == (2,2,4)
@@ -87,6 +96,14 @@ A = AxisArray(reshape(1:16, 2,2,2,2), .5:.5:1)
8796
@test axisnames(A) == (:row,:col,:page,:dim_4)
8897
VERSION >= v"0.5.0-dev" && @inferred(axisnames(A))
8998
@test axisvalues(A) == (.5:.5:1, 1:2, 1:2, 1:2)
99+
A = AxisArray([0]', :x, :y)
100+
@test axisnames(squeeze(A, 1)) == (:y,)
101+
@test axisnames(squeeze(A, 2)) == (:x,)
102+
@test axisnames(squeeze(A, (1,2))) == axisnames(squeeze(A, (2,1))) == ()
103+
@test axisnames(@inferred(squeeze(A, Axis{:x}))) == (:y,)
104+
@test axisnames(@inferred(squeeze(A, Axis{:x,UnitRange{Int}}))) == (:y,)
105+
@test axisnames(@inferred(squeeze(A, Axis{:y}))) == (:x,)
106+
@test axisnames(@inferred(squeeze(squeeze(A, Axis{:x}), Axis{:y}))) == ()
90107

91108
# Test axisdim
92109
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4),
@@ -107,13 +124,25 @@ A = AxisArray(reshape(1:24, 2,3,4),
107124
@test @inferred(axes(A, Axis{:x})) == @inferred(axes(A, Axis{:x}())) == Axis{:x}(.1:.1:.2)
108125
@test @inferred(axes(A, Axis{:y})) == @inferred(axes(A, Axis{:y}())) == Axis{:y}(1//10:1//10:3//10)
109126
@test @inferred(axes(A, Axis{:z})) == @inferred(axes(A, Axis{:z}())) == Axis{:z}(["a", "b", "c", "d"])
127+
@test axes(A, 2) == Axis{:y}(1//10:1//10:3//10)
110128

111129
@test Axis{:col}(1) == Axis{:col}(1)
112130
@test Axis{:col}(1) != Axis{:com}(1)
131+
@test Axis{:x}(1:3) == Axis{:x}(Base.OneTo(3))
113132
@test hash(Axis{:col}(1)) == hash(Axis{:col}(1.0))
114133
@test hash(Axis{:row}()) != hash(Axis{:col}())
134+
@test hash(Axis{:x}(1:3)) == hash(Axis{:x}(Base.OneTo(3)))
115135
@test AxisArrays.axistype(Axis{1}(1:2)) == typeof(1:2)
136+
@test AxisArrays.axistype(Axis{1,UInt32}) == UInt32
116137
@test axisnames(Axis{1}, Axis{2}, Axis{3}) == (1,2,3)
138+
@test Axis{:row}(2:7)[4] == 5
139+
@test eltype(Axis{:row}(1.0:1.0:3.0)) == Float64
140+
@test size(Axis{:row}(2:7)) === (6,)
141+
@test indices(Axis{:row}(2:7)) === (Base.OneTo(6),)
142+
@test indices(Axis{:row}(-1:1), 1) === Base.OneTo(3)
143+
@test length(Axis{:col}(-1:2)) === 4
144+
@test AxisArrays.axisname(Axis{:foo}(1:2)) == :foo
145+
@test AxisArrays.axisname(Axis{:foo}) == :foo
117146

118147
# Test Timetype axis construction
119148
dt, vals = DateTime(2010, 1, 2, 3, 40), randn(5,2)
@@ -123,3 +152,16 @@ A = AxisArray(vals, Axis{:Timestamp}(dt-Dates.Hour(2):Dates.Hour(1):dt+Dates.Hou
123152

124153
# Simply run the display method to ensure no stupid errors
125154
@compat show(IOBuffer(),MIME("text/plain"),A)
155+
156+
# With unconventional indices
157+
import OffsetArrays # import rather than using because OffsetArrays has a deprecation for ..
158+
A = AxisArray(OffsetArrays.OffsetArray([5,3,4], -1:1), :x)
159+
@test axes(A) == (Axis{:x}(-1:1),)
160+
@test A[-1] == 5
161+
A[0] = 12
162+
@test A.data[0] == 12
163+
@test indices(A) == (-1:1,)
164+
@test linearindices(A) == -1:1
165+
A = AxisArray(OffsetArrays.OffsetArray(rand(4,5), -1:2, 5:9), :x, :y)
166+
@test indices(A) == (-1:2, 5:9)
167+
@test linearindices(A) == 1:20

test/indexing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,9 @@ A = AxisArray([1:100 -1:-1:-100], .1:.1:10.0, [:c1, :c2])
8181
@test A[atindex(-0.5..0.5, [25, 35]), :c1] == [20:30 30:40]
8282
@test_throws BoundsError A[atindex(-0.5..0.5, 5), :c1]
8383
@test_throws BoundsError A[atindex(-0.5..0.5, [5, 15, 25]), :]
84+
85+
# Indexing with CartesianIndex{0}
86+
A = AxisArray(reshape(1:15, 3, 5), :x, :y)
87+
@test A[2,2,CartesianIndex(())] == 5
88+
@test A[2,CartesianIndex(()),2] == 5
89+
@test A[CartesianIndex(()),2,2] == 5

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using AxisArrays
22
using Base.Test, Compat
33

4+
@test isempty(detect_ambiguities(AxisArrays, Base, Core))
5+
46
include("core.jl")
57
include("intervals.jl")
68
include("indexing.jl")

0 commit comments

Comments
 (0)