99#include < executorch/kernels/optimized/vec/functional.h>
1010#include < executorch/kernels/optimized/vec/vec.h>
1111#include < executorch/kernels/portable/cpu/scalar_utils.h>
12+ #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
1213#include < executorch/runtime/kernel/kernel_includes.h>
1314#include < executorch/runtime/platform/assert.h>
1415
@@ -26,6 +27,58 @@ Tensor& opt_le_tensor_out(
2627 Tensor& out) {
2728 (void )ctx;
2829
30+ ScalarType a_type = a.scalar_type ();
31+ ScalarType b_type = b.scalar_type ();
32+ ScalarType out_type = out.scalar_type ();
33+
34+ if (a.numel () == 1 || b.numel () == 1 ) {
35+ const Tensor* tensor;
36+ const Tensor* scalar;
37+ ScalarType tensor_type;
38+ ScalarType scalar_type;
39+ if (a.numel () == 1 ) {
40+ tensor = &b;
41+ tensor_type = b_type;
42+ scalar = &a;
43+ scalar_type = a_type;
44+ } else {
45+ tensor = &a;
46+ tensor_type = a_type;
47+ scalar = &b;
48+ scalar_type = b_type;
49+ }
50+ ET_KERNEL_CHECK (
51+ ctx,
52+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
53+ InvalidArgument,
54+ out);
55+
56+ constexpr auto name = " le.Tensor_out" ;
57+
58+ ET_SWITCH_REALB_TYPES (tensor_type, ctx, name, CTYPE, [&]() {
59+ ET_SWITCH_REALB_TYPES (scalar_type, ctx, name, CTYPE_SCALAR, [&]() {
60+ CTYPE_SCALAR scalar_val = *scalar->const_data_ptr <CTYPE_SCALAR>();
61+ CTYPE scalar_casted = static_cast <CTYPE>(scalar_val);
62+
63+ using Vec = executorch::vec::Vectorized<CTYPE>;
64+ if (a.numel () == 1 ) {
65+ executorch::vec::map<CTYPE>(
66+ [scalar_casted](Vec x) { return Vec (scalar_casted).le (x); },
67+ out.mutable_data_ptr <CTYPE>(),
68+ tensor->const_data_ptr <CTYPE>(),
69+ out.numel ());
70+ } else {
71+ executorch::vec::map<CTYPE>(
72+ [scalar_casted](Vec x) { return x.le (Vec (scalar_casted)); },
73+ out.mutable_data_ptr <CTYPE>(),
74+ tensor->const_data_ptr <CTYPE>(),
75+ out.numel ());
76+ }
77+ });
78+ });
79+ return out;
80+ }
81+
2982 ET_KERNEL_CHECK (ctx, tensors_have_same_shape (a, b), InvalidArgument, out);
3083
3184 // Resize for dynamic shape
@@ -37,10 +90,6 @@ Tensor& opt_le_tensor_out(
3790 out,
3891 " Failed to resize output tensor." );
3992
40- ScalarType a_type = a.scalar_type ();
41- ScalarType b_type = b.scalar_type ();
42- ScalarType out_type = out.scalar_type ();
43-
4493 if (a_type == b_type && a_type == out_type) {
4594 ET_SWITCH_REAL_TYPES_AND (
4695 Bool, out_type, ctx, " le.Tensor_out" , CTYPE, [&]() {
0 commit comments