Skip to content

Commit 7a7633b

Browse files
committed
Add support for permutedims
1 parent 22eae14 commit 7a7633b

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/core.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ else
88
using Base: @pure
99
end
1010

11+
typealias Symbols Tuple{Symbol,Vararg{Symbol}}
12+
1113
@doc """
1214
Type-stable axis-specific indexing and identification with a
1315
parametric type.
@@ -263,6 +265,48 @@ Base.similar{T}(A::AxisArray{T}, S::Type, axs::Axis...) = similar(A, S, axs)
263265
AxisArray(d, $ax)
264266
end
265267
end
268+
269+
function Base.permutedims(A::AxisArray, perm)
270+
p = permutation(perm, axisnames(A))
271+
AxisArray(permutedims(A.data, p), axes(A)[[p...]])
272+
end
273+
permutation(to::Union{AbstractVector{Int},Tuple{Int,Vararg{Int}}}, from::Symbols) = to
274+
275+
"""
276+
permutation(to, from) -> p
277+
278+
Calculate the permutation of labels in `from` to produce the order in
279+
`to`. Any entries in `to` that are missing in `from` will receive an
280+
index of 0. Any entries in `from` that are missing in `to` will have
281+
their indices appended to the end of the permutation. Consequently,
282+
the length of `p` is equal to the longer of `to` and `from`.
283+
"""
284+
function permutation(to::Symbols, from::Symbols)
285+
n = length(to)
286+
nf = length(from)
287+
li = linearindices(from)
288+
d = Dict(from[i]=>i for i in li)
289+
covered = similar(dims->falses(length(li)), li)
290+
ind = Array(Int, max(n, nf))
291+
for (i,toi) in enumerate(to)
292+
j = get(d, toi, 0)
293+
ind[i] = j
294+
if j != 0
295+
covered[j] = true
296+
end
297+
end
298+
k = n
299+
for i in li
300+
if !covered[i]
301+
d[from[i]] != i && throw(ArgumentError("$(from[i]) is a duplicated argument"))
302+
k += 1
303+
k > nf && throw(ArgumentError("no incomplete containment allowed in $to and $from"))
304+
ind[k] = i
305+
end
306+
end
307+
ind
308+
end
309+
266310
# A simple display method to include axis information. It might be nice to
267311
# eventually display the axis labels alongside the data array, but that is
268312
# much more difficult.

test/core.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ C = similar(A, 0)
2626
D = similar(A)
2727
@test size(A) == size(D)
2828
@test eltype(A) == eltype(D)
29+
@test axisnames(permutedims(A, (2,1,3))) == (:col, :row, :page)
30+
@test axisnames(permutedims(A, (2,3,1))) == (:col, :page, :row)
31+
@test axisnames(permutedims(A, (3,2,1))) == (:page, :col, :row)
32+
@test axisnames(permutedims(A, (3,1,2))) == (:page, :row, :col)
33+
for perm in ((:col, :row, :page), (:col, :page, :row),
34+
(:page, :col, :row), (:page, :row, :col),
35+
(:row, :page, :col), (:row, :col, :page))
36+
@test axisnames(permutedims(A, perm)) == perm
37+
end
2938
# Test modifying a particular axis
3039
E = similar(A, Float64, Axis{:col}(1:2))
3140
@test size(E) == (2,2,4)

0 commit comments

Comments
 (0)