Skip to content

Commit 6a685c2

Browse files
committed
Support squeeze
1 parent 7a7633b commit 6a685c2

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

src/AxisArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AxisArrays
22

3+
using Base: tail
34
using RangeArrays, Iterators, IntervalSets, Compat
45
using Compat.view
56

src/core.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,23 @@ function permutation(to::Symbols, from::Symbols)
307307
ind
308308
end
309309

310+
function Base.squeeze(A::AxisArray, dims::Dims)
311+
keepdims = setdiff(1:ndims(A), dims)
312+
AxisArray(squeeze(A.data, dims), axes(A)[keepdims])
313+
end
314+
# This version is type-stable
315+
function Base.squeeze{Ax<:Axis}(A::AxisArray, ::Type{Ax})
316+
dim = axisdim(A, Ax)
317+
AxisArray(squeeze(A.data, dim), dropax(Ax, axes(A)...))
318+
end
319+
320+
@inline dropax(ax, ax1, axs...) = (ax1, dropax(ax, axs...)...)
321+
@inline dropax{name}(ax::Axis{name}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
322+
@inline dropax{name}(ax::Type{Axis{name}}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
323+
@inline dropax{name,T}(ax::Type{Axis{name,T}}, ax1::Axis{name}, axs...) = dropax(ax, axs...)
324+
dropax(ax) = ()
325+
326+
310327
# A simple display method to include axis information. It might be nice to
311328
# eventually display the axis labels alongside the data array, but that is
312329
# much more difficult.

test/core.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ A = AxisArray(reshape(1:16, 2,2,2,2), .5:.5:1)
9696
@test axisnames(A) == (:row,:col,:page,:dim_4)
9797
VERSION >= v"0.5.0-dev" && @inferred(axisnames(A))
9898
@test axisvalues(A) == (.5:.5:1, 1:2, 1:2, 1:2)
99+
A = AxisArray([0]', :x, :y)
100+
@test axisnames(squeeze(A, 1)) == (:y,)
101+
@test axisnames(squeeze(A, 2)) == (:x,)
102+
@test axisnames(squeeze(A, (1,2))) == axisnames(squeeze(A, (2,1))) == ()
103+
@test axisnames(@inferred(squeeze(A, Axis{:x}))) == (:y,)
104+
@test axisnames(@inferred(squeeze(A, Axis{:x,UnitRange{Int}}))) == (:y,)
105+
@test axisnames(@inferred(squeeze(A, Axis{:y}))) == (:x,)
106+
@test axisnames(@inferred(squeeze(squeeze(A, Axis{:x}), Axis{:y}))) == ()
99107

100108
# Test axisdim
101109
@test_throws ArgumentError AxisArray(reshape(1:24, 2,3,4),

0 commit comments

Comments
 (0)