Skip to content

Commit 542ea56

Browse files
authored
Fix bug in le broadcast for single element broadcast (#11922)
Summary: Really comparison ops cannot be handled in the same way as other binary ops for broadcasting because output tensor dtype is different than input tensor dtype. As a result we just have to fall back to portable. Even the current vectorized impl for le, ge etc. assumes that the output type of compare is same as input type. That might actually be a bug. A potential way to handle this maybe via vectorized compare natively supporting binary output vector Test Plan: Tests added which fail before and passes after Reviewers: Subscribers: Tasks: Tags: ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent f1b5947 commit 542ea56

File tree

2 files changed

+67
-49
lines changed

2 files changed

+67
-49
lines changed

kernels/optimized/cpu/op_le.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -30,57 +30,8 @@ Tensor& opt_le_tensor_out(
3030
(void)ctx;
3131

3232
ScalarType a_type = a.scalar_type();
33-
ScalarType b_type = b.scalar_type();
3433
ScalarType out_type = out.scalar_type();
3534

36-
if (a.numel() == 1 || b.numel() == 1) {
37-
const Tensor* tensor;
38-
const Tensor* scalar;
39-
ScalarType tensor_type;
40-
ScalarType scalar_type;
41-
if (a.numel() == 1) {
42-
tensor = &b;
43-
tensor_type = b_type;
44-
scalar = &a;
45-
scalar_type = a_type;
46-
} else {
47-
tensor = &a;
48-
tensor_type = a_type;
49-
scalar = &b;
50-
scalar_type = b_type;
51-
}
52-
ET_KERNEL_CHECK(
53-
ctx,
54-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
55-
InvalidArgument,
56-
out);
57-
58-
constexpr auto name = "le.Tensor_out";
59-
60-
ET_SWITCH_REALB_TYPES(tensor_type, ctx, name, CTYPE, [&]() {
61-
ET_SWITCH_REALB_TYPES(scalar_type, ctx, name, CTYPE_SCALAR, [&]() {
62-
CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
63-
CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
64-
65-
using Vec = at::vec::Vectorized<CTYPE>;
66-
if (a.numel() == 1) {
67-
at::vec::map<CTYPE>(
68-
[scalar_casted](Vec x) { return Vec(scalar_casted).le(x); },
69-
out.mutable_data_ptr<CTYPE>(),
70-
tensor->const_data_ptr<CTYPE>(),
71-
out.numel());
72-
} else {
73-
at::vec::map<CTYPE>(
74-
[scalar_casted](Vec x) { return x.le(Vec(scalar_casted)); },
75-
out.mutable_data_ptr<CTYPE>(),
76-
tensor->const_data_ptr<CTYPE>(),
77-
out.numel());
78-
}
79-
});
80-
});
81-
return out;
82-
}
83-
8435
// Check for optimized broadcast paths
8536
auto selected_optimized_path = select_optimized_path(a, b, out);
8637
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {

kernels/test/op_le_test.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,3 +1112,70 @@ TEST_F(OpLeTensorOutTest, Broadcast22dBy1dReverseTest) {
11121112

11131113
EXPECT_TENSOR_EQ(out, tf_bool.make({3, 4}, expected_data));
11141114
}
1115+
1116+
TEST_F(OpLeTensorOutTest, MonotonicIncreasingVsScalarBroadcastTest) {
1117+
TensorFactory<ScalarType::Int> tf;
1118+
TensorFactory<ScalarType::Bool> tf_bool;
1119+
1120+
// Test case: 1D tensor [0, 1, 2, ..., 63] vs 2D tensor [1, 1] with value 2
1121+
std::vector<int32_t> lhs_data;
1122+
for (int i = 0; i < 64; ++i) {
1123+
lhs_data.push_back(i);
1124+
}
1125+
1126+
Tensor lhs = tf.make({64}, lhs_data);
1127+
Tensor rhs = tf.make({1, 1}, {2});
1128+
Tensor out = tf_bool.zeros({1, 64});
1129+
1130+
op_le_tensor_out(lhs, rhs, out);
1131+
1132+
// Expected: [0, 1, 2] <= 2 should be [true, true, true], rest false
1133+
using ctype =
1134+
executorch::runtime::testing::internal::ScalarTypeToCppTypeWrapper<
1135+
ScalarType::Bool>::ctype;
1136+
std::vector<ctype> expected_data;
1137+
for (int i = 0; i < 64; ++i) {
1138+
expected_data.push_back(i <= 2);
1139+
}
1140+
1141+
EXPECT_TENSOR_EQ(out, tf_bool.make({1, 64}, expected_data));
1142+
1143+
// Test with rhs value 4
1144+
rhs = tf.make({1, 1}, {4});
1145+
out = tf_bool.zeros({1, 64});
1146+
1147+
op_le_tensor_out(lhs, rhs, out);
1148+
1149+
expected_data.clear();
1150+
for (int i = 0; i < 64; ++i) {
1151+
expected_data.push_back(i <= 4);
1152+
}
1153+
1154+
EXPECT_TENSOR_EQ(out, tf_bool.make({1, 64}, expected_data));
1155+
1156+
// Test with rhs value 10
1157+
rhs = tf.make({1, 1}, {10});
1158+
out = tf_bool.zeros({1, 64});
1159+
1160+
op_le_tensor_out(lhs, rhs, out);
1161+
1162+
expected_data.clear();
1163+
for (int i = 0; i < 64; ++i) {
1164+
expected_data.push_back(i <= 10);
1165+
}
1166+
1167+
EXPECT_TENSOR_EQ(out, tf_bool.make({1, 64}, expected_data));
1168+
1169+
// Test with rhs value 32
1170+
rhs = tf.make({1, 1}, {32});
1171+
out = tf_bool.zeros({1, 64});
1172+
1173+
op_le_tensor_out(lhs, rhs, out);
1174+
1175+
expected_data.clear();
1176+
for (int i = 0; i < 64; ++i) {
1177+
expected_data.push_back(i <= 32);
1178+
}
1179+
1180+
EXPECT_TENSOR_EQ(out, tf_bool.make({1, 64}, expected_data));
1181+
}

0 commit comments

Comments
 (0)