Skip to content

Commit 2fa44ea

Browse files
Enable fast path for negative indices (#15622)
Fast path was broken for negative indices (see #15285) Because of this, #15366 disabled the fast path when the index tensor had negative indices. In this PR we fix the bug, and re-enable the fast path for negative indices. Fixes #15285 Differential Revision: D86351194
1 parent 2bb8055 commit 2fa44ea

File tree

2 files changed

+61
-20
lines changed

2 files changed

+61
-20
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,6 @@ bool check_fast_path_conditions(
4949
if (index.dim() != 1) {
5050
return false;
5151
}
52-
53-
// Fast path only supports non-negative indices.
54-
if (ix_type == ScalarType::Int) {
55-
const int32_t* const data = index.const_data_ptr<int32_t>();
56-
if (std::any_of(data, data + index.numel(), [](const auto x) {
57-
return x < 0;
58-
})) {
59-
return false;
60-
}
61-
} else { // ScalarType::Long
62-
const int64_t* const data = index.const_data_ptr<int64_t>();
63-
if (std::any_of(data, data + index.numel(), [](const auto x) {
64-
return x < 0;
65-
})) {
66-
return false;
67-
}
68-
}
6952
}
7053
}
7154

@@ -96,8 +79,10 @@ bool check_fast_path_args(
9679
Long, Int, index.scalar_type(), ctx, "index.Tensor", CTYPE, [&]() {
9780
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
9881
for (const auto i : c10::irange(index.numel())) {
99-
if (index_arr[i] < 0 ||
100-
index_arr[i] >= static_cast<CTYPE>(in.size(dim))) {
82+
CTYPE index_val = index_arr[i];
83+
CTYPE dim_size = static_cast<CTYPE>(in.size(dim));
84+
index_val = index_val < 0 ? index_val + dim_size : index_val;
85+
if (index_val < 0 || index_val >= dim_size) {
10186
ET_LOG(
10287
Error,
10388
"Index %" PRId64
@@ -189,11 +174,14 @@ Tensor& fast_path(
189174

190175
ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
191176
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
177+
CTYPE dim_size = static_cast<CTYPE>(in.size(dim));
192178
for (const auto i : c10::irange(leading_dims)) {
193179
const char* src = in_data + i * in_dim_length * length_per_step;
194180
char* dest = out_data + i * out_dim_length * length_per_step;
195181
for (const auto j : c10::irange(out_dim_length)) {
196-
const char* copy_src = src + index_arr[j] * length_per_step;
182+
auto index_val =
183+
index_arr[j] < 0 ? index_arr[j] + dim_size : index_arr[j];
184+
const char* copy_src = src + index_val * length_per_step;
197185
char* copy_dest = dest + j * length_per_step;
198186
memcpy(copy_dest, copy_src, length_per_step);
199187
}

kernels/test/op_index_test.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,56 @@ TEST_F(OpIndexTensorOutTest, FastPathEmptyInput) {
947947

948948
EXPECT_TENSOR_EQ(out, expected);
949949
}
950+
951+
TEST_F(OpIndexTensorOutTest, FastPathNegativeIndex) {
952+
TensorFactory<ScalarType::Float> tf;
953+
TensorFactory<ScalarType::Long> tfl;
954+
955+
// clang-format off
956+
Tensor x = tf.make(
957+
{2, 3, 4},
958+
{
959+
// [0, :, :]
960+
1., 2., 3., 4., // [0, 0, :]
961+
5., 6., 7., 8., // [0, 1, :]
962+
9., 10., 11., 12., // [0, 2, :]
963+
964+
// [1, :, :]
965+
-1., -2., -3., -4., // [1, 0, :]
966+
-5., -6., -7., -8., // [1, 1, :]
967+
-9., -10., -11., -12., // [1, 2, :]
968+
});
969+
// clang-format on
970+
971+
// Use negative indices in the first dimension: -1, 0, -2
972+
std::array<optional<Tensor>, 3> indices = {
973+
optional<Tensor>(tfl.make({3}, {-1, 0, -2})),
974+
optional<Tensor>(),
975+
optional<Tensor>()};
976+
977+
Tensor out = tf.zeros({3, 3, 4});
978+
// clang-format off
979+
Tensor expected = tf.make(
980+
{3, 3, 4},
981+
{
982+
// [1, :, :]
983+
-1., -2., -3., -4., // [1, 0, :]
984+
-5., -6., -7., -8., // [1, 1, :]
985+
-9., -10., -11., -12., // [1, 2, :]
986+
987+
// [0, :, :]
988+
1., 2., 3., 4., // [0, 0, :]
989+
5., 6., 7., 8., // [0, 1, :]
990+
9., 10., 11., 12., // [0, 2, :]
991+
992+
// [0, :, :] again (since -2 wraps to 0)
993+
1., 2., 3., 4., // [0, 0, :]
994+
5., 6., 7., 8., // [0, 1, :]
995+
9., 10., 11., 12., // [0, 2, :]
996+
});
997+
// clang-format on
998+
999+
op_index_tensor_out(x, indices, out);
1000+
1001+
EXPECT_TENSOR_EQ(out, expected);
1002+
}

0 commit comments

Comments
 (0)