Skip to content

Commit 753c4ad

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Extend cat op for complex dtype (pytorch#12894)
Summary: Need complex cat op for live translation Differential Revision: D78934592
1 parent 5338708 commit 753c4ad

File tree

2 files changed

+104
-15
lines changed

2 files changed

+104
-15
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,57 @@ Tensor& cat_out(
5656
const size_t ninputs = tensors.size();
5757

5858
const auto out_type = out.scalar_type();
59-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
60-
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
61-
for (size_t i = 0; i < outer; ++i) {
62-
for (size_t j = 0; j < ninputs; ++j) {
63-
const auto in_type = tensors[j].scalar_type();
64-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
59+
const bool out_is_complex =
60+
executorch::runtime::isComplexType(out.scalar_type());
61+
62+
if (out_is_complex) {
63+
// All the input tensors and output must have same dtype
64+
for (size_t i = 0; i < ninputs; ++i) {
65+
const auto in_type = tensors[i].scalar_type();
66+
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
67+
}
68+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
69+
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
70+
for (size_t i = 0; i < outer; ++i) {
71+
for (size_t j = 0; j < ninputs; ++j) {
6572
if (tensors[j].numel() == 0) {
6673
return;
6774
}
6875
size_t inner = tensors[j].size(dim) * dim_stride;
69-
const CTYPE_IN* const in_ptr =
70-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
71-
72-
for (size_t k = 0; k < inner; ++k) {
73-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
74-
}
76+
const CTYPE* const in_ptr =
77+
tensors[j].const_data_ptr<CTYPE>() + i * inner;
78+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE));
7579
out_ptr += inner;
76-
});
80+
}
7781
}
78-
}
79-
});
82+
});
83+
} else {
84+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
85+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
86+
for (size_t i = 0; i < outer; ++i) {
87+
for (size_t j = 0; j < ninputs; ++j) {
88+
const auto in_type = tensors[j].scalar_type();
89+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
90+
if (tensors[j].numel() == 0) {
91+
return;
92+
}
93+
size_t inner = tensors[j].size(dim) * dim_stride;
94+
const CTYPE_IN* const in_ptr =
95+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
96+
97+
if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) {
98+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN));
99+
} else {
100+
for (size_t k = 0; k < inner; ++k) {
101+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
102+
}
103+
}
104+
out_ptr += inner;
105+
});
106+
}
107+
}
108+
});
109+
}
80110

81111
return out;
82112
}

kernels/test/op_cat_test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,58 @@ class OpCatOutTest : public OperatorTest {
7373
tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0});
7474
EXPECT_TENSOR_EQ(out, expected);
7575
}
76+
77+
template <typename CTYPE, ScalarType DTYPE>
78+
void test_complex_dtype() {
79+
TensorFactory<DTYPE> tf;
80+
Tensor x = tf.make(
81+
{2, 2},
82+
{CTYPE(0.01, 2.03),
83+
CTYPE(4.05, 6.07),
84+
CTYPE(0.11, 2.13),
85+
CTYPE(4.15, 6.17)});
86+
Tensor y = tf.make(
87+
{2, 2},
88+
{CTYPE(0.21, 2.23),
89+
CTYPE(4.25, 6.27),
90+
CTYPE(0.31, 2.33),
91+
CTYPE(4.35, 6.37)});
92+
93+
std::vector<Tensor> inputs = {x, y};
94+
95+
// Concatenate along dim[0].
96+
Tensor out_0 = tf.full({4, 2}, CTYPE{0, 0});
97+
Tensor ret_0 = op_cat_out(
98+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/0, out_0);
99+
Tensor expected_0 = tf.make(
100+
{4, 2},
101+
{CTYPE(0.01, 2.03),
102+
CTYPE(4.05, 6.07),
103+
CTYPE(0.11, 2.13),
104+
CTYPE(4.15, 6.17),
105+
CTYPE(0.21, 2.23),
106+
CTYPE(4.25, 6.27),
107+
CTYPE(0.31, 2.33),
108+
CTYPE(4.35, 6.37)});
109+
110+
EXPECT_TENSOR_EQ(out_0, expected_0);
111+
112+
// Concatenate along dim[1].
113+
Tensor out_1 = tf.full({2, 4}, CTYPE{0, 0});
114+
Tensor ret_1 = op_cat_out(
115+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/1, out_1);
116+
Tensor expected_1 = tf.make(
117+
{2, 4},
118+
{CTYPE(0.01, 2.03),
119+
CTYPE(4.05, 6.07),
120+
CTYPE(0.21, 2.23),
121+
CTYPE(4.25, 6.27),
122+
CTYPE(0.11, 2.13),
123+
CTYPE(4.15, 6.17),
124+
CTYPE(0.31, 2.33),
125+
CTYPE(4.35, 6.37)});
126+
EXPECT_TENSOR_EQ(out_1, expected_1);
127+
}
76128
};
77129

78130
TEST_F(OpCatOutTest, SmokeDim1) {
@@ -133,6 +185,13 @@ TEST_F(OpCatOutTest, SixteenBitFloatSupport) {
133185
test_16bit_dtype<ScalarType::BFloat16>();
134186
}
135187

188+
TEST_F(OpCatOutTest, ComplexSupport) {
189+
#define RUN_COMPLEX_TEST(ctype, dtype) \
190+
test_complex_dtype<ctype, ScalarType::dtype>();
191+
ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST);
192+
#undef RUN_COMPLEX_TEST
193+
}
194+
136195
TEST_F(OpCatOutTest, NegativeDims) {
137196
TensorFactory<ScalarType::Int> tf;
138197

0 commit comments

Comments
 (0)