diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 7ee880d9977..489421f1b2d 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -15,6 +15,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native { @@ -138,110 +140,9 @@ Tensor& opt_sub_out( } } - auto selected_optimized_path = select_optimized_path(a, b, out); - if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.out", CTYPE, [&]() { - CTYPE alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - using Vec = executorch::vec::Vectorized; - executorch::vec::map2( - [alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; }, - out.mutable_data_ptr(), - a.const_data_ptr(), - b.const_data_ptr(), - out.numel()); - }); - } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - const Tensor* lhs; - const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when subing new broadcasting possibility. - ET_DCHECK( - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - ET_SWITCH_REAL_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() { - CTYPE alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - using Vec = executorch::vec::Vectorized; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { - executorch::vec::broadcasting_map_2d_by_1d( - [alpha_val](Vec x, Vec y) { return y - Vec(alpha_val) * x; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); - } else { - executorch::vec::broadcasting_map_2d_by_1d( - [alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); - } - }); - } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() { - ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - - SubInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); - }); - } - - return out; + static constexpr const char op_name[] = "sub.out"; + return torch::executor::kernels::impl::opt_add_sub_out_impl( + ctx, a, b, alpha, out); } Tensor& opt_sub_scalar_out( diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 41dde099290..2a66407a5ce 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -90,6 +90,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_sub", deps = [ ":binary_ops", + ":add_sub_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", ], diff --git a/kernels/test/op_sub_test.cpp b/kernels/test/op_sub_test.cpp index 39fc9e14925..aafaf688b0d 100644 --- a/kernels/test/op_sub_test.cpp +++ b/kernels/test/op_sub_test.cpp @@ -99,6 +99,109 @@ class OpSubOutTest : public OperatorTest { EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.1, 1.2, 3.4, 7.8})); } + template + void test_broadcast_3D() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7}); + + // Destination for output of mul. + Tensor out = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor expected = + tf_a.make({2, 2, 3}, /*data=*/{-1, -1, -1, 2, 2, 2, 2, 2, 2, 5, 5, 5}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected); + // b - a * 1.5 output should be + expected = tf_a.make( + {2, 2, 3}, + /*data=*/ + {0.5, + 0.0, + -0.5, + -4.0, + -4.5, + -5.0, + -5.5, + -6.0, + -6.5, + -10.0, + -10.5, + -11.0}); + EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.5, out), expected); + } + + template + void test_broadcast_4D() { + TensorFactory tf_a; + + Tensor a = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); + Tensor b = tf_a.make( + {2, 1, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({2, 2, 3, 5}); + Tensor expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, -15, -15, -15, -15, -15, -15, -15, -15, -15, + -15, -15, -15, -15, -15, -15, -15, -15, -15, -15, -15, -15, + -15, -15, -15, -15, -15, -15, -15, -15, -15, -30, -30, -30, + -30, -30, -30, -30, -30, -30, -30, -30, -30, -30, -30, -30}); + EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.0, out), expected); + + b = tf_a.make( + {2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + out = tf_a.zeros({2, 2, 3, 5}); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 15, 15, 15, 15, 15, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 25, 25, 25, 25, 25, 30, 30, 30, 30, 30, + 30, 30, 30, 30, 30, 35, 35, 35, 35, 35, 40, 40, 40, 40, 40}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{-0.5000, -1.0000, -1.5000, -2.0000, -2.5000, + -8.0000, -8.5000, -9.0000, -9.5000, -10.0000, + -15.5000, -16.0000, -16.5000, -17.0000, -17.5000, + + -18.0000, -18.5000, -19.0000, -19.5000, -20.0000, + -25.5000, -26.0000, -26.5000, -27.0000, -27.5000, + -33.0000, -33.5000, -34.0000, -34.5000, -35.0000, + + -35.5000, -36.0000, -36.5000, -37.0000, -37.5000, + -43.0000, -43.5000, -44.0000, -44.5000, -45.0000, + -50.5000, -51.0000, -51.5000, -52.0000, -52.5000, + + -53.0000, -53.5000, -54.0000, -54.5000, -55.0000, + -60.5000, -61.0000, -61.5000, -62.0000, -62.5000, + -68.0000, -68.5000, -69.0000, -69.5000, -70.0000}); + EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.5, out), expected); + } + void test_sub_enumerate_a_types() { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_sub_enumerate_b_types(); @@ -237,6 +340,19 @@ TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) { EXPECT_TENSOR_EQ(out, ret); } +TEST_F(OpSubOutTest, BroadcastNDTest) { + // Test 3D tensors + test_broadcast_3D(); + test_broadcast_3D(); + // Sub doesnt yet support BFloat16 + // test_broadcast_3D(); + + // Test 4D tensors + test_broadcast_4D(); + test_broadcast_4D(); + // test_broadcast_4D(); +} + // // Death Tests //