@@ -27,24 +27,25 @@ Tensor& opt_le_tensor_out(
2727 const Tensor& a,
2828 const Tensor& b,
2929 Tensor& out) {
30- (void )ctx;
31-
3230 ScalarType a_type = a.scalar_type ();
3331 ScalarType out_type = out.scalar_type ();
3432
33+ ET_KERNEL_CHECK (
34+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
35+
36+ ET_KERNEL_CHECK (
37+ ctx,
38+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
39+ InvalidArgument,
40+ out);
41+
42+ // @lint-ignore CLANGTIDY facebook-hte-CArray
43+ static constexpr const char op_name[] = " le.Tensor_out" ;
44+
3545 // Check for optimized broadcast paths
3646 auto selected_optimized_path = select_optimized_path (a, b, out);
3747 if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
38- // Resize for dynamic shape
39- auto error = resize_to_broadcast_target_size (a, b, out);
40- ET_KERNEL_CHECK_MSG (
41- ctx,
42- error == Error::Ok,
43- InvalidArgument,
44- out,
45- " Failed to resize output tensor." );
46-
47- ET_SWITCH_REALB_TYPES (a_type, ctx, " le.Tensor_out" , CTYPE, [&]() {
48+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
4849 using Vec = at::vec::Vectorized<CTYPE>;
4950 at::vec::map2<CTYPE>(
5051 [](Vec x, Vec y) { return x.le (y); },
@@ -55,16 +56,13 @@ Tensor& opt_le_tensor_out(
5556 });
5657 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
5758 // Handle optimized broadcast cases
58- ET_SWITCH_REALB_TYPES (out_type, ctx, " le.Tensor_out " , CTYPE, [&]() {
59+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
5960 auto le_lambda = [](auto x, auto y) { return x.le (y); };
6061 torch::executor::handle_broadcast_elementwise<CTYPE>(
6162 ctx, le_lambda, a, b, out, selected_optimized_path);
6263 });
6364 } else {
64- // @lint-ignore CLANGTIDY facebook-hte-CArray
65- static constexpr const char op_name[] = " le.Tensor_out" ;
66- return internal::comparison_tensor_out<std::less_equal, op_name>(
67- ctx, a, b, out);
65+ internal::comparison_tensor_out<std::less_equal, op_name>(ctx, a, b, out);
6866 }
6967
7068 return out;
@@ -75,66 +73,37 @@ Tensor& opt_le_scalar_out(
7573 const Tensor& a,
7674 const Scalar& b,
7775 Tensor& out) {
78- (void )ctx;
79-
80- // Resize for dynamic shape
81- auto error = resize_tensor (out, a.sizes ());
82- ET_KERNEL_CHECK_MSG (
83- ctx,
84- error == Error::Ok,
85- InvalidArgument,
86- out,
87- " Failed to resize output tensor." );
88-
8976 ScalarType a_type = a.scalar_type ();
9077 ScalarType b_type = utils::get_scalar_dtype (b);
91- ScalarType common_type = promoteTypes (a_type, b_type );
78+ ScalarType common_type = utils::promote_type_with_scalar (a_type, b );
9279 ScalarType out_type = out.scalar_type ();
9380
94- if (a_type == common_type && a_type == out_type) {
95- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " le.Scalar_out" , CTYPE, [&]() {
96- ET_SWITCH_REAL_TYPES_AND (
97- Bool, b_type, ctx, " le.Scalar_out" , CTYPE_B, [&]() {
98- CTYPE_B b_val = 0 ;
99- ET_EXTRACT_SCALAR (b, b_val);
100- CTYPE b_casted = static_cast <CTYPE>(b_val);
101- using Vec = at::vec::Vectorized<CTYPE>;
102- at::vec::map<CTYPE>(
103- [b_casted](Vec x) { return x.le (Vec (b_casted)); },
104- out.mutable_data_ptr <CTYPE>(),
105- a.const_data_ptr <CTYPE>(),
106- a.numel ());
107- });
81+ ET_KERNEL_CHECK (
82+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
83+
84+ ET_KERNEL_CHECK (
85+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
86+
87+ // @lint-ignore CLANGTIDY facebook-hte-CArray
88+ static constexpr const char op_name[] = " le.Scalar_out" ;
89+
90+ if (a_type == common_type && a_type == out_type &&
91+ a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
92+ ET_SWITCH_REALB_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
93+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
94+ CTYPE_B b_val = 0 ;
95+ ET_EXTRACT_SCALAR (b, b_val);
96+ CTYPE b_casted = static_cast <CTYPE>(b_val);
97+ using Vec = at::vec::Vectorized<CTYPE>;
98+ at::vec::map<CTYPE>(
99+ [b_casted](Vec x) { return x.le (Vec (b_casted)); },
100+ out.mutable_data_ptr <CTYPE>(),
101+ a.const_data_ptr <CTYPE>(),
102+ a.numel ());
103+ });
108104 });
109105 } else {
110- ET_SWITCH_REAL_TYPES_AND (
111- Bool, a_type, ctx, " le.Scalar_out" , CTYPE_A, [&]() {
112- ET_SWITCH_REAL_TYPES_AND (
113- Bool, b_type, ctx, " le.Scalar_out" , CTYPE_B, [&]() {
114- ET_SWITCH_REAL_TYPES_AND (
115- Bool, common_type, ctx, " le.Scalar_out" , CTYPE_IN, [&]() {
116- ET_SWITCH_REAL_TYPES_AND (
117- Bool,
118- out_type,
119- ctx,
120- " le.Scalar_out" ,
121- CTYPE_OUT,
122- [&]() {
123- CTYPE_B b_val = 0 ;
124- ET_EXTRACT_SCALAR (b, b_val);
125- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
126- const size_t n = a.numel ();
127- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
128- CTYPE_OUT* out_data =
129- out.mutable_data_ptr <CTYPE_OUT>();
130- for (auto i = 0 ; i < n; ++i) {
131- out_data[i] = static_cast <CTYPE_OUT>(
132- static_cast <CTYPE_IN>(a_data[i]) <= b_casted);
133- }
134- });
135- });
136- });
137- });
106+ internal::comparison_scalar_out<std::less_equal, op_name>(ctx, a, b, out);
138107 }
139108
140109 return out;
0 commit comments