@@ -27,24 +27,25 @@ Tensor& opt_le_tensor_out(
27
27
const Tensor& a,
28
28
const Tensor& b,
29
29
Tensor& out) {
30
- (void )ctx;
31
-
32
30
ScalarType a_type = a.scalar_type ();
33
31
ScalarType out_type = out.scalar_type ();
34
32
33
+ ET_KERNEL_CHECK (
34
+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
35
+
36
+ ET_KERNEL_CHECK (
37
+ ctx,
38
+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
39
+ InvalidArgument,
40
+ out);
41
+
42
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
43
+ static constexpr const char op_name[] = " le.Tensor_out" ;
44
+
35
45
// Check for optimized broadcast paths
36
46
auto selected_optimized_path = select_optimized_path (a, b, out);
37
47
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
38
- // Resize for dynamic shape
39
- auto error = resize_to_broadcast_target_size (a, b, out);
40
- ET_KERNEL_CHECK_MSG (
41
- ctx,
42
- error == Error::Ok,
43
- InvalidArgument,
44
- out,
45
- " Failed to resize output tensor." );
46
-
47
- ET_SWITCH_REALB_TYPES (a_type, ctx, " le.Tensor_out" , CTYPE, [&]() {
48
+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
48
49
using Vec = at::vec::Vectorized<CTYPE>;
49
50
at::vec::map2<CTYPE>(
50
51
[](Vec x, Vec y) { return x.le (y); },
@@ -55,16 +56,13 @@ Tensor& opt_le_tensor_out(
55
56
});
56
57
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
57
58
// Handle optimized broadcast cases
58
- ET_SWITCH_REALB_TYPES (out_type, ctx, " le.Tensor_out " , CTYPE, [&]() {
59
+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
59
60
auto le_lambda = [](auto x, auto y) { return x.le (y); };
60
61
torch::executor::handle_broadcast_elementwise<CTYPE>(
61
62
ctx, le_lambda, a, b, out, selected_optimized_path);
62
63
});
63
64
} else {
64
- // @lint-ignore CLANGTIDY facebook-hte-CArray
65
- static constexpr const char op_name[] = " le.Tensor_out" ;
66
- return internal::comparison_tensor_out<std::less_equal, op_name>(
67
- ctx, a, b, out);
65
+ internal::comparison_tensor_out<std::less_equal, op_name>(ctx, a, b, out);
68
66
}
69
67
70
68
return out;
@@ -75,66 +73,37 @@ Tensor& opt_le_scalar_out(
75
73
const Tensor& a,
76
74
const Scalar& b,
77
75
Tensor& out) {
78
- (void )ctx;
79
-
80
- // Resize for dynamic shape
81
- auto error = resize_tensor (out, a.sizes ());
82
- ET_KERNEL_CHECK_MSG (
83
- ctx,
84
- error == Error::Ok,
85
- InvalidArgument,
86
- out,
87
- " Failed to resize output tensor." );
88
-
89
76
ScalarType a_type = a.scalar_type ();
90
77
ScalarType b_type = utils::get_scalar_dtype (b);
91
- ScalarType common_type = promoteTypes (a_type, b_type );
78
+ ScalarType common_type = utils::promote_type_with_scalar (a_type, b );
92
79
ScalarType out_type = out.scalar_type ();
93
80
94
- if (a_type == common_type && a_type == out_type) {
95
- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " le.Scalar_out" , CTYPE, [&]() {
96
- ET_SWITCH_REAL_TYPES_AND (
97
- Bool, b_type, ctx, " le.Scalar_out" , CTYPE_B, [&]() {
98
- CTYPE_B b_val = 0 ;
99
- ET_EXTRACT_SCALAR (b, b_val);
100
- CTYPE b_casted = static_cast <CTYPE>(b_val);
101
- using Vec = at::vec::Vectorized<CTYPE>;
102
- at::vec::map<CTYPE>(
103
- [b_casted](Vec x) { return x.le (Vec (b_casted)); },
104
- out.mutable_data_ptr <CTYPE>(),
105
- a.const_data_ptr <CTYPE>(),
106
- a.numel ());
107
- });
81
+ ET_KERNEL_CHECK (
82
+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
83
+
84
+ ET_KERNEL_CHECK (
85
+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
86
+
87
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
88
+ static constexpr const char op_name[] = " le.Scalar_out" ;
89
+
90
+ if (a_type == common_type && a_type == out_type &&
91
+ a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
92
+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
93
+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
94
+ CTYPE_B b_val = 0 ;
95
+ ET_EXTRACT_SCALAR (b, b_val);
96
+ CTYPE b_casted = static_cast <CTYPE>(b_val);
97
+ using Vec = at::vec::Vectorized<CTYPE>;
98
+ at::vec::map<CTYPE>(
99
+ [b_casted](Vec x) { return x.le (Vec (b_casted)); },
100
+ out.mutable_data_ptr <CTYPE>(),
101
+ a.const_data_ptr <CTYPE>(),
102
+ a.numel ());
103
+ });
108
104
});
109
105
} else {
110
- ET_SWITCH_REAL_TYPES_AND (
111
- Bool, a_type, ctx, " le.Scalar_out" , CTYPE_A, [&]() {
112
- ET_SWITCH_REAL_TYPES_AND (
113
- Bool, b_type, ctx, " le.Scalar_out" , CTYPE_B, [&]() {
114
- ET_SWITCH_REAL_TYPES_AND (
115
- Bool, common_type, ctx, " le.Scalar_out" , CTYPE_IN, [&]() {
116
- ET_SWITCH_REAL_TYPES_AND (
117
- Bool,
118
- out_type,
119
- ctx,
120
- " le.Scalar_out" ,
121
- CTYPE_OUT,
122
- [&]() {
123
- CTYPE_B b_val = 0 ;
124
- ET_EXTRACT_SCALAR (b, b_val);
125
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
126
- const size_t n = a.numel ();
127
- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
128
- CTYPE_OUT* out_data =
129
- out.mutable_data_ptr <CTYPE_OUT>();
130
- for (auto i = 0 ; i < n; ++i) {
131
- out_data[i] = static_cast <CTYPE_OUT>(
132
- static_cast <CTYPE_IN>(a_data[i]) <= b_casted);
133
- }
134
- });
135
- });
136
- });
137
- });
106
+ internal::comparison_scalar_out<std::less_equal, op_name>(ctx, a, b, out);
138
107
}
139
108
140
109
return out;
0 commit comments