Skip to content

Commit ccef472

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Extend cat op for complex dtype (#12894)
Summary: Need complex cat op for live translation. The current support for complex dtype enforces that input and output tensors have the same dtype. Support mixed dtypes in the future. Differential Revision: D78934592
1 parent 8b2ddb2 commit ccef472

File tree

2 files changed

+105
-15
lines changed

2 files changed

+105
-15
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,58 @@ Tensor& cat_out(
5656
const size_t ninputs = tensors.size();
5757

5858
const auto out_type = out.scalar_type();
59-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
60-
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
61-
for (size_t i = 0; i < outer; ++i) {
62-
for (size_t j = 0; j < ninputs; ++j) {
63-
const auto in_type = tensors[j].scalar_type();
64-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
59+
const bool out_is_complex =
60+
executorch::runtime::isComplexType(out.scalar_type());
61+
62+
if (out_is_complex) {
63+
// TODO: The current support for complex dtype enforces that input and
64+
// output tensors have the same dtype. Support mixed dtypes in the future.
65+
for (size_t i = 0; i < ninputs; ++i) {
66+
const auto in_type = tensors[i].scalar_type();
67+
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
68+
}
69+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
70+
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
71+
for (size_t i = 0; i < outer; ++i) {
72+
for (size_t j = 0; j < ninputs; ++j) {
6573
if (tensors[j].numel() == 0) {
6674
return;
6775
}
6876
size_t inner = tensors[j].size(dim) * dim_stride;
69-
const CTYPE_IN* const in_ptr =
70-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
71-
72-
for (size_t k = 0; k < inner; ++k) {
73-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
74-
}
77+
const CTYPE* const in_ptr =
78+
tensors[j].const_data_ptr<CTYPE>() + i * inner;
79+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE));
7580
out_ptr += inner;
76-
});
81+
}
7782
}
78-
}
79-
});
83+
});
84+
} else {
85+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
86+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
87+
for (size_t i = 0; i < outer; ++i) {
88+
for (size_t j = 0; j < ninputs; ++j) {
89+
const auto in_type = tensors[j].scalar_type();
90+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
91+
if (tensors[j].numel() == 0) {
92+
return;
93+
}
94+
size_t inner = tensors[j].size(dim) * dim_stride;
95+
const CTYPE_IN* const in_ptr =
96+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
97+
98+
if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) {
99+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN));
100+
} else {
101+
for (size_t k = 0; k < inner; ++k) {
102+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
103+
}
104+
}
105+
out_ptr += inner;
106+
});
107+
}
108+
}
109+
});
110+
}
80111

81112
return out;
82113
}

kernels/test/op_cat_test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,58 @@ class OpCatOutTest : public OperatorTest {
7373
tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0});
7474
EXPECT_TENSOR_EQ(out, expected);
7575
}
76+
77+
template <typename CTYPE, ScalarType DTYPE>
78+
void test_complex_dtype() {
79+
TensorFactory<DTYPE> tf;
80+
Tensor x = tf.make(
81+
{2, 2},
82+
{CTYPE(0.01, 2.03),
83+
CTYPE(4.05, 6.07),
84+
CTYPE(0.11, 2.13),
85+
CTYPE(4.15, 6.17)});
86+
Tensor y = tf.make(
87+
{2, 2},
88+
{CTYPE(0.21, 2.23),
89+
CTYPE(4.25, 6.27),
90+
CTYPE(0.31, 2.33),
91+
CTYPE(4.35, 6.37)});
92+
93+
std::vector<Tensor> inputs = {x, y};
94+
95+
// Concatenate along dim[0].
96+
Tensor out_0 = tf.full({4, 2}, CTYPE{0, 0});
97+
Tensor ret_0 = op_cat_out(
98+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/0, out_0);
99+
Tensor expected_0 = tf.make(
100+
{4, 2},
101+
{CTYPE(0.01, 2.03),
102+
CTYPE(4.05, 6.07),
103+
CTYPE(0.11, 2.13),
104+
CTYPE(4.15, 6.17),
105+
CTYPE(0.21, 2.23),
106+
CTYPE(4.25, 6.27),
107+
CTYPE(0.31, 2.33),
108+
CTYPE(4.35, 6.37)});
109+
110+
EXPECT_TENSOR_EQ(out_0, expected_0);
111+
112+
// Concatenate along dim[1].
113+
Tensor out_1 = tf.full({2, 4}, CTYPE{0, 0});
114+
Tensor ret_1 = op_cat_out(
115+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/1, out_1);
116+
Tensor expected_1 = tf.make(
117+
{2, 4},
118+
{CTYPE(0.01, 2.03),
119+
CTYPE(4.05, 6.07),
120+
CTYPE(0.21, 2.23),
121+
CTYPE(4.25, 6.27),
122+
CTYPE(0.11, 2.13),
123+
CTYPE(4.15, 6.17),
124+
CTYPE(0.31, 2.33),
125+
CTYPE(4.35, 6.37)});
126+
EXPECT_TENSOR_EQ(out_1, expected_1);
127+
}
76128
};
77129

78130
TEST_F(OpCatOutTest, SmokeDim1) {
@@ -133,6 +185,13 @@ TEST_F(OpCatOutTest, SixteenBitFloatSupport) {
133185
test_16bit_dtype<ScalarType::BFloat16>();
134186
}
135187

188+
TEST_F(OpCatOutTest, ComplexSupport) {
189+
#define RUN_COMPLEX_TEST(ctype, dtype) \
190+
test_complex_dtype<ctype, ScalarType::dtype>();
191+
ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST);
192+
#undef RUN_COMPLEX_TEST
193+
}
194+
136195
TEST_F(OpCatOutTest, NegativeDims) {
137196
TensorFactory<ScalarType::Int> tf;
138197

0 commit comments

Comments
 (0)