Skip to content

Commit 13e9e03

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Extend cat op for complex dtype (pytorch#12894)
Summary: Need complex cat op for live translation Differential Revision: D78934592
1 parent 1c72e0e commit 13e9e03

File tree

2 files changed

+111
-20
lines changed

2 files changed

+111
-20
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,58 @@ Tensor& cat_out(
5656
const size_t ninputs = tensors.size();
5757

5858
const auto out_type = out.scalar_type();
59-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
60-
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
61-
for (size_t i = 0; i < outer; ++i) {
62-
for (size_t j = 0; j < ninputs; ++j) {
63-
const auto in_type = tensors[j].scalar_type();
64-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
65-
if (tensors[j].numel() == 0) {
66-
return;
67-
}
68-
size_t inner = tensors[j].size(dim) * dim_stride;
69-
const CTYPE_IN* const in_ptr =
70-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
71-
72-
for (size_t k = 0; k < inner; ++k) {
73-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
74-
}
75-
out_ptr += inner;
76-
});
77-
}
59+
const bool out_is_complex =
60+
executorch::runtime::isComplexType(out.scalar_type());
61+
62+
if (out_is_complex) {
63+
// All the input tensors and output must have same dtype
64+
for (size_t i = 0; i < ninputs; ++i) {
65+
const auto in_type = tensors[i].scalar_type();
66+
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
7867
}
79-
});
68+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
69+
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
70+
for (size_t i = 0; i < outer; ++i) {
71+
for (size_t j = 0; j < ninputs; ++j) {
72+
if (tensors[j].numel() == 0) {
73+
return;
74+
}
75+
size_t inner = tensors[j].size(dim) * dim_stride;
76+
const CTYPE* const in_ptr =
77+
tensors[j].const_data_ptr<CTYPE>() + i * inner;
78+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE));
79+
out_ptr += inner;
80+
81+
}
82+
}
83+
});
84+
} else {
85+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
86+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
87+
for (size_t i = 0; i < outer; ++i) {
88+
for (size_t j = 0; j < ninputs; ++j) {
89+
const auto in_type = tensors[j].scalar_type();
90+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
91+
if (tensors[j].numel() == 0) {
92+
return;
93+
}
94+
size_t inner = tensors[j].size(dim) * dim_stride;
95+
const CTYPE_IN* const in_ptr =
96+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
97+
98+
if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) {
99+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN));
100+
} else {
101+
for (size_t k = 0; k < inner; ++k) {
102+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
103+
}
104+
}
105+
out_ptr += inner;
106+
});
107+
}
108+
}
109+
});
110+
}
80111

81112
return out;
82113
}

kernels/test/op_cat_test.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,59 @@ class OpCatOutTest : public OperatorTest {
7373
tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0});
7474
EXPECT_TENSOR_EQ(out, expected);
7575
}
76+
77+
template <typename CTYPE, ScalarType DTYPE>
78+
void test_complex_dtype() {
79+
TensorFactory<DTYPE> tf;
80+
Tensor x = tf.make(
81+
{2, 2},
82+
{CTYPE(0.01, 2.03),
83+
CTYPE(4.05, 6.07),
84+
CTYPE(0.11, 2.13),
85+
CTYPE(4.15, 6.17)});
86+
Tensor y = tf.make(
87+
{2, 2},
88+
{
89+
CTYPE(0.21, 2.23),
90+
CTYPE(4.25, 6.27),
91+
CTYPE(0.31, 2.33),
92+
CTYPE(4.35, 6.37),
93+
});
94+
95+
std::vector<Tensor> inputs = {x, y};
96+
97+
// Concatenate along dim[0].
98+
Tensor out_0 = tf.full({4, 2}, CTYPE{0, 0});
99+
Tensor ret_0 = op_cat_out(
100+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/0, out_0);
101+
Tensor expected_0 = tf.make(
102+
{4, 2},
103+
{CTYPE(0.01, 2.03),
104+
CTYPE(4.05, 6.07),
105+
CTYPE(0.11, 2.13),
106+
CTYPE(4.15, 6.17),
107+
CTYPE(0.21, 2.23),
108+
CTYPE(4.25, 6.27),
109+
CTYPE(0.31, 2.33),
110+
CTYPE(4.35, 6.37)});
111+
EXPECT_TENSOR_EQ(out_0, expected_0);
112+
113+
// Concatenate along dim[1].
114+
Tensor out_1 = tf.full({2, 4}, CTYPE{0, 0});
115+
Tensor ret_1 = op_cat_out(
116+
ArrayRef<Tensor>(inputs.data(), inputs.size()), /*dim=*/1, out_1);
117+
Tensor expected_1 = tf.make(
118+
{2, 4},
119+
{CTYPE(0.01, 2.03),
120+
CTYPE(4.05, 6.07),
121+
CTYPE(0.21, 2.23),
122+
CTYPE(4.25, 6.27),
123+
CTYPE(0.11, 2.13),
124+
CTYPE(4.15, 6.17),
125+
CTYPE(0.31, 2.33),
126+
CTYPE(4.35, 6.37)});
127+
EXPECT_TENSOR_EQ(out_1, expected_1);
128+
}
76129
};
77130

78131
TEST_F(OpCatOutTest, SmokeDim1) {
@@ -133,6 +186,13 @@ TEST_F(OpCatOutTest, SixteenBitFloatSupport) {
133186
test_16bit_dtype<ScalarType::BFloat16>();
134187
}
135188

189+
TEST_F(OpCatOutTest, ComplexSupport) {
190+
#define RUN_COMPLEX_TEST(ctype, dtype) \
191+
test_complex_dtype<ctype, ScalarType::dtype>();
192+
ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST);
193+
#undef RUN_COMPLEX_TEST
194+
}
195+
136196
TEST_F(OpCatOutTest, NegativeDims) {
137197
TensorFactory<ScalarType::Int> tf;
138198

0 commit comments

Comments
 (0)