Skip to content

Commit 35c8c8d

Browse files
committed
Update on "[Exutorch] Add broadcast support for le op"
For refactored hf repro requires this to support mask generation Differential Revision: [D76456398](https://our.internmc.facebook.com/intern/diff/D76456398/) [ghstack-poisoned]
1 parent 0e17f38 commit 35c8c8d

File tree

1 file changed

+89
-1
lines changed

1 file changed

+89
-1
lines changed

kernels/test/op_le_test.cpp

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,95 @@ TEST_F(OpLeTensorOutTest, Broadcast1DTo2DShapeTest) {
991991
EXPECT_TENSOR_EQ(out, tf_bool.make({1, 6}, expected_data));
992992
}
993993

994-
TEST_F(OpLeTensorOutTest, Broadcast2dBy1dReverseTest) {
994+
TEST_F(OpLeTensorOutTest, Broadcast2DBy1DShapeTest) {
995+
TensorFactory<ScalarType::Int> tf;
996+
TensorFactory<ScalarType::Bool> tf_bool;
997+
998+
// Test case: (10,) and (6, 1) -> (6, 10)
999+
Tensor a = tf.make({10}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
1000+
Tensor b = tf.make({6, 1}, {2, 4, 6, 8, 10, 12});
1001+
1002+
Tensor out = tf_bool.zeros({6, 10});
1003+
1004+
op_le_tensor_out(a, b, out);
1005+
1006+
// Expected: a[j] <= b[i,0] for all i,j
1007+
// Each row i should be [a[0]<=b[i,0], a[1]<=b[i,0], ..., a[9]<=b[i,0]]
1008+
using ctype =
1009+
executorch::runtime::testing::internal::ScalarTypeToCppTypeWrapper<
1010+
ScalarType::Bool>::ctype;
1011+
std::vector<ctype> expected_data = {
1012+
// Row 0 (b=2): [1,2,3,4,5,6,7,8,9,10] <= 2
1013+
true,
1014+
true,
1015+
false,
1016+
false,
1017+
false,
1018+
false,
1019+
false,
1020+
false,
1021+
false,
1022+
false,
1023+
// Row 1 (b=4): [1,2,3,4,5,6,7,8,9,10] <= 4
1024+
true,
1025+
true,
1026+
true,
1027+
true,
1028+
false,
1029+
false,
1030+
false,
1031+
false,
1032+
false,
1033+
false,
1034+
// Row 2 (b=6): [1,2,3,4,5,6,7,8,9,10] <= 6
1035+
true,
1036+
true,
1037+
true,
1038+
true,
1039+
true,
1040+
true,
1041+
false,
1042+
false,
1043+
false,
1044+
false,
1045+
// Row 3 (b=8): [1,2,3,4,5,6,7,8,9,10] <= 8
1046+
true,
1047+
true,
1048+
true,
1049+
true,
1050+
true,
1051+
true,
1052+
true,
1053+
true,
1054+
false,
1055+
false,
1056+
// Row 4 (b=10): [1,2,3,4,5,6,7,8,9,10] <= 10
1057+
true,
1058+
true,
1059+
true,
1060+
true,
1061+
true,
1062+
true,
1063+
true,
1064+
true,
1065+
true,
1066+
true,
1067+
// Row 5 (b=12): [1,2,3,4,5,6,7,8,9,10] <= 12
1068+
true,
1069+
true,
1070+
true,
1071+
true,
1072+
true,
1073+
true,
1074+
true,
1075+
true,
1076+
true,
1077+
true};
1078+
1079+
EXPECT_TENSOR_EQ(out, tf_bool.make({6, 10}, expected_data));
1080+
}
1081+
1082+
TEST_F(OpLeTensorOutTest, Broadcast22dBy1dReverseTest) {
9951083
TensorFactory<ScalarType::Int> tf;
9961084
TensorFactory<ScalarType::Bool> tf_bool;
9971085

0 commit comments

Comments
 (0)