Skip to content

Commit 56872b7

Browse files
authored
Merge pull request #80 from JuliaArrays/mb/constructor2
AxisArray constructor improvements
2 parents 4e30a82 + 6f87f70 commit 56872b7

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

src/core.jl

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,34 +160,40 @@ immutable AxisArray{T,N,D,Ax} <: AbstractArray{T,N}
160160
axes::Ax # Ax<:NTuple{N, Axis}, but with specialized Axis{...} types
161161
(::Type{AxisArray{T,N,D,Ax}}){T,N,D,Ax}(data::AbstractArray{T,N}, axs::Tuple{Vararg{Axis,N}}) = new{T,N,D,Ax}(data, axs)
162162
end
163-
#
164-
_defaultdimname(i) = i == 1 ? (:row) : i == 2 ? (:col) : i == 3 ? (:page) : Symbol(:dim_, i)
165163

166-
default_axes(A::AbstractArray) = _default_axes(A, indices(A), ())
167-
_default_axes{T,N}(A::AbstractArray{T,N}, inds, axs::NTuple{N,Axis}) = axs
168-
@inline _default_axes{T,N,M}(A::AbstractArray{T,N}, inds, axs::NTuple{M,Axis}) =
169-
_default_axes(A, inds, (axs..., _nextaxistype(axs)(inds[M+1])))
164+
# Helper functions: Default axis names (if not provided)
165+
_defaultdimname(i) = i == 1 ? (:row) : i == 2 ? (:col) : i == 3 ? (:page) : Symbol(:dim_, i)
170166
# Why doesn't @pure work here?
171167
@generated function _nextaxistype{M}(axs::NTuple{M,Axis})
172168
name = _defaultdimname(M+1)
173169
:(Axis{$(Expr(:quote, name))})
174170
end
175171

176-
AxisArray(A::AbstractArray, axs::Axis...) = AxisArray(A, axs)
177-
function AxisArray{T,N}(A::AbstractArray{T,N}, axs::NTuple{N,Axis})
178-
checksizes(axs, _size(A)) || throw(ArgumentError("the length of each axis must match the corresponding size of data"))
179-
checknames(axisnames(axs...)...)
180-
AxisArray{T,N,typeof(A),typeof(axs)}(A, axs)
181-
end
182-
function AxisArray{L}(A::AbstractArray, axs::NTuple{L,Axis})
183-
newaxs = _default_axes(A, indices(A), axs)
184-
AxisArray(A, newaxs)
185-
end
172+
"""
173+
default_axes(A::AbstractArray)
174+
default_axes(A::AbstractArray, axs)
186175
176+
Return a tuple of Axis objects that appropriately index into the array A.
177+
178+
The optional second argument can take a tuple of vectors or axes, which will be
179+
wrapped with the appropriate axis name, and it will ensure no axis goes beyond
180+
the dimensionality of the array A.
181+
"""
182+
@inline default_axes(A::AbstractArray, args=indices(A)) = _default_axes(A, args, ())
183+
_default_axes{T,N}(A::AbstractArray{T,N}, args::Tuple{}, axs::NTuple{N,Axis}) = axs
184+
_default_axes{T,N}(A::AbstractArray{T,N}, args::Tuple{Any, Vararg{Any}}, axs::NTuple{N,Axis}) = throw(ArgumentError("too many axes provided"))
185+
_default_axes{T,N}(A::AbstractArray{T,N}, args::Tuple{Axis, Vararg{Any}}, axs::NTuple{N,Axis}) = throw(ArgumentError("too many axes provided"))
186+
@inline _default_axes{T,N}(A::AbstractArray{T,N}, args::Tuple{}, axs::Tuple) =
187+
_default_axes(A, args, (axs..., _nextaxistype(axs)(indices(A, length(axs)+1))))
188+
@inline _default_axes{T,N}(A::AbstractArray{T,N}, args::Tuple{Any, Vararg{Any}}, axs::Tuple) =
189+
_default_axes(A, Base.tail(args), (axs..., _nextaxistype(axs)(args[1])))
190+
@inline _default_axes{T,N}(A::AbstractArray{T,N}, args::Tuple{Axis, Vararg{Any}}, axs::Tuple) =
191+
_default_axes(A, Base.tail(args), (axs..., args[1]))
192+
193+
# Axis consistency checks — ensure sizes match and the names are unique
187194
@inline checksizes(axs, sz) =
188195
(length(axs[1]) == sz[1]) & checksizes(tail(axs), tail(sz))
189196
checksizes(::Tuple{}, sz) = true
190-
191197
@inline function checknames(name::Symbol, names...)
192198
matches = false
193199
for n in names
@@ -199,10 +205,18 @@ end
199205
checknames(name, names...) = throw(ArgumentError("the Axis names must be Symbols"))
200206
checknames() = ()
201207

202-
# Simple non-type-stable constructors to specify just the name or axis values
208+
# The primary AxisArray constructors — specify an array to wrap and the axes
209+
AxisArray(A::AbstractArray, vects::Union{AbstractVector, Axis}...) = AxisArray(A, vects)
210+
AxisArray(A::AbstractArray, vects::Tuple{Vararg{Union{AbstractVector, Axis}}}) = AxisArray(A, default_axes(A, vects))
211+
function AxisArray{T,N}(A::AbstractArray{T,N}, axs::NTuple{N,Axis})
212+
checksizes(axs, _size(A)) || throw(ArgumentError("the length of each axis must match the corresponding size of data"))
213+
checknames(axisnames(axs...)...)
214+
AxisArray{T,N,typeof(A),typeof(axs)}(A, axs)
215+
end
216+
217+
# Simple non-type-stable constructors to specify names as symbols
203218
AxisArray(A::AbstractArray) = AxisArray(A, ()) # Disambiguation
204219
AxisArray(A::AbstractArray, names::Symbol...) = (inds = indices(A); AxisArray(A, ntuple(i->Axis{names[i]}(inds[i]), length(names))))
205-
AxisArray(A::AbstractArray, vects::AbstractVector...) = AxisArray(A, ntuple(i->Axis{_defaultdimname(i)}(vects[i]), length(vects)))
206220
function AxisArray{T,N}(A::AbstractArray{T,N}, names::NTuple{N,Symbol}, steps::NTuple{N,Number}, offsets::NTuple{N,Number}=map(zero, steps))
207221
axs = ntuple(i->Axis{names[i]}(range(offsets[i], steps[i], size(A,i))), N)
208222
AxisArray(A, axs...)

test/core.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
A = AxisArray(reshape(1:24, 2,3,4), .1:.1:.2, .1:.1:.3, .1:.1:.4)
1+
A = @inferred(AxisArray(reshape(1:24, 2,3,4), .1:.1:.2, .1:.1:.3, .1:.1:.4))
22
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4), .1:.1:.1, .1:.1:.3, .1:.1:.4)
33
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4), .1:.1:.1, .1:.1:.3)
4+
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4), .1:.1:.2, .1:.1:.3, .1:.1:.4, 1:1)
45
@test parent(A) === reshape(1:24, 2,3,4)
56
# Test iteration
67
for (a,b) in zip(A, A.data)
@@ -40,7 +41,7 @@ for perm in ((:col, :row, :page), (:col, :page, :row),
4041
end
4142
@test axisnames(permutedims(A, (:col,))) == (:col, :row, :page)
4243
@test axisnames(permutedims(A, (:page,))) == (:page, :row, :col)
43-
A2 = AxisArray(reshape(1:15, 3, 5))
44+
A2 = @inferred(AxisArray(reshape(1:15, 3, 5)))
4445
A1 = AxisArray(1:5, :t)
4546
for f in (transpose, ctranspose)
4647
@test f(A2).data == f(A2.data)
@@ -100,12 +101,12 @@ A = AxisArray([1 3; 2 4], :a)
100101
VERSION >= v"0.5.0-dev" && @inferred(axisnames(A))
101102
@test axisvalues(A) == (1:2, 1:2)
102103
# Just axis values
103-
A = AxisArray(1:3, .1:.1:.3)
104+
A = @inferred(AxisArray(1:3, .1:.1:.3))
104105
@test A.data == 1:3
105106
@test axisnames(A) == (:row,)
106107
VERSION >= v"0.5.0-dev" && @inferred(axisnames(A))
107108
@test axisvalues(A) == (.1:.1:.3,)
108-
A = AxisArray(reshape(1:16, 2,2,2,2), .5:.5:1)
109+
A = @inferred(AxisArray(reshape(1:16, 2,2,2,2), .5:.5:1))
109110
@test A.data == reshape(1:16, 2,2,2,2)
110111
@test axisnames(A) == (:row,:col,:page,:dim_4)
111112
VERSION >= v"0.5.0-dev" && @inferred(axisnames(A))
@@ -129,17 +130,22 @@ B = AxisArray([1 4; 2 5; 3 6], (:x, :y), (0.2, 100), (-3,14))
129130
@test AxisArrays.HasAxes(A) == AxisArrays.HasAxes{true}()
130131
@test AxisArrays.HasAxes([1]) == AxisArrays.HasAxes{false}()
131132

132-
# Test axisdim
133133
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4),
134134
Axis{1}(.1:.1:.2),
135135
Axis{2}(1//10:1//10:3//10),
136136
Axis{3}(["a", "b", "c", "d"])) # Axis need to be symbols
137+
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4),
138+
Axis{:x}(.1:.1:.2),
139+
Axis{:y}(1//10:1//10:3//10),
140+
Axis{:z}(["a", "b", "c", "d"]),
141+
Axis{:_}(1:1)) # Too many Axes
137142

138-
A = AxisArray(reshape(1:24, 2,3,4),
143+
A = @inferred(AxisArray(reshape(1:24, 2,3,4),
139144
Axis{:x}(.1:.1:.2),
140145
Axis{:y}(1//10:1//10:3//10),
141-
Axis{:z}(["a", "b", "c", "d"]))
146+
Axis{:z}(["a", "b", "c", "d"])))
142147

148+
# Test axisdim
143149
@test axisdim(A, Axis{:x}) == axisdim(A, Axis{:x}()) == 1
144150
@test axisdim(A, Axis{:y}) == axisdim(A, Axis{:y}()) == 2
145151
@test axisdim(A, Axis{:z}) == axisdim(A, Axis{:z}()) == 3
@@ -176,7 +182,7 @@ T = A[AxisArrays.Axis{:x}]
176182

177183
# Test Timetype axis construction
178184
dt, vals = DateTime(2010, 1, 2, 3, 40), randn(5,2)
179-
A = AxisArray(vals, Axis{:Timestamp}(dt-Dates.Hour(2):Dates.Hour(1):dt+Dates.Hour(2)), Axis{:Cols}([:A, :B]))
185+
A = @inferred(AxisArray(vals, Axis{:Timestamp}(dt-Dates.Hour(2):Dates.Hour(1):dt+Dates.Hour(2)), Axis{:Cols}([:A, :B])))
180186
@test A[:, :A].data == vals[:, 1]
181187
@test A[dt, :].data == vals[3, :]
182188

@@ -229,7 +235,7 @@ map!(*, A2, A, A)
229235

230236
# Reductions (issue #55)
231237
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
232-
B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50))
238+
B = @inferred(AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50)))
233239
for C in (A, B)
234240
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
235241
axv = axisvalues(C)

0 commit comments

Comments
 (0)