Skip to content

Commit 889e463

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Fix more shape tests for 0-D tensor
Summary: Special handling ops for 0-D tensor Reviewed By: manuelcandales Differential Revision: D48323627 fbshipit-source-id: 06b3bdfd55f9a5e8569168875b902402b1f96c66
1 parent 3c2adcb commit 889e463

11 files changed

+78
-75
lines changed

kernels/portable/cpu/op_amax.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ void check_preconditions(
2828
Tensor& out) {
2929
ET_CHECK_SAME_DTYPE2(in, out);
3030
check_dim_list_is_valid(in, dim_list);
31-
for (const auto& d : dim_list) {
32-
ET_CHECK_NON_ZERO_DIM_SIZE(d, in);
31+
if (in.dim() != 0) {
32+
for (const auto& d : dim_list) {
33+
ET_CHECK_NON_ZERO_DIM_SIZE(d, in);
34+
}
3335
}
3436
ET_CHECK_MSG(
3537
out.dim() == compute_reduced_out_dim(in, dim_list, keepdim),

kernels/portable/cpu/op_amin.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ void check_preconditions(
2828
Tensor& out) {
2929
ET_CHECK_SAME_DTYPE2(in, out);
3030
check_dim_list_is_valid(in, dim_list);
31-
for (const auto& d : dim_list) {
32-
ET_CHECK_NON_ZERO_DIM_SIZE(d, in);
31+
if (in.dim() != 0) {
32+
for (const auto& d : dim_list) {
33+
ET_CHECK_NON_ZERO_DIM_SIZE(d, in);
34+
}
3335
}
3436
ET_CHECK_MSG(
3537
out.dim() == compute_reduced_out_dim(in, dim_list, keepdim),

kernels/portable/cpu/op_argmax.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ void check_preconditions(
2727
optional<int64_t> dim,
2828
bool keepdim,
2929
Tensor& out) {
30+
if (in.dim() == 0) {
31+
if (dim.has_value()) {
32+
ET_CHECK(dim.value() == 0 || dim.value() == -1);
33+
}
34+
return;
35+
}
3036
if (dim.has_value()) {
3137
ET_CHECK_VALID_DIM(dim.value(), in.dim());
3238
ET_CHECK_NON_ZERO_DIM_SIZE(dim.value(), in);

kernels/portable/cpu/op_argmin.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ void check_preconditions(
2727
optional<int64_t> dim,
2828
bool keepdim,
2929
Tensor& out) {
30+
if (in.dim() == 0) {
31+
if (dim.has_value()) {
32+
ET_CHECK(dim.value() == 0 || dim.value() == -1);
33+
}
34+
return;
35+
}
3036
if (dim.has_value()) {
3137
ET_CHECK_VALID_DIM(dim.value(), in.dim());
3238
ET_CHECK_NON_ZERO_DIM_SIZE(dim.value(), in);

kernels/portable/cpu/op_index_select.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ void check_index_select_args(
3232
const Tensor& index,
3333
Tensor& output) {
3434
// Check dim. The dim planed to be selected on shall exist in input
35-
ET_CHECK_MSG(
36-
dim >= 0 && dim < input.dim(),
37-
"dim %" PRId64 " out of range [0,%zd)",
38-
dim,
39-
input.dim());
35+
if (input.dim() == 0) {
36+
ET_CHECK(dim == 0);
37+
} else {
38+
ET_CHECK_MSG(
39+
dim >= 0 && dim < input.dim(),
40+
"dim %" PRId64 " out of range [0,%zd)",
41+
dim,
42+
input.dim());
43+
}
4044

4145
// Input output should have the same dim
4246
ET_CHECK_MSG(
@@ -62,7 +66,10 @@ void check_index_select_args(
6266
// Index should be a 1-D LongTensor, check if any index is out of bound
6367
ET_CHECK_MSG(
6468
index.scalar_type() == ScalarType::Long, "index scalar_type not long");
65-
ET_CHECK_MSG(index.dim() == 1, "index.dim() %zd != 1", index.dim());
69+
ET_CHECK_MSG(
70+
index.dim() == 1 || index.dim() == 0,
71+
"index.dim() %zd != 1 or 0",
72+
index.dim());
6673

6774
const int64_t* src = index.mutable_data_ptr<int64_t>();
6875
for (auto i = 1; i < index.numel(); i++) {

kernels/portable/cpu/op_log_softmax.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ void check_preconditions(
4040
// Ensure in has value
4141
ET_CHECK_MSG(in.numel() > 0, "in.numel() %zd <= 0", in.numel());
4242
// Ensure dim is valid
43-
ET_CHECK_VALID_DIM(dim, in.dim());
43+
if (in.dim() == 0) {
44+
ET_CHECK_MSG(dim == 0 || dim == -1, "dim must be 0 or -1 for 0-D tensor");
45+
} else {
46+
ET_CHECK_VALID_DIM(dim, in.dim());
47+
}
4448
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(in);
4549
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(out);
4650
}

kernels/portable/cpu/op_select_copy.cpp

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,38 @@ void check_and_update_select_copy_int_out_args(
2727
int64_t dim,
2828
int64_t index,
2929
Tensor output) {
30-
// Support python-style negative indexing. E.g., for the shape {2, 3, 4},
31-
// dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
32-
33-
// The dim planed to be selected on shall exist in input
34-
ET_CHECK_MSG(
35-
dim >= -input.dim() && dim < input.dim(),
36-
"dim %" PRId64 " out of range [-%zd,%zd)",
37-
dim,
38-
input.dim(),
39-
input.dim());
40-
41-
// The index shall be valid in the given dimenson
42-
ET_CHECK_MSG(
43-
index >= -input.size(dim) && index < input.size(dim),
44-
"index %" PRId64 " out of range [-%zd,%zd) at input.size( %" PRId64 ")",
45-
index,
46-
input.size(dim),
47-
input.size(dim),
48-
dim);
30+
if (input.dim() == 0) {
31+
ET_CHECK(dim == 0 || dim == -1);
32+
} else {
33+
// Support python-style negative indexing. E.g., for the shape {2, 3, 4},
34+
// dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so
35+
// on.
36+
37+
// The dim planed to be selected on shall exist in input
38+
ET_CHECK_MSG(
39+
dim >= -input.dim() && dim < input.dim(),
40+
"dim %" PRId64 " out of range [-%zd,%zd)",
41+
dim,
42+
input.dim(),
43+
input.dim());
44+
45+
// Support python-style negative indexing
46+
if (dim < 0) {
47+
dim += input.dim();
48+
}
4949

50-
// Support python-style negative indexing
51-
if (dim < 0) {
52-
dim += input.dim();
53-
}
54-
if (index < 0) {
55-
index += input.size(dim);
50+
// The index shall be valid in the given dimenson
51+
ET_CHECK_MSG(
52+
index >= -input.size(dim) && index < input.size(dim),
53+
"index %" PRId64 " out of range [-%zd,%zd) at input.size( %" PRId64 ")",
54+
index,
55+
input.size(dim),
56+
input.size(dim),
57+
dim);
58+
59+
if (index < 0) {
60+
index += input.size(dim);
61+
}
5662
}
5763

5864
// Input dtype shall match the output dtype.
@@ -71,7 +77,7 @@ void check_and_update_select_copy_int_out_args(
7177
// - output.size(i) shall equal to input.size(i) if i < dim,
7278
// - output.size(i) shall equal to input.size(i+1) if i >= dim
7379

74-
for (size_t d = 0; d < input.dim() - 1; d++) {
80+
for (ssize_t d = 0; d < input.dim() - 1; d++) {
7581
if (d < dim) {
7682
ET_CHECK_MSG(
7783
input.size(d) == output.size(d),

kernels/portable/cpu/op_softmax.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ void check_preconditions(
3737
"in.dim() %zd!= out.dim() %zd",
3838
in.dim(),
3939
out.dim());
40-
// Ensure in has value
41-
ET_CHECK_MSG(in.numel() > 0, "in.numel() %zd <= 0", in.numel());
4240
// Ensure dim is valid
43-
ET_CHECK_VALID_DIM(dim, in.dim());
41+
if (in.dim() == 0) {
42+
ET_CHECK(dim == 0 || dim == -1);
43+
} else {
44+
ET_CHECK_VALID_DIM(dim, in.dim());
45+
}
4446
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(in);
4547
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(out);
4648
}

kernels/portable/cpu/op_squeeze_copy.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ Tensor& squeeze_copy_dim_out(
8282
ET_CHECK_SAME_DTYPE2(self, out);
8383

8484
// A valid dim must be in [-self.dim(), self.dim())
85+
if (self.dim() == 0 && dim == -1) {
86+
dim = 0;
87+
}
8588
ET_CHECK_MSG(
8689
(self.dim() == 0 && dim == 0) || (dim >= -self.dim() && dim < self.dim()),
8790
"dim %" PRId64 " out of range [-%zd,%zd)",

kernels/test/op_log_softmax_test.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,6 @@ TEST(OpLogSoftmaxOutTest, AllDtypesSupported) {
9393
// for those types.
9494
}
9595

96-
TEST(OpLogSoftmaxOutTest, EmptyInputOrEmptyOutTensorDies) {
97-
if (SupportedFeatures::get()->is_aten) {
98-
GTEST_SKIP() << "ATen currently supports empty input or out";
99-
}
100-
101-
TensorFactory<ScalarType::Float> tff;
102-
103-
Tensor x = tff.make({2, 2, 0}, {});
104-
105-
// Make an empty out tensor and demonstrate that it's empty.
106-
Tensor out = tff.make({2, 2, 0}, {});
107-
108-
EXPECT_EQ(out.numel(), 0);
109-
110-
ET_EXPECT_KERNEL_FAILURE(
111-
op_log_softmax_out(x, /*dim=*/1, /*half_to_float*/ false, out));
112-
}
113-
11496
TEST(OpLogSoftmaxOutTest, MismatchedDimensionsDies) {
11597
if (SupportedFeatures::get()->is_aten) {
11698
GTEST_SKIP() << "ATen currently supports mismatched dimensions";

0 commit comments

Comments
 (0)