@@ -50,6 +50,40 @@ void indexSelect(
5050 const Tensor& src,
5151 int dim,
5252 const Tensor& indices) {
53+ IPEX_DISPATCH_INDEX_TYPES (indices.scalar_type (), " indexSelect" , [&] {
54+ TensorInfo<index_t , int64_t > indices_info =
55+ tensorInfoIfScalar (getTensorInfo<index_t , int64_t >(indices));
56+ indices_info.collapseDims ();
57+
58+ TensorInfo<scalar_t , int64_t > dst_info =
59+ tensorInfoIfScalar (getTensorInfo<scalar_t , int64_t >(dst));
60+ TensorInfo<scalar_t , int64_t > src_info =
61+ tensorInfoIfScalar (getTensorInfo<scalar_t , int64_t >(src.contiguous ()));
62+ int new_indexing_dim = src_info.collapseDims (dim);
63+
64+ if (dst.is_contiguous () && indices.is_contiguous ())
65+ _index_select_kernel<
66+ decltype (src_info),
67+ decltype (dst_info),
68+ decltype (indices_info),
69+ /* TrivialOffCal */ true >(
70+ src_info, dst_info, indices_info, new_indexing_dim);
71+ else
72+ _index_select_kernel<
73+ decltype (src_info),
74+ decltype (dst_info),
75+ decltype (indices_info),
76+ /* TrivialOffCal */ false >(
77+ src_info, dst_info, indices_info, new_indexing_dim);
78+ });
79+ return ;
80+ }
81+ template <typename scalar_t >
82+ void index_select_impl (
83+ const Tensor& dst,
84+ const Tensor& src,
85+ int dim,
86+ const Tensor& indices) {
5387 at::assert_no_internal_overlap (dst);
5488 at::assert_no_overlap (dst, src);
5589 at::assert_no_overlap (dst, indices);
@@ -86,45 +120,47 @@ void indexSelect(
86120 src.scalar_type () == dst.scalar_type (),
87121 " index_select(): Source and result must have the same scalar type" );
88122
89- IPEX_DISPATCH_INDEX_TYPES (indices.scalar_type (), " indexSelect" , [&] {
90- TensorInfo<index_t , int64_t > indices_info =
91- tensorInfoIfScalar (getTensorInfo<index_t , int64_t >(indices));
92- indices_info.collapseDims ();
123+ auto new_size = src.sizes ().vec ();
93124
94- auto new_size = src.sizes ().vec ();
125+ if (src.dim () > 0 ) {
126+ new_size[dim] = indices.numel ();
127+ }
95128
96- if (src.dim () > 0 ) {
97- new_size[dim] = indices.numel ();
98- }
129+ at::native::resize_output (dst, new_size);
99130
100- at::native::resize_output (dst, new_size);
131+ ptrdiff_t dst_num_elem = dst.numel ();
132+ if (dst_num_elem == 0 ) {
133+ return ;
134+ }
101135
102- ptrdiff_t dst_num_elem = dst.numel ();
103- if (dst_num_elem == 0 ) {
104- return ;
136+ if (!canUse32BitIndexMath (dst)) {
137+ auto MaxInt32 = std::numeric_limits<int32_t >::max ();
138+ int32_t iter_number = (dst_num_elem + MaxInt32 - 1 ) / MaxInt32;
139+ int64_t slice_offset = indices.numel () / iter_number;
140+
141+ int64_t start_id = 0 ;
142+ int64_t end_id = 0 ;
143+ for (int32_t i = 0 ; i < iter_number; i++) {
144+ start_id = 0 + slice_offset * i;
145+ end_id = start_id + slice_offset;
146+ if (end_id <= indices.numel ()) {
147+ indexSelect<scalar_t >(
148+ dst.slice (dim, start_id, end_id),
149+ src,
150+ dim,
151+ indices.slice (0 , start_id, end_id));
152+ } else {
153+ indexSelect<scalar_t >(
154+ dst.slice (dim, start_id), src, dim, indices.slice (0 , start_id));
155+ }
105156 }
106-
107- TensorInfo<scalar_t , int64_t > dst_info =
108- tensorInfoIfScalar (getTensorInfo<scalar_t , int64_t >(dst));
109- TensorInfo<scalar_t , int64_t > src_info =
110- tensorInfoIfScalar (getTensorInfo<scalar_t , int64_t >(src.contiguous ()));
111- int new_indexing_dim = src_info.collapseDims (dim);
112-
113- if (dst.is_contiguous () && indices.is_contiguous ())
114- _index_select_kernel<
115- decltype (src_info),
116- decltype (dst_info),
117- decltype (indices_info),
118- /* TrivialOffCal */ true >(
119- src_info, dst_info, indices_info, new_indexing_dim);
120- else
121- _index_select_kernel<
122- decltype (src_info),
123- decltype (dst_info),
124- decltype (indices_info),
125- /* TrivialOffCal */ false >(
126- src_info, dst_info, indices_info, new_indexing_dim);
127- });
157+ if (end_id < indices.numel ()) {
158+ indexSelect<scalar_t >(
159+ dst.slice (dim, end_id), src, dim, indices.slice (0 , end_id));
160+ }
161+ } else {
162+ indexSelect<scalar_t >(dst, src, dim, indices);
163+ }
128164 return ;
129165}
130166
@@ -1439,8 +1475,8 @@ Tensor& index_select_out(
14391475 at::ScalarType::BFloat16,
14401476 at::ScalarType::Bool,
14411477 self.scalar_type (),
1442- " indexSelect " ,
1443- [=]() { impl::indexSelect <scalar_t >(out, self, dim, index); });
1478+ " index_select_impl " ,
1479+ [=]() { impl::index_select_impl <scalar_t >(out, self, dim, index); });
14441480 return out;
14451481}
14461482
0 commit comments