From e0b07d87aff78c1b6817e51ac8910f5e0fbc71ff Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Fri, 1 Aug 2025 19:13:34 -0700 Subject: [PATCH] Extend op_add for complex dtype (#12977) Summary: Add complex dtype support for op_add. The current support for complex dtype enforces that input and output tensors have the same dtype. Support mixed dtypes in the future. For optimized add op: Additionally, we do not support broadcasting. Reviewed By: manuelcandales Differential Revision: D79091064 --- kernels/optimized/cpu/op_add.cpp | 28 ++++++++++++- kernels/optimized/cpu/op_add_sub_impl.h | 29 ++++++++++++++ kernels/portable/cpu/op_add.cpp | 53 ++++++++++++++++++------- kernels/test/op_add_test.cpp | 43 ++++++++++++++++++++ 4 files changed, 137 insertions(+), 16 deletions(-) diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index 97bdb0a0d5e..88b102b5650 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -33,7 +33,33 @@ Tensor& opt_add_out( ScalarType out_type = out.scalar_type(); if (b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && + if (executorch::runtime::isComplexType(a_type) || + executorch::runtime::isComplexType(b_type) || + executorch::runtime::isComplexType(out_type)) { + // TODO: The current support for complex dtype enforces that input and + // output tensors have the same dtype. Support mixed dtypes in the future. + ET_KERNEL_CHECK( + ctx, a_type == b_type && a_type == out_type, InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "add.out", CTYPE, [&]() { + CTYPE alpha_val = utils::scalar_to(alpha); + CTYPE b_val = *b.const_data_ptr(); + + using Vec = at::vec::Vectorized; + at::vec::map( + [alpha_val, b_val](Vec x) { return x + Vec(alpha_val * b_val); }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); + }); + return out; + } else if ( + a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_KERNEL_CHECK( ctx, diff --git a/kernels/optimized/cpu/op_add_sub_impl.h b/kernels/optimized/cpu/op_add_sub_impl.h index 2dd865b294d..3fc22d88a63 100644 --- a/kernels/optimized/cpu/op_add_sub_impl.h +++ b/kernels/optimized/cpu/op_add_sub_impl.h @@ -85,6 +85,35 @@ Tensor& opt_add_sub_out_impl( ScalarType out_type = out.scalar_type(); auto selected_optimized_path = select_optimized_path(a, b, out); + + if (executorch::runtime::isComplexType(a_type) || + executorch::runtime::isComplexType(b_type) || + executorch::runtime::isComplexType(out_type)) { + // TODO: The current implementation for complex dtypes enforces that the + // inputs and output tensors have same dtype and shape. Handle mixed dtypes + // and broadcasting in the future. + ET_KERNEL_CHECK( + ctx, + a_type == b_type && a_type == out_type && + selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d, + InvalidArgument, + out); + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() { + CTYPE alpha_val = torch::executor::native::utils::scalar_to(alpha); + if constexpr (is_sub) { + alpha_val = -alpha_val; + } + using Vec = at::vec::Vectorized; + at::vec::map2( + [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, + out.mutable_data_ptr(), + a.const_data_ptr(), + b.const_data_ptr(), + out.numel()); + }); + return out; + } + if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { // Resize for dynamic shape auto error = resize_tensor(out, a.sizes()); diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index e10534cd233..122b2a2c97e 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -50,24 +50,47 @@ Tensor& add_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - CTYPE_COMPUTE val_alpha; + if (executorch::runtime::isComplexType(a.scalar_type()) || + executorch::runtime::isComplexType(b.scalar_type()) || + executorch::runtime::isComplexType(out.scalar_type())) { + // TODO: The current support for complex dtype enforces that input and + // output tensors have the same dtype. Support mixed dtypes in the future. ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, ); - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - [val_alpha](const auto val_a, const auto val_b) { - return val_a + val_alpha * val_b; - }, ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, + a.scalar_type() == b.scalar_type() && + a.scalar_type() == out.scalar_type(), + InvalidArgument, out); - }); + ET_SWITCH_COMPLEXH_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() { + CTYPE val_alpha = utils::scalar_to(alpha); + apply_binary_elementwise_fn( + [val_alpha](const CTYPE val_a, const CTYPE val_b) { + return val_a + val_alpha * val_b; + }, + a, + b, + out); + }); + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + CTYPE_COMPUTE val_alpha; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, ); + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + [val_alpha](const auto val_a, const auto val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out); + }); + } return out; } diff --git a/kernels/test/op_add_test.cpp b/kernels/test/op_add_test.cpp index 8af693e1b3e..c081b6dd3cc 100644 --- a/kernels/test/op_add_test.cpp +++ b/kernels/test/op_add_test.cpp @@ -89,6 +89,45 @@ class OpAddOutKernelTest : public OperatorTest { #undef ENUMERATE_TEST_ENTRY } + template + void test_add_complex_dtype() { + TensorFactory tf; + + // Both inputs have the same shape + Tensor x_0 = tf.make({2}, {CTYPE(1, 2.1), CTYPE(3.1, 4)}); + Tensor y_0 = tf.make({2}, {CTYPE(5.2, 6.3), CTYPE(7, 8.9)}); + // Destination for the sum. + Tensor out = tf.full({2}, CTYPE{0, 0}); + // Add two tensors. + op_add_out( + x_0, + y_0, + /*alpha=*/1, + out); + Tensor expected_0 = tf.make({2}, {CTYPE(6.2, 8.4), CTYPE(10.1, 12.9)}); + // Check that it matches the expected output. + EXPECT_TENSOR_EQ(out, expected_0); + + // Other tensor has numel() = 1 + Tensor y_1 = tf.make({1}, {CTYPE(2, 3)}); + // Add two tensors. + op_add_out( + x_0, + y_1, + /*alpha=*/2, + out); + Tensor expected_1 = tf.make({2}, {CTYPE(5, 8.1), CTYPE(7.1, 10)}); + // Check that it matches the expected output. + EXPECT_TENSOR_EQ(out, expected_1); + } + + void test_add_enumerate_complex_types() { +#define RUN_COMPLEX_TEST(ctype, dtype) \ + test_add_complex_dtype(); + ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST); +#undef RUN_COMPLEX_TEST + } + // Common testing for adding two floating point Tensors. template void test_floating_point_add_out() { @@ -293,6 +332,10 @@ TEST_F(OpAddOutKernelTest, AllRealDtypesSupported) { test_add_enumerate_a_types(); } +TEST_F(OpAddOutKernelTest, ComplexTensors) { + test_add_enumerate_complex_types(); +} + TEST_F(OpAddOutKernelTest, FloatTensors) { test_floating_point_add_out(); }