diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index ce82e49cc27..51ff4fbd571 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -134,8 +134,8 @@ Tensor& opt_sub_out( } }); }); + return out; } - return out; } auto selected_optimized_path = select_optimized_path(a, b, out); diff --git a/kernels/test/op_sub_test.cpp b/kernels/test/op_sub_test.cpp index 886adaf2e9d..f0285bc85e9 100644 --- a/kernels/test/op_sub_test.cpp +++ b/kernels/test/op_sub_test.cpp @@ -107,6 +107,27 @@ class OpSubOutTest : public OperatorTest { #undef ENUMERATE_TEST_ENTRY } + + template + void test_broadcast_rank1_scalar() { + TensorFactory tf; + + Tensor a = tf.make({2, 1, 3}, {2, 3, 4, 5, 6, 7}); + Tensor b = tf.make({1}, {2}); + + // Destination for the broadcasting div. Follow the broadcasting rules in + // https://fburl.com/n9wl4d0o + Tensor out = tf.zeros({2, 1, 3}); + + op_sub_out(a, b, 1, out); + + Tensor ret = tf.make({2, 1, 3}, {0, 1, 2, 3, 4, 5}); + EXPECT_TENSOR_EQ(out, ret); + + op_sub_out(b, a, 1, out); + ret = tf.make({2, 1, 3}, {0, -1, -2, -3, -4, -5}); + EXPECT_TENSOR_EQ(out, ret); + } }; class OpSubScalarOutTest : public OperatorTest { @@ -171,19 +192,8 @@ TEST_F(OpSubOutTest, BroadcastSupported2) { } TEST_F(OpSubOutTest, BroadcastScalarSupported1) { - TensorFactory tf; - - Tensor a = tf.make({2, 1, 3}, {2, 3, 4, 5, 6, 7}); - Tensor b = tf.make({1}, {2}); - - // Destination for the broadcasting div. Follow the broadcasting rules in - // https://fburl.com/n9wl4d0o - Tensor out = tf.zeros({2, 1, 3}); - - op_sub_out(a, b, 1, out); - - Tensor ret = tf.make({2, 1, 3}, {0, 1, 2, 3, 4, 5}); - EXPECT_TENSOR_EQ(out, ret); + test_broadcast_rank1_scalar(); + test_broadcast_rank1_scalar(); } TEST_F(OpSubOutTest, BroadcastScalarSupported2) {