Skip to content

Commit 78f0e67

Browse files
Enable optimized op_le broadcast against 1 element tensor
Differential Revision: D71652593 Pull Request resolved: #9507
1 parent fb24360 commit 78f0e67

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

kernels/optimized/cpu/op_le.cpp

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/kernels/optimized/vec/functional.h>
1010
#include <executorch/kernels/optimized/vec/vec.h>
1111
#include <executorch/kernels/portable/cpu/scalar_utils.h>
12+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314
#include <executorch/runtime/platform/assert.h>
1415

@@ -26,6 +27,58 @@ Tensor& opt_le_tensor_out(
2627
Tensor& out) {
2728
(void)ctx;
2829

30+
ScalarType a_type = a.scalar_type();
31+
ScalarType b_type = b.scalar_type();
32+
ScalarType out_type = out.scalar_type();
33+
34+
if (a.numel() == 1 || b.numel() == 1) {
35+
const Tensor* tensor;
36+
const Tensor* scalar;
37+
ScalarType tensor_type;
38+
ScalarType scalar_type;
39+
if (a.numel() == 1) {
40+
tensor = &b;
41+
tensor_type = b_type;
42+
scalar = &a;
43+
scalar_type = a_type;
44+
} else {
45+
tensor = &a;
46+
tensor_type = a_type;
47+
scalar = &b;
48+
scalar_type = b_type;
49+
}
50+
ET_KERNEL_CHECK(
51+
ctx,
52+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
53+
InvalidArgument,
54+
out);
55+
56+
constexpr auto name = "le.Tensor_out";
57+
58+
ET_SWITCH_REALB_TYPES(tensor_type, ctx, name, CTYPE, [&]() {
59+
ET_SWITCH_REALB_TYPES(scalar_type, ctx, name, CTYPE_SCALAR, [&]() {
60+
CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
61+
CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
62+
63+
using Vec = executorch::vec::Vectorized<CTYPE>;
64+
if (a.numel() == 1) {
65+
executorch::vec::map<CTYPE>(
66+
[scalar_casted](Vec x) { return Vec(scalar_casted).le(x); },
67+
out.mutable_data_ptr<CTYPE>(),
68+
tensor->const_data_ptr<CTYPE>(),
69+
out.numel());
70+
} else {
71+
executorch::vec::map<CTYPE>(
72+
[scalar_casted](Vec x) { return x.le(Vec(scalar_casted)); },
73+
out.mutable_data_ptr<CTYPE>(),
74+
tensor->const_data_ptr<CTYPE>(),
75+
out.numel());
76+
}
77+
});
78+
});
79+
return out;
80+
}
81+
2982
ET_KERNEL_CHECK(ctx, tensors_have_same_shape(a, b), InvalidArgument, out);
3083

3184
// Resize for dynamic shape
@@ -37,10 +90,6 @@ Tensor& opt_le_tensor_out(
3790
out,
3891
"Failed to resize output tensor.");
3992

40-
ScalarType a_type = a.scalar_type();
41-
ScalarType b_type = b.scalar_type();
42-
ScalarType out_type = out.scalar_type();
43-
4493
if (a_type == b_type && a_type == out_type) {
4594
ET_SWITCH_REAL_TYPES_AND(
4695
Bool, out_type, ctx, "le.Tensor_out", CTYPE, [&]() {

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ _OPTIMIZED_ATEN_OPS = (
4646
name = "op_le",
4747
deps = [
4848
"//executorch/kernels/portable/cpu:scalar_utils",
49+
"//executorch/kernels/portable/cpu/util:broadcast_util",
4950
],
5051
),
5152
op_target(

kernels/test/op_le_test.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,15 @@ TEST_F(OpLeTensorOutTest, DynamicOutShapeTest) {
174174
op_le_tensor_out(a, b, out);
175175
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {false, true, true, false}));
176176
}
177+
178+
TEST_F(OpLeTensorOutTest, BroadcastTest) {
179+
TensorFactory<ScalarType::Int> tf;
180+
181+
Tensor a = tf.make(/*sizes=*/{4}, /*data=*/{2, 3, 2, 4});
182+
Tensor b = tf.make({1, 1}, {3});
183+
184+
Tensor out = tf.zeros({1, 4});
185+
186+
op_le_tensor_out(a, b, out);
187+
EXPECT_TENSOR_EQ(out, tf.make({1, 4}, {true, true, true, false}));
188+
}

0 commit comments

Comments
 (0)