Skip to content

Commit 8f724e0

Browse files
Clean up optimized op_le
Differential Revision: D81199585 Pull Request resolved: #13765
1 parent 537d30b commit 8f724e0

File tree

2 files changed

+46
-77
lines changed

2 files changed

+46
-77
lines changed

kernels/optimized/cpu/op_le.cpp

Lines changed: 40 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,25 @@ Tensor& opt_le_tensor_out(
2727
const Tensor& a,
2828
const Tensor& b,
2929
Tensor& out) {
30-
(void)ctx;
31-
3230
ScalarType a_type = a.scalar_type();
3331
ScalarType out_type = out.scalar_type();
3432

33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
35+
36+
ET_KERNEL_CHECK(
37+
ctx,
38+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
39+
InvalidArgument,
40+
out);
41+
42+
// @lint-ignore CLANGTIDY facebook-hte-CArray
43+
static constexpr const char op_name[] = "le.Tensor_out";
44+
3545
// Check for optimized broadcast paths
3646
auto selected_optimized_path = select_optimized_path(a, b, out);
3747
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
38-
// Resize for dynamic shape
39-
auto error = resize_to_broadcast_target_size(a, b, out);
40-
ET_KERNEL_CHECK_MSG(
41-
ctx,
42-
error == Error::Ok,
43-
InvalidArgument,
44-
out,
45-
"Failed to resize output tensor.");
46-
47-
ET_SWITCH_REALB_TYPES(a_type, ctx, "le.Tensor_out", CTYPE, [&]() {
48+
ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
4849
using Vec = at::vec::Vectorized<CTYPE>;
4950
at::vec::map2<CTYPE>(
5051
[](Vec x, Vec y) { return x.le(y); },
@@ -55,16 +56,13 @@ Tensor& opt_le_tensor_out(
5556
});
5657
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
5758
// Handle optimized broadcast cases
58-
ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
59+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
5960
auto le_lambda = [](auto x, auto y) { return x.le(y); };
6061
torch::executor::handle_broadcast_elementwise<CTYPE>(
6162
ctx, le_lambda, a, b, out, selected_optimized_path);
6263
});
6364
} else {
64-
// @lint-ignore CLANGTIDY facebook-hte-CArray
65-
static constexpr const char op_name[] = "le.Tensor_out";
66-
return internal::comparison_tensor_out<std::less_equal, op_name>(
67-
ctx, a, b, out);
65+
internal::comparison_tensor_out<std::less_equal, op_name>(ctx, a, b, out);
6866
}
6967

7068
return out;
@@ -75,66 +73,37 @@ Tensor& opt_le_scalar_out(
7573
const Tensor& a,
7674
const Scalar& b,
7775
Tensor& out) {
78-
(void)ctx;
79-
80-
// Resize for dynamic shape
81-
auto error = resize_tensor(out, a.sizes());
82-
ET_KERNEL_CHECK_MSG(
83-
ctx,
84-
error == Error::Ok,
85-
InvalidArgument,
86-
out,
87-
"Failed to resize output tensor.");
88-
8976
ScalarType a_type = a.scalar_type();
9077
ScalarType b_type = utils::get_scalar_dtype(b);
91-
ScalarType common_type = promoteTypes(a_type, b_type);
78+
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
9279
ScalarType out_type = out.scalar_type();
9380

94-
if (a_type == common_type && a_type == out_type) {
95-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE, [&]() {
96-
ET_SWITCH_REAL_TYPES_AND(
97-
Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
98-
CTYPE_B b_val = 0;
99-
ET_EXTRACT_SCALAR(b, b_val);
100-
CTYPE b_casted = static_cast<CTYPE>(b_val);
101-
using Vec = at::vec::Vectorized<CTYPE>;
102-
at::vec::map<CTYPE>(
103-
[b_casted](Vec x) { return x.le(Vec(b_casted)); },
104-
out.mutable_data_ptr<CTYPE>(),
105-
a.const_data_ptr<CTYPE>(),
106-
a.numel());
107-
});
81+
ET_KERNEL_CHECK(
82+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
83+
84+
ET_KERNEL_CHECK(
85+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
86+
87+
// @lint-ignore CLANGTIDY facebook-hte-CArray
88+
static constexpr const char op_name[] = "le.Scalar_out";
89+
90+
if (a_type == common_type && a_type == out_type &&
91+
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
92+
ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
93+
ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() {
94+
CTYPE_B b_val = 0;
95+
ET_EXTRACT_SCALAR(b, b_val);
96+
CTYPE b_casted = static_cast<CTYPE>(b_val);
97+
using Vec = at::vec::Vectorized<CTYPE>;
98+
at::vec::map<CTYPE>(
99+
[b_casted](Vec x) { return x.le(Vec(b_casted)); },
100+
out.mutable_data_ptr<CTYPE>(),
101+
a.const_data_ptr<CTYPE>(),
102+
a.numel());
103+
});
108104
});
109105
} else {
110-
ET_SWITCH_REAL_TYPES_AND(
111-
Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() {
112-
ET_SWITCH_REAL_TYPES_AND(
113-
Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
114-
ET_SWITCH_REAL_TYPES_AND(
115-
Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() {
116-
ET_SWITCH_REAL_TYPES_AND(
117-
Bool,
118-
out_type,
119-
ctx,
120-
"le.Scalar_out",
121-
CTYPE_OUT,
122-
[&]() {
123-
CTYPE_B b_val = 0;
124-
ET_EXTRACT_SCALAR(b, b_val);
125-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
126-
const size_t n = a.numel();
127-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
128-
CTYPE_OUT* out_data =
129-
out.mutable_data_ptr<CTYPE_OUT>();
130-
for (auto i = 0; i < n; ++i) {
131-
out_data[i] = static_cast<CTYPE_OUT>(
132-
static_cast<CTYPE_IN>(a_data[i]) <= b_casted);
133-
}
134-
});
135-
});
136-
});
137-
});
106+
internal::comparison_scalar_out<std::less_equal, op_name>(ctx, a, b, out);
138107
}
139108

140109
return out;

kernels/test/op_le_test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ TEST_F(OpLeScalarOutTest, AllRealInputBoolOutputSupport) {
6767
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
6868
test_le_scalar_out<ScalarType::dtype_in, ScalarType::dtype_out>();
6969

70-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
70+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
7272
test_le_scalar_out<ScalarType::dtype_in, ScalarType::Bool>();
7373

74-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES)
74+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES)
7575

7676
#undef TEST_FORALL_OUT_TYPES
7777
#undef TEST_ENTRY
@@ -124,11 +124,11 @@ TEST_F(OpLeTensorOutTest, AllDtypesSupported) {
124124
#define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
125125
test_dtype<ScalarType::dtype_in, ScalarType::dtype_out>();
126126

127-
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128-
ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
127+
#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128+
ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
129129
test_dtype<ScalarType::dtype_in, ScalarType::Bool>();
130130

131-
ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES);
131+
ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES);
132132

133133
#undef TEST_FORALL_OUT_TYPES
134134
#undef TEST_ENTRY

0 commit comments

Comments
 (0)