|
7 | 7 | */ |
8 | 8 | #include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
9 | 9 | #include <executorch/runtime/kernel/kernel_includes.h> |
10 | | -#include <iostream> |
| 10 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
11 | 11 |
|
12 | 12 | namespace torch { |
13 | 13 | namespace executor { |
@@ -58,15 +58,31 @@ Tensor& opt_where_out( |
58 | 58 | const bool* const data_cond = cond.const_data_ptr<bool>(); |
59 | 59 | CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>(); |
60 | 60 | if (any_is_broadcasted) { |
61 | | - for (const auto [out_index, a_index, b_index, cond_index] : |
62 | | - BroadcastIndexesRange<3>(out, a, b, cond)) { |
63 | | - data_out[out_index] = |
64 | | - data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; |
65 | | - } |
| 61 | + executorch::extension::parallel_for( |
| 62 | + 0, |
| 63 | + out_numel, |
| 64 | + ::executorch::extension::internal::GRAIN_SIZE, |
| 65 | + [&](const auto begin, const auto end) { |
| 66 | + auto range = BroadcastIndexesRange<3>(out, a, b, cond); |
| 67 | + auto begin_it = range.begin(); |
| 68 | + begin_it += begin; |
| 69 | + for (; (*begin_it)[0] < end; ++begin_it) { |
| 70 | + const auto [out_index, a_index, b_index, cond_index] = |
| 71 | + *begin_it; |
| 72 | + data_out[out_index] = |
| 73 | + data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; |
| 74 | + } |
| 75 | + }); |
66 | 76 | } else { |
67 | | - for (const auto i : c10::irange(out_numel)) { |
68 | | - data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; |
69 | | - } |
| 77 | + executorch::extension::parallel_for( |
| 78 | + 0, |
| 79 | + out_numel, |
| 80 | + ::executorch::extension::internal::GRAIN_SIZE, |
| 81 | + [&](const auto begin, const auto end) { |
| 82 | + for (const auto i : c10::irange(begin, end)) { |
| 83 | + data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; |
| 84 | + } |
| 85 | + }); |
70 | 86 | } |
71 | 87 | }); |
72 | 88 | } else { |
|
0 commit comments