|
13 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
14 | 14 | #include <executorch/kernels/portable/cpu/util/dtype_util.h> |
15 | 15 | #include <executorch/runtime/kernel/kernel_runtime_context.h> |
| 16 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
16 | 17 |
|
17 | 18 | #include <array> |
18 | 19 | #include <utility> |
@@ -94,17 +95,28 @@ inline void apply_elementwise_fn( |
94 | 95 | char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr()); |
95 | 96 | const auto out_element_size = out.element_size(); |
96 | 97 |
|
97 | | - for (const auto& indexes : |
98 | | - BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) { |
99 | | - std::array<CTYPE_COMMON, kNumInputs> loaded_inputs; |
100 | | - for (const auto idx : c10::irange(kNumInputs)) { |
101 | | - const auto& input_info = inputs_info[idx]; |
102 | | - loaded_inputs[idx] = input_info.load_to_common( |
103 | | - &input_info.data_ptr[indexes[idx + 1] * input_info.element_size]); |
104 | | - } |
105 | | - auto result = std::apply(compute_fun, loaded_inputs); |
106 | | - store_common_to_out(result, &data_out[indexes[0] * out_element_size]); |
107 | | - } |
| 98 | + ::executorch::extension::parallel_for( |
| 99 | + 0, |
| 100 | + out.numel(), |
| 101 | + ::executorch::extension::internal::GRAIN_SIZE, |
| 102 | + [&](const auto begin, const auto end) { |
| 103 | + const auto range = |
| 104 | + BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...); |
| 105 | + auto begin_it = range.begin(); |
| 106 | + begin_it += begin; |
| 107 | + for (; (*begin_it)[0] < end; ++begin_it) { |
| 108 | + const auto& indexes = *begin_it; |
| 109 | + std::array<CTYPE_COMMON, kNumInputs> loaded_inputs; |
| 110 | + for (const auto idx : c10::irange(kNumInputs)) { |
| 111 | + const auto& input_info = inputs_info[idx]; |
| 112 | + loaded_inputs[idx] = input_info.load_to_common( |
| 113 | + &input_info |
| 114 | + .data_ptr[indexes[idx + 1] * input_info.element_size]); |
| 115 | + } |
| 116 | + auto result = std::apply(compute_fun, loaded_inputs); |
| 117 | + store_common_to_out(result, &data_out[indexes[0] * out_element_size]); |
| 118 | + } |
| 119 | + }); |
108 | 120 | } |
109 | 121 | } // namespace internal |
110 | 122 |
|
|
0 commit comments