Skip to content

Commit a358620

Browse files
[ET][Portable][Build Size] Introduce comparison op pattern. Binary ops: eq, ge, gt, le, lt, ne
- ge: 1.1 M -> 15 K (40x reduction!) - gt: 1.1 M -> 15 K - lt: 1.1 M -> 15 K - lt: 1.1 M -> 15 K - ne: 511 K -> 15 K - eq: 504 K -> 15 K Differential Revision: [D63914950](https://our.internmc.facebook.com/intern/diff/D63914950/) [ghstack-poisoned]
1 parent 07dd865 commit a358620

File tree

9 files changed

+199
-537
lines changed

9 files changed

+199
-537
lines changed

kernels/portable/cpu/op_eq.cpp

Lines changed: 5 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,111 +6,28 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
12-
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
9+
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1410

1511
namespace torch {
1612
namespace executor {
1713
namespace native {
1814

19-
using Tensor = exec_aten::Tensor;
20-
using ScalarType = exec_aten::ScalarType;
21-
2215
Tensor& eq_tensor_out(
2316
KernelRuntimeContext& ctx,
2417
const Tensor& a,
2518
const Tensor& b,
2619
Tensor& out) {
27-
ET_KERNEL_CHECK(
28-
ctx,
29-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
30-
InvalidArgument,
31-
out);
32-
33-
ScalarType a_type = a.scalar_type();
34-
ScalarType b_type = b.scalar_type();
35-
ScalarType out_type = out.scalar_type();
36-
37-
ET_KERNEL_CHECK(
38-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39-
40-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
41-
ET_SWITCH_REAL_TYPES_AND(
42-
Bool, b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
43-
using CTYPE_IN =
44-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
45-
ET_DCHECK(
46-
CppTypeToScalarType<CTYPE_IN>::value ==
47-
promoteTypes(a_type, b_type));
48-
ET_SWITCH_REAL_TYPES_AND(
49-
Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() {
50-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
51-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
52-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
53-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
54-
bool value = a_casted == b_casted;
55-
return static_cast<CTYPE_OUT>(value);
56-
},
57-
a,
58-
b,
59-
out);
60-
});
61-
});
62-
});
63-
64-
return out;
20+
static constexpr const char op_name[] = "eq.Tensor_out";
21+
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
6522
}
6623

6724
Tensor& eq_scalar_out(
6825
KernelRuntimeContext& ctx,
6926
const Tensor& a,
7027
const Scalar& b,
7128
Tensor& out) {
72-
(void)ctx;
73-
74-
// Resize for dynamic shape
75-
ET_KERNEL_CHECK_MSG(
76-
ctx,
77-
resize_tensor(out, a.sizes()) == Error::Ok,
78-
InvalidArgument,
79-
out,
80-
"Failed to resize output tensor.");
81-
82-
ScalarType a_type = a.scalar_type();
83-
ScalarType b_type = utils::get_scalar_dtype(b);
84-
ScalarType out_type = out.scalar_type();
85-
86-
ET_KERNEL_CHECK(
87-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
88-
89-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
90-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
91-
using CTYPE_IN =
92-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
93-
ET_DCHECK(
94-
CppTypeToScalarType<CTYPE_IN>::value == promoteTypes(a_type, b_type));
95-
ET_SWITCH_REAL_TYPES_AND(
96-
Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() {
97-
CTYPE_B val_b = 0;
98-
utils::extract_scalar(b, &val_b);
99-
apply_unary_map_fn(
100-
[val_b](const CTYPE_A val_a) {
101-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
102-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
103-
bool value = a_casted == b_casted;
104-
return static_cast<CTYPE_OUT>(value);
105-
},
106-
a.const_data_ptr<CTYPE_A>(),
107-
out.mutable_data_ptr<CTYPE_OUT>(),
108-
out.numel());
109-
});
110-
});
111-
});
112-
113-
return out;
29+
static constexpr const char op_name[] = "eq.Scalar_out";
30+
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
11431
}
11532

11633
} // namespace native

kernels/portable/cpu/op_ge.cpp

Lines changed: 5 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,28 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
12-
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
9+
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1410

1511
namespace torch {
1612
namespace executor {
1713
namespace native {
1814

19-
using Tensor = exec_aten::Tensor;
20-
using ScalarType = exec_aten::ScalarType;
21-
2215
Tensor& ge_tensor_out(
2316
KernelRuntimeContext& ctx,
2417
const Tensor& a,
2518
const Tensor& b,
2619
Tensor& out) {
27-
// Determine output size and resize for dynamic shapes
28-
ET_KERNEL_CHECK(
29-
ctx,
30-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
31-
InvalidArgument,
32-
out);
33-
34-
ET_KERNEL_CHECK(
35-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
36-
37-
ScalarType a_type = a.scalar_type();
38-
ScalarType b_type = b.scalar_type();
39-
ScalarType out_type = out.scalar_type();
40-
41-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ge.Tensor_out", CTYPE_A, [&]() {
42-
ET_SWITCH_REAL_TYPES_AND(
43-
Bool, b_type, ctx, "ge.Tensor_out", CTYPE_B, [&]() {
44-
using CTYPE_IN =
45-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
46-
ET_DCHECK(
47-
CppTypeToScalarType<CTYPE_IN>::value ==
48-
promoteTypes(a_type, b_type));
49-
ET_SWITCH_REAL_TYPES_AND(
50-
Bool, out_type, ctx, "ge.Tensor_out", CTYPE_OUT, [&]() {
51-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
53-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
54-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
55-
bool value = a_casted >= b_casted;
56-
return static_cast<CTYPE_OUT>(value);
57-
},
58-
a,
59-
b,
60-
out);
61-
});
62-
});
63-
});
64-
65-
return out;
20+
static constexpr const char op_name[] = "ge.Tensor_out";
21+
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
6622
}
6723

6824
Tensor& ge_scalar_out(
6925
KernelRuntimeContext& ctx,
7026
const Tensor& a,
7127
const Scalar& b,
7228
Tensor& out) {
73-
(void)ctx;
74-
75-
// Resize for dynamic shape
76-
ET_KERNEL_CHECK_MSG(
77-
ctx,
78-
resize_tensor(out, a.sizes()) == Error::Ok,
79-
InvalidArgument,
80-
out,
81-
"Failed to resize output tensor.");
82-
83-
ET_KERNEL_CHECK(
84-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
85-
86-
ScalarType a_type = a.scalar_type();
87-
ScalarType b_type = utils::get_scalar_dtype(b);
88-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
89-
ScalarType out_type = out.scalar_type();
90-
91-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ge.Scalar_out", CTYPE_A, [&]() {
92-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "ge.Scalar_out", CTYPE_B, [&]() {
93-
ET_SWITCH_REAL_TYPES_AND(
94-
Bool, common_type, ctx, "ge.Scalar_out", CTYPE_IN, [&]() {
95-
ET_SWITCH_REAL_TYPES_AND(
96-
Bool, out_type, ctx, "ge.Scalar_out", CTYPE_OUT, [&]() {
97-
CTYPE_B val_b = 0;
98-
utils::extract_scalar(b, &val_b);
99-
apply_unary_map_fn(
100-
[val_b](const CTYPE_A val_a) {
101-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
102-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
103-
bool value = a_casted >= b_casted;
104-
return static_cast<CTYPE_OUT>(value);
105-
},
106-
a.const_data_ptr<CTYPE_A>(),
107-
out.mutable_data_ptr<CTYPE_OUT>(),
108-
out.numel());
109-
});
110-
});
111-
});
112-
});
113-
114-
return out;
29+
static constexpr const char op_name[] = "ge.Scalar_out";
30+
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
11531
}
11632

11733
} // namespace native

kernels/portable/cpu/op_gt.cpp

Lines changed: 5 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,28 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
12-
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
9+
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1410

1511
namespace torch {
1612
namespace executor {
1713
namespace native {
1814

19-
using Tensor = exec_aten::Tensor;
20-
using ScalarType = exec_aten::ScalarType;
21-
2215
Tensor& gt_tensor_out(
2316
KernelRuntimeContext& ctx,
2417
const Tensor& a,
2518
const Tensor& b,
2619
Tensor& out) {
27-
// Determine output size and resize for dynamic shapes
28-
ET_KERNEL_CHECK(
29-
ctx,
30-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
31-
InvalidArgument,
32-
out);
33-
34-
ET_KERNEL_CHECK(
35-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
36-
37-
ScalarType a_type = a.scalar_type();
38-
ScalarType b_type = b.scalar_type();
39-
ScalarType out_type = out.scalar_type();
40-
41-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "gt.Tensor_out", CTYPE_A, [&]() {
42-
ET_SWITCH_REAL_TYPES_AND(
43-
Bool, b_type, ctx, "gt.Tensor_out", CTYPE_B, [&]() {
44-
using CTYPE_IN =
45-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
46-
ET_DCHECK(
47-
CppTypeToScalarType<CTYPE_IN>::value ==
48-
promoteTypes(a_type, b_type));
49-
ET_SWITCH_REAL_TYPES_AND(
50-
Bool, out_type, ctx, "gt.Tensor_out", CTYPE_OUT, [&]() {
51-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
53-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
54-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
55-
bool value = a_casted > b_casted;
56-
return static_cast<CTYPE_OUT>(value);
57-
},
58-
a,
59-
b,
60-
out);
61-
});
62-
});
63-
});
64-
65-
return out;
20+
static constexpr const char op_name[] = "gt.Tensor_out";
21+
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
6622
}
6723

6824
Tensor& gt_scalar_out(
6925
KernelRuntimeContext& ctx,
7026
const Tensor& a,
7127
const Scalar& b,
7228
Tensor& out) {
73-
(void)ctx;
74-
75-
// Resize for dynamic shape
76-
ET_KERNEL_CHECK_MSG(
77-
ctx,
78-
resize_tensor(out, a.sizes()) == Error::Ok,
79-
InvalidArgument,
80-
out,
81-
"Failed to resize output tensor.");
82-
83-
ET_KERNEL_CHECK(
84-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
85-
86-
ScalarType a_type = a.scalar_type();
87-
ScalarType b_type = utils::get_scalar_dtype(b);
88-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
89-
ScalarType out_type = out.scalar_type();
90-
91-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "gt.Scalar_out", CTYPE_A, [&]() {
92-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "gt.Scalar_out", CTYPE_B, [&]() {
93-
ET_SWITCH_REAL_TYPES_AND(
94-
Bool, common_type, ctx, "gt.Scalar_out", CTYPE_IN, [&]() {
95-
ET_SWITCH_REAL_TYPES_AND(
96-
Bool, out_type, ctx, "gt.Scalar_out", CTYPE_OUT, [&]() {
97-
CTYPE_B val_b = 0;
98-
utils::extract_scalar(b, &val_b);
99-
apply_unary_map_fn(
100-
[val_b](const CTYPE_A val_a) {
101-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
102-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
103-
bool value = a_casted > b_casted;
104-
return static_cast<CTYPE_OUT>(value);
105-
},
106-
a.const_data_ptr<CTYPE_A>(),
107-
out.mutable_data_ptr<CTYPE_OUT>(),
108-
out.numel());
109-
});
110-
});
111-
});
112-
});
113-
114-
return out;
29+
static constexpr const char op_name[] = "gt.Scalar_out";
30+
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
11531
}
11632

11733
} // namespace native

0 commit comments

Comments
 (0)