Skip to content

Commit 2f782bf

Browse files
authored
Extend op_add for complex dtype
Differential Revision: D79091064 Pull Request resolved: #12977
1 parent e19e4a7 commit 2f782bf

File tree

4 files changed

+137
-16
lines changed

4 files changed

+137
-16
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,33 @@ Tensor& opt_add_out(
3333
ScalarType out_type = out.scalar_type();
3434

3535
if (b.numel() == 1) {
36-
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
36+
if (executorch::runtime::isComplexType(a_type) ||
37+
executorch::runtime::isComplexType(b_type) ||
38+
executorch::runtime::isComplexType(out_type)) {
39+
// TODO: The current support for complex dtype enforces that input and
40+
// output tensors have the same dtype. Support mixed dtypes in the future.
41+
ET_KERNEL_CHECK(
42+
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);
43+
ET_KERNEL_CHECK(
44+
ctx,
45+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
46+
InvalidArgument,
47+
out);
48+
49+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
50+
CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
51+
CTYPE b_val = *b.const_data_ptr<CTYPE>();
52+
53+
using Vec = at::vec::Vectorized<CTYPE>;
54+
at::vec::map<CTYPE>(
55+
[alpha_val, b_val](Vec x) { return x + Vec(alpha_val * b_val); },
56+
out.mutable_data_ptr<CTYPE>(),
57+
a.const_data_ptr<CTYPE>(),
58+
out.numel());
59+
});
60+
return out;
61+
} else if (
62+
a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
3763
a_type != ScalarType::BFloat16) {
3864
ET_KERNEL_CHECK(
3965
ctx,

kernels/optimized/cpu/op_add_sub_impl.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ Tensor& opt_add_sub_out_impl(
8585
ScalarType out_type = out.scalar_type();
8686

8787
auto selected_optimized_path = select_optimized_path(a, b, out);
88+
89+
if (executorch::runtime::isComplexType(a_type) ||
90+
executorch::runtime::isComplexType(b_type) ||
91+
executorch::runtime::isComplexType(out_type)) {
92+
// TODO: The current implementation for complex dtypes enforces that the
93+
// inputs and output tensors have same dtype and shape. Handle mixed dtypes
94+
// and broadcasting in the future.
95+
ET_KERNEL_CHECK(
96+
ctx,
97+
a_type == b_type && a_type == out_type &&
98+
selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d,
99+
InvalidArgument,
100+
out);
101+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
102+
CTYPE alpha_val = torch::executor::native::utils::scalar_to<CTYPE>(alpha);
103+
if constexpr (is_sub) {
104+
alpha_val = -alpha_val;
105+
}
106+
using Vec = at::vec::Vectorized<CTYPE>;
107+
at::vec::map2<CTYPE>(
108+
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
109+
out.mutable_data_ptr<CTYPE>(),
110+
a.const_data_ptr<CTYPE>(),
111+
b.const_data_ptr<CTYPE>(),
112+
out.numel());
113+
});
114+
return out;
115+
}
116+
88117
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
89118
// Resize for dynamic shape
90119
auto error = resize_tensor(out, a.sizes());

kernels/portable/cpu/op_add.cpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,47 @@ Tensor& add_out(
5050
// @lint-ignore CLANGTIDY facebook-hte-CArray
5151
static constexpr const char op_name[] = "add.out";
5252

53-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
CTYPE_COMPUTE val_alpha;
53+
if (executorch::runtime::isComplexType(a.scalar_type()) ||
54+
executorch::runtime::isComplexType(b.scalar_type()) ||
55+
executorch::runtime::isComplexType(out.scalar_type())) {
56+
// TODO: The current support for complex dtype enforces that input and
57+
// output tensors have the same dtype. Support mixed dtypes in the future.
5558
ET_KERNEL_CHECK(
56-
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
57-
utils::apply_bitensor_elementwise_fn<
58-
CTYPE_COMPUTE,
59-
op_name,
60-
utils::SupportedTensorDtypes::REALHBBF16>(
61-
[val_alpha](const auto val_a, const auto val_b) {
62-
return val_a + val_alpha * val_b;
63-
},
6459
ctx,
65-
a,
66-
utils::SupportedTensorDtypes::REALHBBF16,
67-
b,
68-
utils::SupportedTensorDtypes::REALHBBF16,
60+
a.scalar_type() == b.scalar_type() &&
61+
a.scalar_type() == out.scalar_type(),
62+
InvalidArgument,
6963
out);
70-
});
64+
ET_SWITCH_COMPLEXH_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() {
65+
CTYPE val_alpha = utils::scalar_to<CTYPE>(alpha);
66+
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
67+
[val_alpha](const CTYPE val_a, const CTYPE val_b) {
68+
return val_a + val_alpha * val_b;
69+
},
70+
a,
71+
b,
72+
out);
73+
});
74+
} else {
75+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
76+
CTYPE_COMPUTE val_alpha;
77+
ET_KERNEL_CHECK(
78+
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
79+
utils::apply_bitensor_elementwise_fn<
80+
CTYPE_COMPUTE,
81+
op_name,
82+
utils::SupportedTensorDtypes::REALHBBF16>(
83+
[val_alpha](const auto val_a, const auto val_b) {
84+
return val_a + val_alpha * val_b;
85+
},
86+
ctx,
87+
a,
88+
utils::SupportedTensorDtypes::REALHBBF16,
89+
b,
90+
utils::SupportedTensorDtypes::REALHBBF16,
91+
out);
92+
});
93+
}
7194

7295
return out;
7396
}

kernels/test/op_add_test.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,45 @@ class OpAddOutKernelTest : public OperatorTest {
8989
#undef ENUMERATE_TEST_ENTRY
9090
}
9191

92+
template <typename CTYPE, ScalarType DTYPE>
93+
void test_add_complex_dtype() {
94+
TensorFactory<DTYPE> tf;
95+
96+
// Both inputs have the same shape
97+
Tensor x_0 = tf.make({2}, {CTYPE(1, 2.1), CTYPE(3.1, 4)});
98+
Tensor y_0 = tf.make({2}, {CTYPE(5.2, 6.3), CTYPE(7, 8.9)});
99+
// Destination for the sum.
100+
Tensor out = tf.full({2}, CTYPE{0, 0});
101+
// Add two tensors.
102+
op_add_out(
103+
x_0,
104+
y_0,
105+
/*alpha=*/1,
106+
out);
107+
Tensor expected_0 = tf.make({2}, {CTYPE(6.2, 8.4), CTYPE(10.1, 12.9)});
108+
// Check that it matches the expected output.
109+
EXPECT_TENSOR_EQ(out, expected_0);
110+
111+
// Other tensor has numel() = 1
112+
Tensor y_1 = tf.make({1}, {CTYPE(2, 3)});
113+
// Add two tensors.
114+
op_add_out(
115+
x_0,
116+
y_1,
117+
/*alpha=*/2,
118+
out);
119+
Tensor expected_1 = tf.make({2}, {CTYPE(5, 8.1), CTYPE(7.1, 10)});
120+
// Check that it matches the expected output.
121+
EXPECT_TENSOR_EQ(out, expected_1);
122+
}
123+
124+
void test_add_enumerate_complex_types() {
125+
#define RUN_COMPLEX_TEST(ctype, dtype) \
126+
test_add_complex_dtype<ctype, ScalarType::dtype>();
127+
ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST);
128+
#undef RUN_COMPLEX_TEST
129+
}
130+
92131
// Common testing for adding two floating point Tensors.
93132
template <ScalarType DTYPE>
94133
void test_floating_point_add_out() {
@@ -293,6 +332,10 @@ TEST_F(OpAddOutKernelTest, AllRealDtypesSupported) {
293332
test_add_enumerate_a_types();
294333
}
295334

335+
TEST_F(OpAddOutKernelTest, ComplexTensors) {
336+
test_add_enumerate_complex_types();
337+
}
338+
296339
TEST_F(OpAddOutKernelTest, FloatTensors) {
297340
test_floating_point_add_out<ScalarType::Float>();
298341
}

0 commit comments

Comments
 (0)