| 
7 | 7 |  */  | 
8 | 8 | #include <executorch/kernels/portable/cpu/util/elementwise_util.h>  | 
9 | 9 | #include <executorch/runtime/kernel/kernel_includes.h>  | 
10 |  | -#include <iostream>  | 
 | 10 | +#include <executorch/runtime/kernel/thread_parallel_interface.h>  | 
11 | 11 | 
 
  | 
12 | 12 | namespace torch {  | 
13 | 13 | namespace executor {  | 
@@ -58,15 +58,31 @@ Tensor& opt_where_out(  | 
58 | 58 |       const bool* const data_cond = cond.const_data_ptr<bool>();  | 
59 | 59 |       CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();  | 
60 | 60 |       if (any_is_broadcasted) {  | 
61 |  | -        for (const auto [out_index, a_index, b_index, cond_index] :  | 
62 |  | -             BroadcastIndexesRange<3>(out, a, b, cond)) {  | 
63 |  | -          data_out[out_index] =  | 
64 |  | -              data_cond[cond_index] ? data_a[a_index] : data_b[b_index];  | 
65 |  | -        }  | 
 | 61 | +        executorch::extension::parallel_for(  | 
 | 62 | +            0,  | 
 | 63 | +            out_numel,  | 
 | 64 | +            ::executorch::extension::internal::GRAIN_SIZE,  | 
 | 65 | +            [&](const auto begin, const auto end) {  | 
 | 66 | +              auto range = BroadcastIndexesRange<3>(out, a, b, cond);  | 
 | 67 | +              auto begin_it = range.begin();  | 
 | 68 | +              begin_it += begin;  | 
 | 69 | +              for (; (*begin_it)[0] < end; ++begin_it) {  | 
 | 70 | +                const auto [out_index, a_index, b_index, cond_index] =  | 
 | 71 | +                    *begin_it;  | 
 | 72 | +                data_out[out_index] =  | 
 | 73 | +                    data_cond[cond_index] ? data_a[a_index] : data_b[b_index];  | 
 | 74 | +              }  | 
 | 75 | +            });  | 
66 | 76 |       } else {  | 
67 |  | -        for (const auto i : c10::irange(out_numel)) {  | 
68 |  | -          data_out[i] = data_cond[i] ? data_a[i] : data_b[i];  | 
69 |  | -        }  | 
 | 77 | +        executorch::extension::parallel_for(  | 
 | 78 | +            0,  | 
 | 79 | +            out_numel,  | 
 | 80 | +            ::executorch::extension::internal::GRAIN_SIZE,  | 
 | 81 | +            [&](const auto begin, const auto end) {  | 
 | 82 | +              for (const auto i : c10::irange(begin, end)) {  | 
 | 83 | +                data_out[i] = data_cond[i] ? data_a[i] : data_b[i];  | 
 | 84 | +              }  | 
 | 85 | +            });  | 
70 | 86 |       }  | 
71 | 87 |     });  | 
72 | 88 |   } else {  | 
 | 
0 commit comments