From 0d7eb3aec652eedf6a8a5d2a584c1c632a85ca89 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Wed, 10 Jun 2020 18:53:21 +0800 Subject: [PATCH 1/2] add hcat and vcat --- src/combine.jl | 8 ++++++-- test/combine.jl | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/combine.jl b/src/combine.jl index c172c8d..01b0e7b 100644 --- a/src/combine.jl +++ b/src/combine.jl @@ -14,17 +14,21 @@ matchingdims(As::Tuple{Vararg{AxisArray}}) = all(equalvalued, sizes(As...)) matchingdimsexcept(As::Tuple{Vararg{AxisArray}}, n::Int) = all(equalvalued, sizes(As...)[[1:n-1; n+1:end]]) Base.cat(A1::AxisArray{T}, As::AxisArray{T}...; dims) where {T} = _cat(dims, A1, As...) +Base.vcat(A1::AxisArray{T}, As::AxisArray{T}...) where {T} = _cat(1, A1, As...) +Base.hcat(A1::AxisArray{T}, As::AxisArray{T}...) where {T} = _cat(2, A1, As...) + _cat(::Val{n}, As...) where {n} = _cat(n, As...) +fastcat(n::Integer, As...) = n == 1 ? vcat(As...) : n == 2 ? hcat(As...) : cat(As...; dims = n) @inline function _cat(n::Integer, As...) if n <= ndims(As[1]) matchingdimsexcept(As, n) || error("All non-concatenated axes must be identically-valued") newaxis = Axis{axisnames(As[1])[n]}(vcat(map(A -> A.axes[n].val, As)...)) checkaxis(newaxis) - return AxisArray(cat(map(A->A.data, As)..., dims=n), (As[1].axes[1:n-1]..., newaxis, As[1].axes[n+1:end]...)) + return AxisArray(fastcat(n, map(A->A.data, As)...), (As[1].axes[1:n-1]..., newaxis, As[1].axes[n+1:end]...)) else matchingdims(As) || error("All axes must be identically-valued") - return AxisArray(cat(map(A->A.data, As)..., dims=n), As[1].axes) + return AxisArray(fastcat(n, map(A->A.data, As)...), As[1].axes) end #if end diff --git a/test/combine.jl b/test/combine.jl index c7eb019..44ac40f 100644 --- a/test/combine.jl +++ b/test/combine.jl @@ -1,14 +1,17 @@ # cat +using AxisArrays, Test A1data, A2data = [1 3; 2 4], [5 7; 6 8] A1 = AxisArray(A1data, Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B])) A2 = AxisArray(A2data, Axis{:Row}([:Third, :Fourth]), Axis{:Col}([:A, :B])) @test isa(cat(A1, A2, dims=1), AxisArray) +@test vcat(A1, A2) == cat(A1, A2, dims=1) @test cat(A1, A2, dims=1) == AxisArray(vcat(A1data, A2data), Axis{:Row}([:First, :Second, :Third, :Fourth]), Axis{:Col}([:A, :B])) A2 = AxisArray(A2data, Axis{:Row}([:First, :Second]), Axis{:Col}([:C, :D])) @test isa(cat(A1, A2, dims=2), AxisArray) +@test hcat(A1, A2) == cat(A1, A2, dims=2) @test cat(A1, A2, dims=2) == AxisArray(hcat(A1data, A2data), Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B, :C, :D])) From d1e38037f611cff073011cc4cb1f6a45393bf008 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Mon, 15 Jun 2020 09:17:47 +0800 Subject: [PATCH 2/2] fix type instability of _cat --- src/combine.jl | 3 ++- test/combine.jl | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/combine.jl b/src/combine.jl index 01b0e7b..6cd82a0 100644 --- a/src/combine.jl +++ b/src/combine.jl @@ -25,7 +25,8 @@ fastcat(n::Integer, As...) = n == 1 ? vcat(As...) : n == 2 ? hcat(As...) : cat(A matchingdimsexcept(As, n) || error("All non-concatenated axes must be identically-valued") newaxis = Axis{axisnames(As[1])[n]}(vcat(map(A -> A.axes[n].val, As)...)) checkaxis(newaxis) - return AxisArray(fastcat(n, map(A->A.data, As)...), (As[1].axes[1:n-1]..., newaxis, As[1].axes[n+1:end]...)) + axes = ntuple(d -> d == n ? newaxis : As[1].axes[d], ndims(As[1])) + return AxisArray(fastcat(n, map(A->A.data, As)...), axes) else matchingdims(As) || error("All axes must be identically-valued") return AxisArray(fastcat(n, map(A->A.data, As)...), As[1].axes) diff --git a/test/combine.jl b/test/combine.jl index 44ac40f..e4c06ab 100644 --- a/test/combine.jl +++ b/test/combine.jl @@ -5,13 +5,13 @@ A1data, A2data = [1 3; 2 4], [5 7; 6 8] A1 = AxisArray(A1data, Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B])) A2 = AxisArray(A2data, Axis{:Row}([:Third, :Fourth]), Axis{:Col}([:A, :B])) @test isa(cat(A1, A2, dims=1), AxisArray) -@test vcat(A1, A2) == cat(A1, A2, dims=1) +@test @inferred(vcat(A1, A2)) == cat(A1, A2, dims=1) @test cat(A1, A2, dims=1) == AxisArray(vcat(A1data, A2data), Axis{:Row}([:First, :Second, :Third, :Fourth]), Axis{:Col}([:A, :B])) A2 = AxisArray(A2data, Axis{:Row}([:First, :Second]), Axis{:Col}([:C, :D])) @test isa(cat(A1, A2, dims=2), AxisArray) -@test hcat(A1, A2) == cat(A1, A2, dims=2) +@test @inferred(hcat(A1, A2)) == cat(A1, A2, dims=2) @test cat(A1, A2, dims=2) == AxisArray(hcat(A1data, A2data), Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B, :C, :D]))