5
5
#include < algorithm>
6
6
#include < cstdint>
7
7
8
+ #include < c10/util/irange.h>
8
9
#include < executorch/runtime/core/exec_aten/exec_aten.h>
9
10
#include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
10
11
#include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -78,7 +79,7 @@ inline bool check_strides(
78
79
// a.strides == (1, 1, 2). We want to sort create a mapping to make the
79
80
// sorted_stride as (2, 1, 1) while sorted_size == (3, 2, 1)
80
81
std::vector<std::int32_t > sorted_idx (sizes.size ());
81
- for (size_t i = 0 ; i < sizes.size (); i++ ) {
82
+ for (const auto i : c10::irange ( sizes.size ()) ) {
82
83
sorted_idx[i] = i;
83
84
}
84
85
std::sort (
@@ -98,7 +99,7 @@ inline bool check_strides(
98
99
// Use the mapping to rearrange the sizes and strides
99
100
std::vector<std::int32_t > sorted_sizes (sizes.size ());
100
101
std::vector<std::int32_t > sorted_strides (sizes.size ());
101
- for (size_t i = 0 ; i < sizes.size (); i++ ) {
102
+ for (const auto i : c10::irange ( sizes.size ()) ) {
102
103
sorted_sizes[i] = sizes[sorted_idx[i]] == 0 ? 1 : sizes[sorted_idx[i]];
103
104
sorted_strides[i] = strides[sorted_idx[i]];
104
105
}
@@ -132,7 +133,7 @@ inline bool check_dim_order(
132
133
}
133
134
size_t gauss_sum = 0 ;
134
135
std::vector<int > count (dim_order.size (), 0 );
135
- for (int i = 0 ; i < dim_order.size (); i++ ) {
136
+ for (const auto i : c10::irange ( dim_order.size ()) ) {
136
137
if (dim_order[i] >= sizes.size ()) {
137
138
return false ;
138
139
}
@@ -378,7 +379,7 @@ class TensorFactory {
378
379
std::vector<executorch::aten::StridesType> contiguous_strides =
379
380
internal::strides_from_dim_order (sizes, contiguous_dim_order);
380
381
381
- for (int32_t i = 0 ; i < input.dim (); i++ ) {
382
+ for (const auto i : c10::irange ( input.dim ()) ) {
382
383
ET_CHECK_MSG (
383
384
input.strides ()[i] == contiguous_strides[i],
384
385
" Input tensor is not contiguous" );
@@ -394,10 +395,10 @@ class TensorFactory {
394
395
std::vector<ctype> channels_last_data (
395
396
N * C * H * W); // Create a new blob with the same total size to contain
396
397
// channels_last data
397
- for (int32_t n = 0 ; n < N; ++n ) {
398
- for (int32_t c = 0 ; c < C; ++c ) {
399
- for (int32_t h = 0 ; h < H; ++h ) {
400
- for (int32_t w = 0 ; w < W; ++w ) {
398
+ for (const auto n : c10::irange (N) ) {
399
+ for (const auto c : c10::irange (C) ) {
400
+ for (const auto h : c10::irange (H) ) {
401
+ for (const auto w : c10::irange (W) ) {
401
402
// Calculate the index in the original blob
402
403
int32_t old_index = ((n * C + c) * H + h) * W + w;
403
404
// Calculate the index in the new blob
@@ -614,7 +615,7 @@ inline void validate_strides(
614
615
}
615
616
}
616
617
// No two dimensions can have same stride value
617
- for (int32_t i = 0 ; i < strides.size (); ++i ) {
618
+ for (const auto i : c10::irange ( strides.size ()) ) {
618
619
for (int32_t j = i + 1 ; j < strides.size (); ++j) {
619
620
if ((sizes[i] == 0 ) || (sizes[j] == 0 ) ||
620
621
((sizes[i] == 1 ) || (sizes[j] == 1 ))) {
@@ -830,7 +831,7 @@ class TensorFactory {
830
831
// given strides is empty.
831
832
if (!sizes.empty () && dim_order.empty ()) {
832
833
default_dim_order.resize (sizes.size (), 1 );
833
- for (size_t i = 0 ; i < sizes.size (); ++i ) {
834
+ for (const auto i : c10::irange ( sizes.size ()) ) {
834
835
default_dim_order[i] = i;
835
836
}
836
837
}
@@ -904,10 +905,10 @@ class TensorFactory {
904
905
std::vector<ctype> channels_last_data (
905
906
N * C * H * W); // Create a new blob with the same total size to contain
906
907
// channels_last data
907
- for (int32_t n = 0 ; n < N; ++n ) {
908
- for (int32_t c = 0 ; c < C; ++c ) {
909
- for (int32_t h = 0 ; h < H; ++h ) {
910
- for (int32_t w = 0 ; w < W; ++w ) {
908
+ for (const auto n : c10::irange (N) ) {
909
+ for (const auto c : c10::irange (C) ) {
910
+ for (const auto h : c10::irange (H) ) {
911
+ for (const auto w : c10::irange (W) ) {
911
912
// Calculate the index in the original blob
912
913
int32_t old_index = ((n * C + c) * H + h) * W + w;
913
914
// Calculate the index in the new blob
0 commit comments