Skip to content

Commit 36d342c

Browse files
authored
Use dtype agnostic op_cat implementation, add op_cat testcases
Differential Revision: D80193957 Pull Request resolved: #13397
1 parent 377f474 commit 36d342c

File tree

3 files changed

+159
-21
lines changed

3 files changed

+159
-21
lines changed

backends/cadence/hifi/operators/op_cat.cpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,34 +126,30 @@ Tensor& cat_out(
126126
const size_t outer = getLeadingDims(out, dim);
127127
const size_t dim_stride = getTrailingDims(out, dim);
128128
const size_t ninputs = tensors.size();
129+
const size_t element_size = out.element_size();
130+
char* out_ptr = static_cast<char*>(out.mutable_data_ptr());
129131

130-
const auto out_type = out.scalar_type();
131-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
132-
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
133-
for (size_t i = 0; i < outer; ++i) {
134-
for (size_t j = 0; j < ninputs; ++j) {
135-
const auto in_type = tensors[j].scalar_type();
136-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
137-
if (tensors[j].numel() == 0) {
138-
return;
139-
}
140-
size_t inner = tensors[j].size(dim) * dim_stride;
141-
const CTYPE_IN* const in_ptr =
142-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
143-
144-
for (size_t k = 0; k < inner; ++k) {
145-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
146-
}
147-
out_ptr += inner;
148-
});
132+
for (size_t i = 0; i < outer; ++i) {
133+
for (size_t j = 0; j < ninputs; ++j) {
134+
if (tensors[j].numel() == 0) {
135+
continue;
149136
}
137+
size_t inner_elements = tensors[j].size(dim) * dim_stride;
138+
size_t contiguous_bytes = inner_elements * element_size;
139+
140+
const char* const in_ptr =
141+
static_cast<const char*>(tensors[j].const_data_ptr()) +
142+
i * contiguous_bytes;
143+
144+
std::memcpy(out_ptr, in_ptr, contiguous_bytes);
145+
out_ptr += contiguous_bytes;
150146
}
151-
});
147+
}
152148

153149
return out;
154150
}
155151

156152
} // namespace native
157153
} // namespace HiFi
158154
} // namespace impl
159-
} // namespace cadence
155+
} // namespace cadence

backends/cadence/hifi/operators/operators.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ void quantized_conv_per_tensor_out(
122122
bool channel_last,
123123
::executorch::aten::Tensor& out);
124124

125+
::executorch::aten::Tensor& cat_out(
126+
::executorch::runtime::KernelRuntimeContext& ctx,
127+
::executorch::aten::ArrayRef<::executorch::aten::Tensor> tensors,
128+
int64_t dim,
129+
::executorch::aten::Tensor& out);
130+
125131
} // namespace native
126132
} // namespace HiFi
127133
} // namespace impl
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <sys/times.h>
11+
#include <xtensa/sim.h>
12+
13+
#include <executorch/kernels/test/TestUtil.h>
14+
#include <executorch/runtime/core/error.h>
15+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
16+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
17+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
18+
#include <executorch/runtime/platform/runtime.h>
19+
20+
#include <executorch/backends/cadence/hifi/operators/operators.h>
21+
22+
namespace cadence {
23+
namespace impl {
24+
namespace HiFi {
25+
namespace native {
26+
namespace {
27+
28+
using ::executorch::aten::ArrayRef;
29+
using ::executorch::aten::ScalarType;
30+
using ::executorch::aten::Tensor;
31+
using ::executorch::aten::TensorImpl;
32+
using ::executorch::runtime::Error;
33+
using ::executorch::runtime::KernelRuntimeContext;
34+
using ::executorch::runtime::runtime_init;
35+
using ::executorch::runtime::testing::TensorFactory;
36+
37+
class HiFiCatTest : public OperatorTest {
38+
public:
39+
protected:
40+
Tensor& cat_out(ArrayRef<Tensor> tensors, int64_t dim, Tensor& out) {
41+
return ::cadence::impl::HiFi::native::cat_out(context_, tensors, dim, out);
42+
}
43+
};
44+
45+
TEST_F(HiFiCatTest, FloatCatDim0Test) {
46+
TensorFactory<ScalarType::Float> tf;
47+
Tensor a = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
48+
Tensor b = tf.make({1, 3}, {7.0, 8.0, 9.0});
49+
Tensor c = tf.make({2, 3}, {10.0, 11.0, 12.0, 13.0, 14.0, 15.0});
50+
51+
Tensor expected = tf.make(
52+
{5, 3},
53+
{1.0,
54+
2.0,
55+
3.0,
56+
4.0,
57+
5.0,
58+
6.0,
59+
7.0,
60+
8.0,
61+
9.0,
62+
10.0,
63+
11.0,
64+
12.0,
65+
13.0,
66+
14.0,
67+
15.0});
68+
69+
Tensor out = tf.zeros({5, 3});
70+
std::vector<Tensor> tensors = {a, b, c};
71+
72+
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 0, out);
73+
EXPECT_TENSOR_EQ(out, expected);
74+
}
75+
76+
TEST_F(HiFiCatTest, FloatCatDim1Test) {
77+
TensorFactory<ScalarType::Float> tf;
78+
Tensor a = tf.make({2, 2}, {1.0, 2.0, 3.0, 4.0});
79+
Tensor b = tf.make({2, 1}, {5.0, 6.0});
80+
Tensor c = tf.make({2, 3}, {7.0, 8.0, 9.0, 10.0, 11.0, 12.0});
81+
82+
Tensor expected = tf.make(
83+
{2, 6}, {1.0, 2.0, 5.0, 7.0, 8.0, 9.0, 3.0, 4.0, 6.0, 10.0, 11.0, 12.0});
84+
85+
Tensor out = tf.zeros({2, 6});
86+
std::vector<Tensor> tensors = {a, b, c};
87+
88+
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 1, out);
89+
EXPECT_TENSOR_EQ(out, expected);
90+
}
91+
92+
TEST_F(HiFiCatTest, IntCatDim0Test) {
93+
TensorFactory<ScalarType::Int> tf;
94+
Tensor a = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
95+
Tensor b = tf.make({1, 3}, {7, 8, 9});
96+
97+
Tensor expected = tf.make({3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
98+
99+
Tensor out = tf.zeros({3, 3});
100+
std::vector<Tensor> tensors = {a, b};
101+
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 0, out);
102+
EXPECT_TENSOR_EQ(out, expected);
103+
}
104+
105+
TEST_F(HiFiCatTest, SingleTensorTest) {
106+
TensorFactory<ScalarType::Float> tf;
107+
Tensor a = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
108+
Tensor expected = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
109+
110+
Tensor out = tf.zeros({2, 3});
111+
std::vector<Tensor> tensors = {a};
112+
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 0, out);
113+
EXPECT_TENSOR_EQ(out, expected);
114+
}
115+
116+
TEST_F(HiFiCatTest, ThreeDimensionalCatTest) {
117+
TensorFactory<ScalarType::Float> tf;
118+
Tensor a = tf.make({2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0});
119+
Tensor b = tf.make({2, 2, 1}, {9.0, 10.0, 11.0, 12.0});
120+
121+
Tensor expected = tf.make(
122+
{2, 2, 3},
123+
{1.0, 2.0, 9.0, 3.0, 4.0, 10.0, 5.0, 6.0, 11.0, 7.0, 8.0, 12.0});
124+
125+
Tensor out = tf.zeros({2, 2, 3});
126+
std::vector<Tensor> tensors = {a, b};
127+
128+
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 2, out);
129+
EXPECT_TENSOR_EQ(out, expected);
130+
}
131+
132+
} // namespace
133+
} // namespace native
134+
} // namespace HiFi
135+
} // namespace impl
136+
} // namespace cadence

0 commit comments

Comments
 (0)