Skip to content

Commit 0e17f38

Browse files
committed
Update on "[Exutorch] Add broadcast support for le op"
For refactored hf repro requires this to support mask generation Differential Revision: [D76456398](https://our.internmc.facebook.com/intern/diff/D76456398/) [ghstack-poisoned]
1 parent eeb4375 commit 0e17f38

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

kernels/optimized/cpu/op_le.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Tensor& opt_le_tensor_out(
8585
auto selected_optimized_path = select_optimized_path(a, b, out);
8686
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
8787
// Resize for dynamic shape
88-
auto error = resize_tensor(out, a.sizes());
88+
auto error = resize_to_broadcast_target_size(a, b, out);
8989
ET_KERNEL_CHECK_MSG(
9090
ctx,
9191
error == Error::Ok,

kernels/test/op_le_test.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,35 @@ TEST_F(OpLeTensorOutTest, Broadcast2dBy1dTest) {
962962
EXPECT_TENSOR_EQ(out, tf_bool.make({3, 4}, expected_data));
963963
}
964964

965+
TEST_F(OpLeTensorOutTest, Broadcast1DTo2DShapeTest) {
966+
TensorFactory<ScalarType::Int> tf;
967+
TensorFactory<ScalarType::Bool> tf_bool;
968+
969+
// Test case: (6,) and (1, 6) -> (1, 6)
970+
Tensor a = tf.make({6}, {1, 3, 5, 7, 9, 11});
971+
Tensor b = tf.make({1, 6}, {2, 4, 6, 8, 10, 12});
972+
973+
Tensor out = tf_bool.zeros({1, 6});
974+
975+
op_le_tensor_out(a, b, out);
976+
977+
// Expected: a[i] <= b[0,i] for all i
978+
// [1, 3, 5, 7, 9, 11] <= [2, 4, 6, 8, 10, 12]
979+
using ctype =
980+
executorch::runtime::testing::internal::ScalarTypeToCppTypeWrapper<
981+
ScalarType::Bool>::ctype;
982+
std::vector<ctype> expected_data = {
983+
true, // 1 <= 2
984+
true, // 3 <= 4
985+
true, // 5 <= 6
986+
true, // 7 <= 8
987+
true, // 9 <= 10
988+
true // 11 <= 12
989+
};
990+
991+
EXPECT_TENSOR_EQ(out, tf_bool.make({1, 6}, expected_data));
992+
}
993+
965994
TEST_F(OpLeTensorOutTest, Broadcast2dBy1dReverseTest) {
966995
TensorFactory<ScalarType::Int> tf;
967996
TensorFactory<ScalarType::Bool> tf_bool;

0 commit comments

Comments
 (0)