Skip to content

Commit a9d7265

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
fix reduce_util for 0-D tensor
Summary: For 0-D tensor, we should still be able to get 0 and -1 dimensions (treat it as a 1-D tensor with only one number) Reviewed By: manuelcandales Differential Revision: D48247370 fbshipit-source-id: 26a6a629055e8ac4a67075a479bd801e201982c7
1 parent 1db9e08 commit a9d7265

File tree

3 files changed

+77
-17
lines changed

3 files changed

+77
-17
lines changed

kernels/portable/cpu/util/reduce_util.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ using Tensor = exec_aten::Tensor;
2121
// Helper Functions
2222
//
2323

24+
// Normalize the dimension by adding in_dim if d < 0; for 0-D, clamp to 0
25+
inline size_t _normalize_non_neg_d(ssize_t d, ssize_t in_dim) {
26+
if (in_dim == 0 && (d == 0 || d == -1)) {
27+
return 0;
28+
}
29+
if (d < 0) {
30+
return d + in_dim;
31+
}
32+
return d;
33+
}
34+
2435
void check_dim_list_is_valid(
2536
const Tensor& in,
2637
const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
@@ -29,9 +40,14 @@ void check_dim_list_is_valid(
2940
bool dim_exist[kTensorDimensionLimit];
3041
memset(dim_exist, false, sizeof(dim_exist));
3142
for (const auto& d : reduce_dims) {
32-
ET_CHECK_VALID_DIM(d, in.dim());
33-
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
34-
ET_CHECK(non_neg_d < kTensorDimensionLimit);
43+
if (in.dim() == 0) {
44+
ET_CHECK(d == 0 || d == -1);
45+
} else {
46+
ET_CHECK_VALID_DIM(d, in.dim());
47+
}
48+
const size_t non_neg_d = _normalize_non_neg_d(d, in.dim());
49+
ET_CHECK(non_neg_d < kTensorDimensionLimit && non_neg_d >= 0);
50+
3551
ET_CHECK_MSG(
3652
dim_exist[non_neg_d] == false,
3753
"dim %zd appears multiple times in the list of dims",
@@ -46,7 +62,7 @@ bool check_dim_in_dim_list(
4662
const size_t max_dim,
4763
const exec_aten::ArrayRef<int64_t>& dim_list) {
4864
for (const auto& d : dim_list) {
49-
const size_t non_neg_dim = d < 0 ? d + max_dim : d;
65+
const size_t non_neg_dim = _normalize_non_neg_d(d, max_dim);
5066
if (dim == non_neg_dim) {
5167
return true;
5268
}
@@ -58,14 +74,17 @@ bool check_dim_in_dim_list(
5874
* Returns the product of the sizes of all reduction dims.
5975
*/
6076
size_t get_reduced_dim_product(const Tensor& in, const optional<int64_t>& dim) {
77+
if (in.dim() == 0) {
78+
return 1;
79+
}
6180
size_t dim_product = 1;
6281
if (!dim.has_value()) {
6382
for (size_t i = 0; i < in.dim(); ++i) {
6483
dim_product *= in.size(i);
6584
}
6685
return dim_product;
6786
}
68-
const size_t d = dim.value() < 0 ? dim.value() + in.dim() : dim.value();
87+
const size_t d = _normalize_non_neg_d(dim.value(), in.dim());
6988
return in.size(d);
7089
}
7190

@@ -75,6 +94,9 @@ size_t get_reduced_dim_product(const Tensor& in, const optional<int64_t>& dim) {
7594
size_t get_reduced_dim_product(
7695
const Tensor& in,
7796
const optional<ArrayRef<int64_t>>& dim_list) {
97+
if (in.dim() == 0) {
98+
return 1;
99+
}
78100
size_t dim_product = 1;
79101
const size_t in_dim = in.dim();
80102
if (!dim_list.has_value() || dim_list.value().size() == 0) {
@@ -84,7 +106,7 @@ size_t get_reduced_dim_product(
84106
return dim_product;
85107
}
86108
for (const auto& d : dim_list.value()) {
87-
const size_t non_neg_d = d < 0 ? d + in_dim : d;
109+
const size_t non_neg_d = _normalize_non_neg_d(d, in_dim);
88110
dim_product *= in.size(non_neg_d);
89111
}
90112
return dim_product;
@@ -98,8 +120,12 @@ size_t get_out_numel(const Tensor& in, const optional<int64_t>& dim) {
98120
size_t out_numel = 1;
99121
if (dim.has_value()) {
100122
const auto dim_val = dim.value();
101-
ET_CHECK_VALID_DIM(dim_val, in.dim());
102-
const size_t non_neg_dim = dim_val < 0 ? dim_val + in.dim() : dim_val;
123+
if (in.dim() == 0) {
124+
ET_CHECK(dim_val == 0 || dim_val == -1);
125+
} else {
126+
ET_CHECK_VALID_DIM(dim_val, in.dim());
127+
}
128+
const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in.dim());
103129
for (size_t d = 0; d < in.dim(); ++d) {
104130
if (d != non_neg_dim) {
105131
out_numel *= in.size(d);
@@ -139,8 +165,12 @@ size_t get_init_index(
139165
return 0;
140166
}
141167
const auto dim_val = dim.value();
142-
ET_CHECK_VALID_DIM(dim_val, in.dim());
143-
const size_t non_neg_dim = dim_val < 0 ? dim_val + in.dim() : dim_val;
168+
if (in.dim() == 0) {
169+
ET_CHECK(dim_val == 0 || dim_val == -1);
170+
} else {
171+
ET_CHECK_VALID_DIM(dim_val, in.dim());
172+
}
173+
const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in.dim());
144174
size_t init_ix = 0;
145175
size_t mutable_out_ix = out_ix;
146176
auto strides = in.strides();
@@ -191,7 +221,7 @@ size_t compute_reduced_out_size(
191221

192222
if (dim.has_value()) {
193223
const auto dim_val = dim.value();
194-
const auto non_neg_dim = dim_val < 0 ? dim_val + in_dim : dim_val;
224+
const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in_dim);
195225
for (size_t i = 0; i < non_neg_dim; ++i) {
196226
sizes_arr[i] = in.size(i);
197227
}

kernels/portable/cpu/util/reduce_util.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,12 @@ void apply_over_dim(
197197
return;
198198
}
199199

200-
if (dim.has_value()) {
200+
if (in.dim() != 0) {
201201
ET_CHECK_VALID_DIM(dim.value(), in.dim());
202+
} else {
203+
ET_CHECK(dim.value() == 0 || dim.value() == -1);
204+
fn(in.numel(), 1, 0);
205+
return;
202206
}
203207

204208
if (in.numel() == 0) {
@@ -304,8 +308,8 @@ void apply_over_dim_list(
304308
const size_t ustart = std::max(normalized_start, size_t(0));
305309
const size_t uend = std::min(normalized_end, iter_length - 1);
306310

307-
// If dim_list is null or empty, iterate over the entire tensor
308-
if (!dim_list.has_value() || dim_list.value().size() == 0) {
311+
// If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
312+
if (!dim_list.has_value() || dim_list.value().size() == 0 || in.dim() == 0) {
309313
apply_on_flat_ix_with_stride_and_base(
310314
fn, /*stride=*/1, /*base=*/0, ustart, uend);
311315
return;
@@ -539,7 +543,10 @@ inline ssize_t compute_reduced_out_dim(
539543
const exec_aten::Tensor& in,
540544
const exec_aten::optional<int64_t>& dim,
541545
bool keepdim) {
542-
return (keepdim ? in.dim() : dim.has_value() ? in.dim() - 1 : 0);
546+
return (
547+
keepdim ? in.dim()
548+
: dim.has_value() && in.dim() != 0 ? in.dim() - 1
549+
: 0);
543550
}
544551

545552
inline ssize_t compute_reduced_out_dim(
@@ -548,7 +555,9 @@ inline ssize_t compute_reduced_out_dim(
548555
bool keepdim) {
549556
return (
550557
keepdim ? in.dim()
551-
: dim_list.has_value() && dim_list.value().size() != 0
558+
: dim_list.has_value() && dim_list.value().size() != 0 &&
559+
in.dim() != 0
560+
552561
? in.dim() - dim_list.value().size()
553562
: 0);
554563
}

kernels/portable/cpu/util/test/reduce_test.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,26 @@ TEST(ReduceUtilTest, ApplyOverDimListNull) {
127127
EXPECT_TENSOR_EQ(in, tf.zeros({2, 4, 5, 3}));
128128
}
129129

130+
TEST(ReduceUtilTest, ApplyOverZeroDimListEmpty) {
131+
TensorFactory<ScalarType::Long> tf;
132+
optional<ArrayRef<int64_t>> null_dim_list;
133+
134+
Tensor in = tf.ones({});
135+
_apply_over_dim_list(in, null_dim_list);
136+
EXPECT_TENSOR_EQ(in, tf.zeros({}));
137+
}
138+
139+
TEST(ReduceUtilTest, ApplyOverZeroDim) {
140+
TensorFactory<ScalarType::Long> tf;
141+
optional<ArrayRef<int64_t>> dim_list;
142+
int64_t dim_array_0[1] = {0};
143+
dim_list = optional<ArrayRef<int64_t>>(ArrayRef<int64_t>{dim_array_0, 1});
144+
145+
Tensor in = tf.ones({});
146+
_apply_over_dim_list(in, dim_list);
147+
EXPECT_TENSOR_EQ(in, tf.zeros({}));
148+
}
149+
130150
TEST(ReduceUtilTest, ApplyOverDimListEmpty) {
131151
TensorFactory<ScalarType::Long> tf;
132152
optional<ArrayRef<int64_t>> empty_dim_list{ArrayRef<int64_t>{}};
@@ -446,7 +466,8 @@ TEST(ReduceUtilTest, ApplyOnZeroDimTensorOverDimListNonEmpty) {
446466
optional<ArrayRef<int64_t>>(ArrayRef<int64_t>{dim_array_0, 1});
447467

448468
Tensor in = tf.ones({});
449-
ET_EXPECT_DEATH(_apply_over_dim_list(in, dim_list), "");
469+
_apply_over_dim_list(in, dim_list), "";
470+
EXPECT_TENSOR_EQ(in, tf.make({}, {0}));
450471
}
451472

452473
TEST(ReduceUtilTest, ApplyOnEmptyTensorOverDim) {

0 commit comments

Comments
 (0)