Skip to content

Commit 3c2adcb

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Special handling for 0-D tensor
Summary: `ET_CHECK_VALID_DIM` checks whether a tensor has a valid dimension. However, for 0-D tensors, some behaviors are different. Accessing size/stride for dim=0 is invalid, because the valid dimension is [-0, 0) which is none. ``` torch.tensor(2).size(dim=0) ``` However, some ops allow accessing dim=0 or -1 for a 0-D tensor ``` torch.mean(torch.tensor(2, dtype=float), dim=-1) ``` Therefore, in reduce_util helper functions and ops, we need to special handle that case. We also want to revisit this check for 0-D tensor case for ops. Reviewed By: manuelcandales Differential Revision: D48319644 fbshipit-source-id: 0a394ab4caccbcbb7868ccd371276e9dd047f054
1 parent 7108fb3 commit 3c2adcb

File tree

6 files changed

+46
-18
lines changed

6 files changed

+46
-18
lines changed

kernels/portable/cpu/op_max.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ void check_preconditions(
3737
max_indices.scalar_type() == ScalarType::Long,
3838
"dtype of the max_indices Tensor expected to be be long.");
3939
// Ensure dim is valid
40-
ET_CHECK_VALID_DIM(dim, in.dim());
41-
ET_CHECK_NON_ZERO_DIM_SIZE(dim, in);
40+
if (in.dim() == 0) {
41+
ET_CHECK(dim == 0 || dim == -1);
42+
} else {
43+
ET_CHECK_VALID_DIM(dim, in.dim());
44+
ET_CHECK_NON_ZERO_DIM_SIZE(dim, in);
45+
}
4246
const auto expected_dim = compute_reduced_out_dim(in, dim, keepdim);
4347
ET_CHECK_MSG(
4448
max.dim() == expected_dim && max_indices.dim() == expected_dim,

kernels/portable/cpu/op_min.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ void check_preconditions(
3737
min_indices.scalar_type() == ScalarType::Long,
3838
"dtype of the min_indices Tensor expected to be be long.");
3939
// Ensure dim is valid
40-
ET_CHECK_VALID_DIM(dim, in.dim());
41-
ET_CHECK_NON_ZERO_DIM_SIZE(dim, in);
40+
if (in.dim() == 0) {
41+
ET_CHECK(dim == 0 || dim == -1);
42+
} else {
43+
ET_CHECK_VALID_DIM(dim, in.dim());
44+
ET_CHECK_NON_ZERO_DIM_SIZE(dim, in);
45+
}
4246
const auto expected_dim = compute_reduced_out_dim(in, dim, keepdim);
4347
ET_CHECK_MSG(
4448
min.dim() == expected_dim && min_indices.dim() == expected_dim,

kernels/portable/cpu/op_transpose_copy.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ void check_preconditions(
3232
Tensor& out) {
3333
auto a_dim = a.dim();
3434
ET_CHECK_MSG(
35-
a_dim > 0 && a_dim == out.dim(), "invalid rank of tensor a: %zd", a_dim);
35+
a_dim >= 0 && a_dim == out.dim(), "invalid rank of tensor a: %zd", a_dim);
36+
if (a_dim == 0) {
37+
ET_CHECK(dim0 == 0 || dim0 == -1);
38+
ET_CHECK(dim1 == 0 || dim1 == -1);
39+
return;
40+
}
3641
ET_CHECK_MSG(
3742
dim0 >= 0 && dim0 < a_dim,
3843
"dim0: %" PRId64 " out of bounds [0,%zd)",

kernels/portable/cpu/util/reduce_util.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,19 +222,19 @@ size_t compute_reduced_out_size(
222222
if (dim.has_value()) {
223223
const auto dim_val = dim.value();
224224
const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in_dim);
225-
for (size_t i = 0; i < non_neg_dim; ++i) {
225+
for (ssize_t i = 0; i < non_neg_dim; ++i) {
226226
sizes_arr[i] = in.size(i);
227227
}
228228
if (keepdim) {
229229
sizes_arr[non_neg_dim] = 1;
230-
for (size_t i = non_neg_dim + 1; i < in_dim; ++i) {
230+
for (ssize_t i = non_neg_dim + 1; i < in_dim; ++i) {
231231
sizes_arr[i] = in.size(i);
232232
}
233233
} else {
234-
for (size_t i = non_neg_dim; i < in_dim - 1; ++i) {
234+
for (ssize_t i = non_neg_dim; i < in_dim - 1; ++i) {
235235
sizes_arr[i] = in.size(i + 1);
236236
}
237-
out_dim = in_dim - 1;
237+
out_dim = in_dim == 0 ? 0 : in_dim - 1;
238238
}
239239
} else {
240240
if (keepdim) {

kernels/portable/cpu/util/reduce_util.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ void apply_over_dim(
200200
if (in.dim() != 0) {
201201
ET_CHECK_VALID_DIM(dim.value(), in.dim());
202202
} else {
203+
// Special handling for 0-D tensor; 0 or -1 is valid for PyTorch code
204+
// `torch.mean(torch.tensor(2, dtype=float), dim=-1)`
203205
ET_CHECK(dim.value() == 0 || dim.value() == -1);
204206
fn(in.numel(), 1, 0);
205207
return;
@@ -243,7 +245,11 @@ void apply_over_dim(
243245
const int64_t start = 0,
244246
const int64_t end = -1) {
245247
if (dim.has_value()) {
246-
ET_CHECK_VALID_DIM(dim.value(), in.dim());
248+
if (in.dim() != 0) {
249+
ET_CHECK_VALID_DIM(dim.value(), in.dim());
250+
} else {
251+
ET_CHECK(dim.value() == 0 || dim.value() == -1);
252+
}
247253
}
248254
ET_CHECK_MSG(
249255
out_ix < get_out_numel(in, dim),
@@ -255,10 +261,10 @@ void apply_over_dim(
255261
}
256262

257263
const size_t iter_length = get_reduced_dim_product(in, dim);
258-
ET_CHECK_VALID_IX(start, iter_length);
259-
ET_CHECK_VALID_IX(end, iter_length);
260-
const size_t ustart = ET_NORMALIZE_IX(start, iter_length);
261-
const size_t uend = ET_NORMALIZE_IX(end, iter_length);
264+
const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
265+
const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
266+
const size_t ustart = std::max(normalized_start, size_t(0));
267+
const size_t uend = std::min(normalized_end, iter_length - 1);
262268

263269
// If dim is null, iterate over the entire tensor
264270
if (!dim.has_value()) {
@@ -273,8 +279,12 @@ void apply_over_dim(
273279
// Compute non-negative dimension value from dim value
274280
const size_t d = ET_NORMALIZE_IX(dim.value(), in.dim());
275281

276-
apply_on_flat_and_dim_ix_with_stride_and_base(
277-
fn, in.strides()[d], base, ustart, uend);
282+
if (in.dim() == 0) {
283+
fn(base, ustart);
284+
} else {
285+
apply_on_flat_and_dim_ix_with_stride_and_base(
286+
fn, in.strides()[d], base, ustart, uend);
287+
}
278288
}
279289

280290
/**
@@ -370,7 +380,11 @@ std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
370380
const exec_aten::optional<int64_t>& dim,
371381
const size_t out_ix) {
372382
if (dim.has_value()) {
373-
ET_CHECK_VALID_DIM(dim.value(), in.dim());
383+
if (in.dim() != 0) {
384+
ET_CHECK_VALID_DIM(dim.value(), in.dim());
385+
} else {
386+
ET_CHECK(dim.value() == 0 || dim.value() == -1);
387+
}
374388
}
375389

376390
ET_CHECK_MSG(

kernels/portable/cpu/util/test/reduce_test.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ TEST(ReduceUtilTest, ApplyOnZeroDimTensorOverDim) {
438438
TensorFactory<ScalarType::Long> tf;
439439

440440
Tensor in = tf.ones({});
441-
ET_EXPECT_DEATH(_apply_over_dim(in, 0), "");
441+
_apply_over_dim(in, 0);
442+
EXPECT_TENSOR_EQ(in, tf.make({}, {0}));
442443
}
443444

444445
TEST(ReduceUtilTest, ApplyOnZeroDimTensorOverDimListNull) {

0 commit comments

Comments
 (0)