diff --git a/src/combine.jl b/src/combine.jl index c172c8d..6cd82a0 100644 --- a/src/combine.jl +++ b/src/combine.jl @@ -14,17 +14,22 @@ matchingdims(As::Tuple{Vararg{AxisArray}}) = all(equalvalued, sizes(As...)) matchingdimsexcept(As::Tuple{Vararg{AxisArray}}, n::Int) = all(equalvalued, sizes(As...)[[1:n-1; n+1:end]]) Base.cat(A1::AxisArray{T}, As::AxisArray{T}...; dims) where {T} = _cat(dims, A1, As...) +Base.vcat(A1::AxisArray{T}, As::AxisArray{T}...) where {T} = _cat(1, A1, As...) +Base.hcat(A1::AxisArray{T}, As::AxisArray{T}...) where {T} = _cat(2, A1, As...) + _cat(::Val{n}, As...) where {n} = _cat(n, As...) +fastcat(n::Integer, As...) = n == 1 ? vcat(As...) : n == 2 ? hcat(As...) : cat(As...; dims = n) @inline function _cat(n::Integer, As...) if n <= ndims(As[1]) matchingdimsexcept(As, n) || error("All non-concatenated axes must be identically-valued") newaxis = Axis{axisnames(As[1])[n]}(vcat(map(A -> A.axes[n].val, As)...)) checkaxis(newaxis) - return AxisArray(cat(map(A->A.data, As)..., dims=n), (As[1].axes[1:n-1]..., newaxis, As[1].axes[n+1:end]...)) + axes = ntuple(d -> d == n ? newaxis : As[1].axes[d], ndims(As[1])) + return AxisArray(fastcat(n, map(A->A.data, As)...), axes) else matchingdims(As) || error("All axes must be identically-valued") - return AxisArray(cat(map(A->A.data, As)..., dims=n), As[1].axes) + return AxisArray(fastcat(n, map(A->A.data, As)...), As[1].axes) end #if end diff --git a/test/combine.jl b/test/combine.jl index c7eb019..e4c06ab 100644 --- a/test/combine.jl +++ b/test/combine.jl @@ -1,14 +1,17 @@ # cat +using AxisArrays, Test A1data, A2data = [1 3; 2 4], [5 7; 6 8] A1 = AxisArray(A1data, Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B])) A2 = AxisArray(A2data, Axis{:Row}([:Third, :Fourth]), Axis{:Col}([:A, :B])) @test isa(cat(A1, A2, dims=1), AxisArray) +@test @inferred(vcat(A1, A2)) == cat(A1, A2, dims=1) @test cat(A1, A2, dims=1) == AxisArray(vcat(A1data, A2data), Axis{:Row}([:First, :Second, :Third, :Fourth]), Axis{:Col}([:A, :B])) A2 = AxisArray(A2data, Axis{:Row}([:First, :Second]), Axis{:Col}([:C, :D])) @test isa(cat(A1, A2, dims=2), AxisArray) +@test @inferred(hcat(A1, A2)) == cat(A1, A2, dims=2) @test cat(A1, A2, dims=2) == AxisArray(hcat(A1data, A2data), Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B, :C, :D]))