|
9 | 9 | #pragma once |
10 | 10 |
|
11 | 11 | #include <c10/util/irange.h> |
| 12 | +#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h> |
12 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
13 | 14 | #include <executorch/kernels/portable/cpu/util/dtype_util.h> |
14 | 15 | #include <executorch/runtime/kernel/kernel_runtime_context.h> |
@@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn( |
121 | 122 | char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr()); |
122 | 123 |
|
123 | 124 | auto out_numel = out.numel(); |
124 | | - for (const auto i : c10::irange(out_numel)) { |
125 | | - size_t a_linear_index = i; |
126 | | - size_t b_linear_index = i; |
127 | | - |
128 | | - if (any_is_broadcasted) { |
129 | | - size_t out_indexes[kTensorDimensionLimit]; |
130 | | - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); |
131 | | - |
132 | | - if (a_is_broadcasted) { |
133 | | - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); |
134 | | - } |
135 | | - if (b_is_broadcasted) { |
136 | | - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); |
137 | | - } |
| 125 | + if (any_is_broadcasted) { |
| 126 | + for (const auto [out_index, a_index, b_index] : |
| 127 | + BroadcastIndexesRange<2>(out, a, b)) { |
| 128 | + auto result = compute_fun( |
| 129 | + load_a_to_common(&data_a[a_index * a_element_size]), |
| 130 | + load_b_to_common(&data_b[b_index * b_element_size])); |
| 131 | + store_common_to_out(result, &data_out[out_index * out_element_size]); |
| 132 | + } |
| 133 | + } else { |
| 134 | + for (const auto i : c10::irange(out_numel)) { |
| 135 | + size_t a_linear_index = i; |
| 136 | + size_t b_linear_index = i; |
| 137 | + |
| 138 | + auto result = compute_fun( |
| 139 | + load_a_to_common(&data_a[a_linear_index * a_element_size]), |
| 140 | + load_b_to_common(&data_b[b_linear_index * b_element_size])); |
| 141 | + store_common_to_out(result, &data_out[i * out_element_size]); |
138 | 142 | } |
139 | | - |
140 | | - auto result = compute_fun( |
141 | | - load_a_to_common(&data_a[a_linear_index * a_element_size]), |
142 | | - load_b_to_common(&data_b[b_linear_index * b_element_size])); |
143 | | - store_common_to_out(result, &data_out[i * out_element_size]); |
144 | 143 | } |
145 | 144 | } |
146 | 145 |
|
@@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn( |
211 | 210 | char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr()); |
212 | 211 |
|
213 | 212 | auto out_numel = out.numel(); |
214 | | - for (const auto i : c10::irange(out_numel)) { |
215 | | - size_t a_linear_index = i; |
216 | | - size_t b_linear_index = i; |
217 | | - size_t c_linear_index = i; |
218 | | - |
219 | | - if (any_is_broadcasted) { |
220 | | - size_t out_indexes[kTensorDimensionLimit]; |
221 | | - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); |
222 | | - |
223 | | - if (a_is_broadcasted) { |
224 | | - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); |
225 | | - } |
226 | | - if (b_is_broadcasted) { |
227 | | - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); |
228 | | - } |
229 | | - if (c_is_broadcasted) { |
230 | | - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); |
231 | | - } |
| 213 | + if (any_is_broadcasted) { |
| 214 | + for (const auto [out_index, a_index, b_index, c_index] : |
| 215 | + BroadcastIndexesRange<3>(out, a, b, c)) { |
| 216 | + auto result = compute_fun( |
| 217 | + load_a_to_common(&data_a[a_index * a_element_size]), |
| 218 | + load_b_to_common(&data_b[b_index * b_element_size]), |
| 219 | + load_c_to_common(&data_c[c_index * c_element_size])); |
| 220 | + store_common_to_out(result, &data_out[out_index * out_element_size]); |
| 221 | + } |
| 222 | + } else { |
| 223 | + for (const auto i : c10::irange(out_numel)) { |
| 224 | + size_t a_linear_index = i; |
| 225 | + size_t b_linear_index = i; |
| 226 | + size_t c_linear_index = i; |
| 227 | + |
| 228 | + auto result = compute_fun( |
| 229 | + load_a_to_common(&data_a[a_linear_index * a_element_size]), |
| 230 | + load_b_to_common(&data_b[b_linear_index * b_element_size]), |
| 231 | + load_c_to_common(&data_c[c_linear_index * c_element_size])); |
| 232 | + store_common_to_out(result, &data_out[i * out_element_size]); |
232 | 233 | } |
233 | | - |
234 | | - auto result = compute_fun( |
235 | | - load_a_to_common(&data_a[a_linear_index * a_element_size]), |
236 | | - load_b_to_common(&data_b[b_linear_index * b_element_size]), |
237 | | - load_c_to_common(&data_c[c_linear_index * c_element_size])); |
238 | | - store_common_to_out(result, &data_out[i * out_element_size]); |
239 | 234 | } |
240 | 235 | } |
241 | 236 |
|
|
0 commit comments