Skip to content

Commit 97699f1

Browse files
authored
Extend cat op for complex dtype
Differential Revision: D78934592 Pull Request resolved: #12894
1 parent ec4228c commit 97699f1

File tree

2 files changed

+105
-15
lines changed

2 files changed

+105
-15
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,58 @@ Tensor& cat_out(
5656
const size_t ninputs = tensors.size();
5757

5858
const auto out_type = out.scalar_type();
59-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
60-
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
61-
for (size_t i = 0; i < outer; ++i) {
62-
for (size_t j = 0; j < ninputs; ++j) {
63-
const auto in_type = tensors[j].scalar_type();
64-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
59+
const bool out_is_complex =
60+
executorch::runtime::isComplexType(out.scalar_type());
61+
62+
if (out_is_complex) {
63+
// TODO: The current support for complex dtype enforces that input and
64+
// output tensors have the same dtype. Support mixed dtypes in the future.
65+
for (size_t i = 0; i < ninputs; ++i) {
66+
const auto in_type = tensors[i].scalar_type();
67+
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
68+
}
69+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
70+
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
71+
for (size_t i = 0; i < outer; ++i) {
72+
for (size_t j = 0; j < ninputs; ++j) {
6573
if (tensors[j].numel() == 0) {
6674
return;
6775
}
6876
size_t inner = tensors[j].size(dim) * dim_stride;
69-
const CTYPE_IN* const in_ptr =
70-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
71-
72-
for (size_t k = 0; k < inner; ++k) {
73-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
74-
}
77+
const CTYPE* const in_ptr =
78+
tensors[j].const_data_ptr<CTYPE>() + i * inner;
79+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE));
7580
out_ptr += inner;
76-
});
81+
}
7782
}
78-
}
79-
});
83+
});
84+
} else {
85+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
86+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
87+
for (size_t i = 0; i < outer; ++i) {
88+
for (size_t j = 0; j < ninputs; ++j) {
89+
const auto in_type = tensors[j].scalar_type();
90+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
91+
if (tensors[j].numel() == 0) {
92+
return;
93+
}
94+
size_t inner = tensors[j].size(dim) * dim_stride;
95+
const CTYPE_IN* const in_ptr =
96+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
97+
98+
if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) {
99+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN));
100+
} else {
101+
for (size_t k = 0; k < inner; ++k) {
102+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
103+
}
104+
}
105+
out_ptr += inner;
106+
});
107+
}
108+
}
109+
});
110+
}
80111

81112
return out;
82113
}

kernels/test/op_cat_test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,58 @@ class OpCatOutTest : public OperatorTest {
7373
tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0});
7474
EXPECT_TENSOR_EQ(out, expected);
7575
}
76+
77+
template <typename CTYPE, ScalarType DTYPE>
78+
void test_complex_dtype() {
79+
TensorFactory<DTYPE> tf;
80+
Tensor x = tf.make(
81+
{2, 2},
82+
{CTYPE(0.01, 2.03),
83+
CTYPE(4.05, 6.07),
84+
CTYPE(0.11, 2.13),
85+
CTYPE(4.15, 6.17)});
86+
Tensor y = tf.make(
87+
{2, 2},
88+
{CTYPE(0.21, 2.23),
89+
CTYPE(4.25, 6.27),
90+
CTYPE(0.31, 2.33),
91+
CTYPE(4.35, 6.37)});
92+
93+
std::vector<Tensor> inputs = {x, y};
94+
95+
// Concatenate along dim[0].
96+
Tensor out_0 = tf.full({4, 2}, CTYPE{0, 0});
97+
Tensor ret_0 = op_cat_out(
98+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/0, out_0);
99+
Tensor expected_0 = tf.make(
100+
{4, 2},
101+
{CTYPE(0.01, 2.03),
102+
CTYPE(4.05, 6.07),
103+
CTYPE(0.11, 2.13),
104+
CTYPE(4.15, 6.17),
105+
CTYPE(0.21, 2.23),
106+
CTYPE(4.25, 6.27),
107+
CTYPE(0.31, 2.33),
108+
CTYPE(4.35, 6.37)});
109+
110+
EXPECT_TENSOR_EQ(out_0, expected_0);
111+
112+
// Concatenate along dim[1].
113+
Tensor out_1 = tf.full({2, 4}, CTYPE{0, 0});
114+
Tensor ret_1 = op_cat_out(
115+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/1, out_1);
116+
Tensor expected_1 = tf.make(
117+
{2, 4},
118+
{CTYPE(0.01, 2.03),
119+
CTYPE(4.05, 6.07),
120+
CTYPE(0.21, 2.23),
121+
CTYPE(4.25, 6.27),
122+
CTYPE(0.11, 2.13),
123+
CTYPE(4.15, 6.17),
124+
CTYPE(0.31, 2.33),
125+
CTYPE(4.35, 6.37)});
126+
EXPECT_TENSOR_EQ(out_1, expected_1);
127+
}
76128
};
77129

78130
TEST_F(OpCatOutTest, SmokeDim1) {
@@ -133,6 +185,13 @@ TEST_F(OpCatOutTest, SixteenBitFloatSupport) {
133185
test_16bit_dtype<ScalarType::BFloat16>();
134186
}
135187

188+
TEST_F(OpCatOutTest, ComplexSupport) {
189+
#define RUN_COMPLEX_TEST(ctype, dtype) \
190+
test_complex_dtype<ctype, ScalarType::dtype>();
191+
ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST);
192+
#undef RUN_COMPLEX_TEST
193+
}
194+
136195
TEST_F(OpCatOutTest, NegativeDims) {
137196
TensorFactory<ScalarType::Int> tf;
138197

0 commit comments

Comments
 (0)