@@ -27,32 +27,38 @@ void check_and_update_select_copy_int_out_args(
27
27
int64_t dim,
28
28
int64_t index,
29
29
Tensor output) {
30
- // Support python-style negative indexing. E.g., for the shape {2, 3, 4},
31
- // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
32
-
33
- // The dim planed to be selected on shall exist in input
34
- ET_CHECK_MSG (
35
- dim >= -input. dim () && dim < input. dim (),
36
- " dim % " PRId64 " out of range [-%zd,%zd) " ,
37
- dim,
38
- input. dim (),
39
- input.dim ());
40
-
41
- // The index shall be valid in the given dimenson
42
- ET_CHECK_MSG (
43
- index >= - input.size ( dim) && index < input. size (dim),
44
- " index % " PRId64 " out of range [-%zd,%zd) at input.size( % " PRId64 " ) " ,
45
- index,
46
- input. size (dim),
47
- input.size ( dim),
48
- dim);
30
+ if (input. dim () == 0 ) {
31
+ ET_CHECK ( dim == 0 || dim == - 1 );
32
+ } else {
33
+ // Support python-style negative indexing. E.g., for the shape {2, 3, 4},
34
+ // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so
35
+ // on.
36
+
37
+ // The dim planed to be selected on shall exist in input
38
+ ET_CHECK_MSG (
39
+ dim >= - input.dim () && dim < input. dim (),
40
+ " dim % " PRId64 " out of range [-%zd,%zd) " ,
41
+ dim,
42
+ input. dim (),
43
+ input.dim ());
44
+
45
+ // Support python-style negative indexing
46
+ if (dim < 0 ) {
47
+ dim += input.dim ();
48
+ }
49
49
50
- // Support python-style negative indexing
51
- if (dim < 0 ) {
52
- dim += input.dim ();
53
- }
54
- if (index < 0 ) {
55
- index += input.size (dim);
50
+ // The index shall be valid in the given dimenson
51
+ ET_CHECK_MSG (
52
+ index >= -input.size (dim) && index < input.size (dim),
53
+ " index %" PRId64 " out of range [-%zd,%zd) at input.size( %" PRId64 " )" ,
54
+ index,
55
+ input.size (dim),
56
+ input.size (dim),
57
+ dim);
58
+
59
+ if (index < 0 ) {
60
+ index += input.size (dim);
61
+ }
56
62
}
57
63
58
64
// Input dtype shall match the output dtype.
@@ -71,7 +77,7 @@ void check_and_update_select_copy_int_out_args(
71
77
// - output.size(i) shall equal to input.size(i) if i < dim,
72
78
// - output.size(i) shall equal to input.size(i+1) if i >= dim
73
79
74
- for (size_t d = 0 ; d < input.dim () - 1 ; d++) {
80
+ for (ssize_t d = 0 ; d < input.dim () - 1 ; d++) {
75
81
if (d < dim) {
76
82
ET_CHECK_MSG (
77
83
input.size (d) == output.size (d),
0 commit comments