diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index b86f35be387..f59c9fd5d76 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include namespace torch { @@ -235,7 +236,8 @@ Tensor& handle_broadcast_elementwise( const Tensor& a, const Tensor& b, Tensor& out, - const ElementwiseOptimizedPath selected_optimized_path) { + const ElementwiseOptimizedPath selected_optimized_path, + const executorch::aten::optional& alpha = {}) { if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastLastDim) || (selected_optimized_path == diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index f35c8bf594f..dbf828e5882 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -140,40 +140,41 @@ Tensor& opt_add_out( out.numel()); }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - const Tensor* lhs; - const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when adding new broadcasting possibility. - ET_DCHECK( - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() { CTYPE alpha_val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); - + ET_KERNEL_CHECK_MSG( + ctx, + utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, + out, + "Failed to extract scalar alpha."); using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_2d_by_1d( - [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); + Vec alpha_val_vec(alpha_val); + if (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { + // Reason we swap out args here is because handle_broadcast_elementwise + // handles this selected_optimized_path option a bit differently. + // This should really be resolved in handle_broadcast_elementwise. + // However, the current blocker is that handle_broadcast_elementwise + // tries to be agnostic of op. This should be fixed, likely by moving + // lambda creation to handle_broadcast_elementwise and it be aware of + // which op is being executed. + auto add_lambda = [&alpha_val_vec](auto x, auto y) { + return y + alpha_val_vec * x; + }; + return torch::executor::handle_broadcast_elementwise( + ctx, add_lambda, a, b, out, selected_optimized_path, alpha); + } else { + auto add_lambda = [&alpha_val_vec](auto x, auto y) { + return x + alpha_val_vec * y; + }; + return torch::executor::handle_broadcast_elementwise( + ctx, add_lambda, a, b, out, selected_optimized_path, alpha); + } }); } else { ScalarType common_type = diff --git a/kernels/test/op_add_test.cpp b/kernels/test/op_add_test.cpp index f91249a96c3..c84341aa9b1 100644 --- a/kernels/test/op_add_test.cpp +++ b/kernels/test/op_add_test.cpp @@ -112,6 +112,125 @@ class OpAddOutKernelTest : public OperatorTest { // tests. EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125})); } + + template + void test_broadcast_3D() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7}); + + // Destination for output of mul. + Tensor out = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor expected = tf_a.make( + {2, 2, 3}, /*data=*/{3, 5, 7, 6, 8, 10, 12, 14, 16, 15, 17, 19}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); + expected = tf_a.make( + {2, 2, 3}, + /*data=*/{3.5, 6, 8.5, 8, 10.5, 13, 15.5, 18, 20.5, 20, 22.5, 25}); + EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.5, out), expected); + } + + template + void test_broadcast_4D() { + TensorFactory tf_a; + + Tensor a = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); + Tensor b = tf_a.make( + {2, 1, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({2, 2, 3, 5}); + Tensor expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, + 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, + 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); + EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected); + + b = tf_a.make( + {2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + out = tf_a.zeros({2, 2, 3, 5}); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{2, 4, 6, 8, 10, 7, 9, 11, 13, 15, 12, 14, 16, 18, 20, + 22, 24, 26, 28, 30, 27, 29, 31, 33, 35, 32, 34, 36, 38, 40, + 42, 44, 46, 48, 50, 47, 49, 51, 53, 55, 52, 54, 56, 58, 60, + 62, 64, 66, 68, 70, 67, 69, 71, 73, 75, 72, 74, 76, 78, 80}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); + EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected); + } + + template + void test_broadcast_last_dim() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({4, 3}); + Tensor expected = + tf_a.make({4, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); + EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected); + + a = tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5}); + + // Destination for output of mul. + out = tf_a.zeros({2, 2, 3}); + expected = tf_a.make( + {2, 2, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); + EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected); + + a = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); + b = tf_a.make( + {2, 2, 3, 1}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + // Destination for output of mul. + out = tf_a.zeros({2, 2, 3, 5}); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, + 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36, + 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, + 56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); + EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected); + } }; class OpAddScalarOutKernelTest : public OperatorTest { @@ -371,6 +490,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) { EXPECT_TENSOR_EQ(out, ret); } +TEST_F(OpAddOutKernelTest, BroadcastNDTest) { + // Test 3D tensors + test_broadcast_3D(); + test_broadcast_3D(); + test_broadcast_3D(); + + // Test 4D tensors + test_broadcast_4D(); + test_broadcast_4D(); + test_broadcast_4D(); + + // Test broadcasting on the last dimension + test_broadcast_last_dim(); + test_broadcast_last_dim(); + test_broadcast_last_dim(); +} + // // Death Tests // diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 5e7b0a4efe4..4d148aefd0c 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -417,16 +417,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) { test_broadcast_a2b(); test_broadcast_a2b(); test_broadcast_a2b(); - - // Test 3D tensors - test_broadcast_3D(); - test_broadcast_3D(); - test_broadcast_3D(); - - // Test 4D tensors - test_broadcast_4D(); - test_broadcast_4D(); - test_broadcast_4D(); } // Broadcast tensor a's size to tensor b's size