@@ -21,6 +21,17 @@ using Tensor = exec_aten::Tensor;
21
21
// Helper Functions
22
22
//
23
23
24
+ // Normalize the dimension by adding in_dim if d < 0; for 0-D, clamp to 0
25
+ inline size_t _normalize_non_neg_d (ssize_t d, ssize_t in_dim) {
26
+ if (in_dim == 0 && (d == 0 || d == -1 )) {
27
+ return 0 ;
28
+ }
29
+ if (d < 0 ) {
30
+ return d + in_dim;
31
+ }
32
+ return d;
33
+ }
34
+
24
35
void check_dim_list_is_valid (
25
36
const Tensor& in,
26
37
const exec_aten::optional<exec_aten::ArrayRef<int64_t >>& dim_list) {
@@ -29,9 +40,14 @@ void check_dim_list_is_valid(
29
40
bool dim_exist[kTensorDimensionLimit ];
30
41
memset (dim_exist, false , sizeof (dim_exist));
31
42
for (const auto & d : reduce_dims) {
32
- ET_CHECK_VALID_DIM (d, in.dim ());
33
- const size_t non_neg_d = d < 0 ? d + in.dim () : d;
34
- ET_CHECK (non_neg_d < kTensorDimensionLimit );
43
+ if (in.dim () == 0 ) {
44
+ ET_CHECK (d == 0 || d == -1 );
45
+ } else {
46
+ ET_CHECK_VALID_DIM (d, in.dim ());
47
+ }
48
+ const size_t non_neg_d = _normalize_non_neg_d (d, in.dim ());
49
+ ET_CHECK (non_neg_d < kTensorDimensionLimit && non_neg_d >= 0 );
50
+
35
51
ET_CHECK_MSG (
36
52
dim_exist[non_neg_d] == false ,
37
53
" dim %zd appears multiple times in the list of dims" ,
@@ -46,7 +62,7 @@ bool check_dim_in_dim_list(
46
62
const size_t max_dim,
47
63
const exec_aten::ArrayRef<int64_t >& dim_list) {
48
64
for (const auto & d : dim_list) {
49
- const size_t non_neg_dim = d < 0 ? d + max_dim : d ;
65
+ const size_t non_neg_dim = _normalize_non_neg_d (d, max_dim) ;
50
66
if (dim == non_neg_dim) {
51
67
return true ;
52
68
}
@@ -58,14 +74,17 @@ bool check_dim_in_dim_list(
58
74
* Returns the product of the sizes of all reduction dims.
59
75
*/
60
76
size_t get_reduced_dim_product (const Tensor& in, const optional<int64_t >& dim) {
77
+ if (in.dim () == 0 ) {
78
+ return 1 ;
79
+ }
61
80
size_t dim_product = 1 ;
62
81
if (!dim.has_value ()) {
63
82
for (size_t i = 0 ; i < in.dim (); ++i) {
64
83
dim_product *= in.size (i);
65
84
}
66
85
return dim_product;
67
86
}
68
- const size_t d = dim. value () < 0 ? dim.value () + in.dim () : dim. value ( );
87
+ const size_t d = _normalize_non_neg_d ( dim.value (), in.dim ());
69
88
return in.size (d);
70
89
}
71
90
@@ -75,6 +94,9 @@ size_t get_reduced_dim_product(const Tensor& in, const optional<int64_t>& dim) {
75
94
size_t get_reduced_dim_product (
76
95
const Tensor& in,
77
96
const optional<ArrayRef<int64_t >>& dim_list) {
97
+ if (in.dim () == 0 ) {
98
+ return 1 ;
99
+ }
78
100
size_t dim_product = 1 ;
79
101
const size_t in_dim = in.dim ();
80
102
if (!dim_list.has_value () || dim_list.value ().size () == 0 ) {
@@ -84,7 +106,7 @@ size_t get_reduced_dim_product(
84
106
return dim_product;
85
107
}
86
108
for (const auto & d : dim_list.value ()) {
87
- const size_t non_neg_d = d < 0 ? d + in_dim : d ;
109
+ const size_t non_neg_d = _normalize_non_neg_d (d, in_dim) ;
88
110
dim_product *= in.size (non_neg_d);
89
111
}
90
112
return dim_product;
@@ -98,8 +120,12 @@ size_t get_out_numel(const Tensor& in, const optional<int64_t>& dim) {
98
120
size_t out_numel = 1 ;
99
121
if (dim.has_value ()) {
100
122
const auto dim_val = dim.value ();
101
- ET_CHECK_VALID_DIM (dim_val, in.dim ());
102
- const size_t non_neg_dim = dim_val < 0 ? dim_val + in.dim () : dim_val;
123
+ if (in.dim () == 0 ) {
124
+ ET_CHECK (dim_val == 0 || dim_val == -1 );
125
+ } else {
126
+ ET_CHECK_VALID_DIM (dim_val, in.dim ());
127
+ }
128
+ const size_t non_neg_dim = _normalize_non_neg_d (dim_val, in.dim ());
103
129
for (size_t d = 0 ; d < in.dim (); ++d) {
104
130
if (d != non_neg_dim) {
105
131
out_numel *= in.size (d);
@@ -139,8 +165,12 @@ size_t get_init_index(
139
165
return 0 ;
140
166
}
141
167
const auto dim_val = dim.value ();
142
- ET_CHECK_VALID_DIM (dim_val, in.dim ());
143
- const size_t non_neg_dim = dim_val < 0 ? dim_val + in.dim () : dim_val;
168
+ if (in.dim () == 0 ) {
169
+ ET_CHECK (dim_val == 0 || dim_val == -1 );
170
+ } else {
171
+ ET_CHECK_VALID_DIM (dim_val, in.dim ());
172
+ }
173
+ const size_t non_neg_dim = _normalize_non_neg_d (dim_val, in.dim ());
144
174
size_t init_ix = 0 ;
145
175
size_t mutable_out_ix = out_ix;
146
176
auto strides = in.strides ();
@@ -191,7 +221,7 @@ size_t compute_reduced_out_size(
191
221
192
222
if (dim.has_value ()) {
193
223
const auto dim_val = dim.value ();
194
- const auto non_neg_dim = dim_val < 0 ? dim_val + in_dim : dim_val ;
224
+ const size_t non_neg_dim = _normalize_non_neg_d ( dim_val, in_dim) ;
195
225
for (size_t i = 0 ; i < non_neg_dim; ++i) {
196
226
sizes_arr[i] = in.size (i);
197
227
}
0 commit comments