1010#include  < executorch/kernels/portable/cpu/util/reduce_util.h> 
1111#include  < executorch/runtime/kernel/kernel_includes.h> 
1212
13+ #include  < optional> 
14+ 
1315namespace  torch  {
1416namespace  executor  {
1517namespace  native  {
@@ -79,6 +81,11 @@ Tensor& any_dims_out(
7981  ScalarType out_type = out.scalar_type ();
8082  constexpr  auto  name = " any.dims_out"  ;
8183
84+   const  bool  in_not_empty = in.numel () > 0 ;
85+   std::optional<MapReduceOverDimListPlan> plan;
86+   if  ((!dim_list.has_value () || !dim_list.value ().empty ()) && in_not_empty) {
87+     plan.emplace (in, dim_list);
88+   }
8289  ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, name, CTYPE_IN, [&] {
8390    ET_SWITCH_TWO_TYPES (Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
8491      CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
@@ -91,12 +98,10 @@ Tensor& any_dims_out(
9198      } else  {
9299        for  (const  auto  out_ix : c10::irange (out.numel ())) {
93100          bool  any = false ;
94-           if  (in. numel () >  0 ) {
95-             any = map_reduce_over_dim_list <CTYPE_IN, bool >(
101+           if  (in_not_empty ) {
102+             any = plan-> execute <CTYPE_IN, bool >(
96103                [](CTYPE_IN v) { return  static_cast <bool >(v); },
97104                [](bool  outv, bool  acc) { return  acc || outv; },
98-                 in,
99-                 dim_list,
100105                out_ix);
101106          }
102107          out_data[out_ix] = static_cast <CTYPE_OUT>(any);
0 commit comments