Skip to content

Commit 4647b75

Browse files
xiaolil1gujinghui
andauthored
enhance index_select for large index (#4108) (#4170)
Co-authored-by: Jinghui <[email protected]>
1 parent 8294aa3 commit 4647b75

File tree

2 files changed

+138
-36
lines changed

2 files changed

+138
-36
lines changed

csrc/gpu/aten/operators/Indexing.cpp

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,40 @@ void indexSelect(
5050
const Tensor& src,
5151
int dim,
5252
const Tensor& indices) {
53+
IPEX_DISPATCH_INDEX_TYPES(indices.scalar_type(), "indexSelect", [&] {
54+
TensorInfo<index_t, int64_t> indices_info =
55+
tensorInfoIfScalar(getTensorInfo<index_t, int64_t>(indices));
56+
indices_info.collapseDims();
57+
58+
TensorInfo<scalar_t, int64_t> dst_info =
59+
tensorInfoIfScalar(getTensorInfo<scalar_t, int64_t>(dst));
60+
TensorInfo<scalar_t, int64_t> src_info =
61+
tensorInfoIfScalar(getTensorInfo<scalar_t, int64_t>(src.contiguous()));
62+
int new_indexing_dim = src_info.collapseDims(dim);
63+
64+
if (dst.is_contiguous() && indices.is_contiguous())
65+
_index_select_kernel<
66+
decltype(src_info),
67+
decltype(dst_info),
68+
decltype(indices_info),
69+
/* TrivialOffCal */ true>(
70+
src_info, dst_info, indices_info, new_indexing_dim);
71+
else
72+
_index_select_kernel<
73+
decltype(src_info),
74+
decltype(dst_info),
75+
decltype(indices_info),
76+
/* TrivialOffCal */ false>(
77+
src_info, dst_info, indices_info, new_indexing_dim);
78+
});
79+
return;
80+
}
81+
template <typename scalar_t>
82+
void index_select_impl(
83+
const Tensor& dst,
84+
const Tensor& src,
85+
int dim,
86+
const Tensor& indices) {
5387
at::assert_no_internal_overlap(dst);
5488
at::assert_no_overlap(dst, src);
5589
at::assert_no_overlap(dst, indices);
@@ -86,45 +120,47 @@ void indexSelect(
86120
src.scalar_type() == dst.scalar_type(),
87121
"index_select(): Source and result must have the same scalar type");
88122

89-
IPEX_DISPATCH_INDEX_TYPES(indices.scalar_type(), "indexSelect", [&] {
90-
TensorInfo<index_t, int64_t> indices_info =
91-
tensorInfoIfScalar(getTensorInfo<index_t, int64_t>(indices));
92-
indices_info.collapseDims();
123+
auto new_size = src.sizes().vec();
93124

94-
auto new_size = src.sizes().vec();
125+
if (src.dim() > 0) {
126+
new_size[dim] = indices.numel();
127+
}
95128

96-
if (src.dim() > 0) {
97-
new_size[dim] = indices.numel();
98-
}
129+
at::native::resize_output(dst, new_size);
99130

100-
at::native::resize_output(dst, new_size);
131+
ptrdiff_t dst_num_elem = dst.numel();
132+
if (dst_num_elem == 0) {
133+
return;
134+
}
101135

102-
ptrdiff_t dst_num_elem = dst.numel();
103-
if (dst_num_elem == 0) {
104-
return;
136+
if (!canUse32BitIndexMath(dst)) {
137+
auto MaxInt32 = std::numeric_limits<int32_t>::max();
138+
int32_t iter_number = (dst_num_elem + MaxInt32 - 1) / MaxInt32;
139+
int64_t slice_offset = indices.numel() / iter_number;
140+
141+
int64_t start_id = 0;
142+
int64_t end_id = 0;
143+
for (int32_t i = 0; i < iter_number; i++) {
144+
start_id = 0 + slice_offset * i;
145+
end_id = start_id + slice_offset;
146+
if (end_id <= indices.numel()) {
147+
indexSelect<scalar_t>(
148+
dst.slice(dim, start_id, end_id),
149+
src,
150+
dim,
151+
indices.slice(0, start_id, end_id));
152+
} else {
153+
indexSelect<scalar_t>(
154+
dst.slice(dim, start_id), src, dim, indices.slice(0, start_id));
155+
}
105156
}
106-
107-
TensorInfo<scalar_t, int64_t> dst_info =
108-
tensorInfoIfScalar(getTensorInfo<scalar_t, int64_t>(dst));
109-
TensorInfo<scalar_t, int64_t> src_info =
110-
tensorInfoIfScalar(getTensorInfo<scalar_t, int64_t>(src.contiguous()));
111-
int new_indexing_dim = src_info.collapseDims(dim);
112-
113-
if (dst.is_contiguous() && indices.is_contiguous())
114-
_index_select_kernel<
115-
decltype(src_info),
116-
decltype(dst_info),
117-
decltype(indices_info),
118-
/* TrivialOffCal */ true>(
119-
src_info, dst_info, indices_info, new_indexing_dim);
120-
else
121-
_index_select_kernel<
122-
decltype(src_info),
123-
decltype(dst_info),
124-
decltype(indices_info),
125-
/* TrivialOffCal */ false>(
126-
src_info, dst_info, indices_info, new_indexing_dim);
127-
});
157+
if (end_id < indices.numel()) {
158+
indexSelect<scalar_t>(
159+
dst.slice(dim, end_id), src, dim, indices.slice(0, end_id));
160+
}
161+
} else {
162+
indexSelect<scalar_t>(dst, src, dim, indices);
163+
}
128164
return;
129165
}
130166

@@ -1439,8 +1475,8 @@ Tensor& index_select_out(
14391475
at::ScalarType::BFloat16,
14401476
at::ScalarType::Bool,
14411477
self.scalar_type(),
1442-
"indexSelect",
1443-
[=]() { impl::indexSelect<scalar_t>(out, self, dim, index); });
1478+
"index_select_impl",
1479+
[=]() { impl::index_select_impl<scalar_t>(out, self, dim, index); });
14441480
return out;
14451481
}
14461482

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS
3+
import intel_extension_for_pytorch # noqa F401
4+
import pytest
5+
6+
7+
class TestTorchMethod(TestCase):
8+
@pytest.mark.skipif(
9+
IS_WINDOWS, reason="Memory allocated by this case exceed Windows provide."
10+
)
11+
@pytest.mark.skipif(
12+
not torch.xpu.has_2d_block_array(),
13+
reason="Memory allocated by this case exceed ATSM provide.",
14+
)
15+
def test_index_select_large_1(self, dtype=torch.float):
16+
torch.xpu.synchronize()
17+
torch.xpu.empty_cache()
18+
src = torch.rand((256000, 128))
19+
20+
index = torch.randint(0, src.size(0), (32000000,))
21+
dst_cpu = src.index_select(0, index)
22+
# print("dst_cpu = ", dst_cpu)
23+
24+
src_xpu = src.to("xpu")
25+
index_xpu = index.to("xpu")
26+
dst_xpu = src_xpu.index_select(0, index_xpu)
27+
28+
# print("dst_xpu = ", dst_xpu)
29+
# print("diff = ", torch.max(abs(dst_xpu.cpu()-dst_cpu)))
30+
31+
self.assertEqual(dst_cpu, dst_xpu.cpu())
32+
del src_xpu
33+
del index_xpu
34+
del dst_xpu
35+
torch.xpu.synchronize()
36+
torch.xpu.empty_cache()
37+
38+
@pytest.mark.skipif(
39+
IS_WINDOWS, reason="Memory allocated by this case exceed Windows provide."
40+
)
41+
@pytest.mark.skipif(
42+
not torch.xpu.has_2d_block_array(),
43+
reason="Memory allocated by this case exceed ATSM provide.",
44+
)
45+
def test_index_select_large_2(self, dtype=torch.float):
46+
torch.xpu.synchronize()
47+
torch.xpu.empty_cache()
48+
src = torch.rand((512000, 128))
49+
50+
index = torch.randint(0, src.size(0), (20002185,))
51+
dst_cpu = src.index_select(0, index)
52+
# print("dst_cpu = ", dst_cpu)
53+
54+
src_xpu = src.to("xpu")
55+
index_xpu = index.to("xpu")
56+
dst_xpu = src_xpu.index_select(0, index_xpu)
57+
58+
# print("dst_xpu = ", dst_xpu)
59+
# print("diff = ", torch.max(abs(dst_xpu.cpu()-dst_cpu)))
60+
61+
self.assertEqual(dst_cpu, dst_xpu.cpu())
62+
del src_xpu
63+
del index_xpu
64+
del dst_xpu
65+
torch.xpu.synchronize()
66+
torch.xpu.empty_cache()

0 commit comments

Comments
 (0)