Skip to content

Commit f7bb487

Browse files
xiaolil1tye1gujinghui
authored
Rebase with supporting int32 index to align with latest PyTorch index operator definition (#3749) (#3808)
* Enable int32 index tensor for index operator --------- Co-authored-by: Ye Ting <[email protected]> Co-authored-by: Jinghui <[email protected]>
1 parent c19d733 commit f7bb487

File tree

2 files changed

+251
-78
lines changed

2 files changed

+251
-78
lines changed

csrc/gpu/aten/operators/Indexing.h

Lines changed: 242 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ static inline void launch_index_kernel(IdxConfig& cfg) {
346346
auto& queue = dpcppGetCurrentQueue();
347347
auto cgf = DPCPP_Q_CGF(__cgh) {
348348
IndexKernel<IdxConfig, TrivialOffCal, known_problem_inner> idx_ker(cfg);
349-
__cgh.parallel_for(
349+
__cgh.parallel_for<decltype(idx_ker)>(
350350
sycl::nd_range<2>(cfg.global_size(), cfg.group_size()), idx_ker);
351351
};
352352
DPCPP_Q_SUBMIT(queue, cgf);
@@ -510,6 +510,137 @@ static inline void _index_select_kernel(
510510
}
511511
}
512512

513+
template <typename func_t, typename index_buf_type>
514+
struct DpcppSmallIndexKernelImplFunctor {
515+
void operator()(sycl::nd_item<1> item_id) const {
516+
auto local_id = item_id.get_local_id(0);
517+
auto group_id = item_id.get_group(0);
518+
519+
// construct a indices_size table on SLM
520+
for (int64_t local_index = local_id; local_index < indices_size;
521+
local_index += wgroup_size) {
522+
int64_t offset = 0;
523+
for (size_t i = 0; i < num_indices; i++) {
524+
// handle int32 index tensor according to the indice_size_bytes.
525+
// we didn't use template parametor to avoid too many kernels' creation
526+
// with numbers of input datatypes.
527+
if (indice_size_bytes == 4) {
528+
int32_t index =
529+
*(int32_t*)(index_ptrs[i] + local_index * indice_size_bytes);
530+
SYCL_KERNEL_ASSERT(
531+
index >= -sizes[i] && index < sizes[i] && "index out of bounds");
532+
if (index < 0) {
533+
index += sizes[i];
534+
}
535+
offset += index * strides[i];
536+
} else {
537+
int64_t index =
538+
*(int64_t*)(index_ptrs[i] + local_index * indice_size_bytes);
539+
SYCL_KERNEL_ASSERT(
540+
index >= -sizes[i] && index < sizes[i] && "index out of bounds");
541+
if (index < 0) {
542+
index += sizes[i];
543+
}
544+
offset += index * strides[i];
545+
}
546+
}
547+
local_offset[local_index] = offset;
548+
}
549+
550+
// calculate the number of workloads on each group
551+
auto group_linear_id = group_id * group_numel;
552+
auto group_numel_range = group_numel;
553+
if (group_num_tail && group_id >= group_num) {
554+
group_linear_id =
555+
group_num * group_numel + (group_id - group_num) * group_numel_tail;
556+
group_numel_range = group_numel_tail;
557+
}
558+
auto out_ptr = out_data;
559+
auto in_ptr = in_data;
560+
item_id.barrier(sycl::access::fence_space::local_space);
561+
562+
// compute the in/out/indices offsets and perform memory copy
563+
for (int64_t local_index = local_id; local_index < group_numel_range;
564+
local_index += wgroup_size) {
565+
auto linear_id = group_linear_id + local_index;
566+
auto out_offset = linear_id * element_size_bytes;
567+
auto src_linear_id = linear_id / indices_size;
568+
int64_t in_offset = 0;
569+
for (int i = num_non_indices - 1; i > 0; --i) {
570+
in_offset += (src_linear_id % src_sizes[i]) * src_strides[i];
571+
src_linear_id /= src_sizes[i];
572+
}
573+
in_offset += src_linear_id * src_strides0;
574+
575+
auto offset = local_offset[local_index % indices_size];
576+
f(out_ptr + out_offset, in_ptr + in_offset, offset);
577+
}
578+
}
579+
DpcppSmallIndexKernelImplFunctor(
580+
const func_t f_,
581+
int64_t indices_size_,
582+
int64_t group_num_tail_,
583+
int64_t group_num_,
584+
int64_t group_numel_,
585+
int64_t group_numel_tail_,
586+
int64_t wgroup_size_,
587+
size_t num_non_indices_,
588+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> src_sizes_,
589+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> src_strides_,
590+
int64_t src_strides0_,
591+
size_t num_indices_,
592+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> sizes_,
593+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> strides_,
594+
int64_t element_size_bytes_,
595+
int64_t indice_size_bytes_,
596+
char* out_data_,
597+
char* in_data_,
598+
at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs_,
599+
dpcpp_local_acc_t<int64_t, 1> local_offset_)
600+
: f(f_),
601+
indices_size(indices_size_),
602+
group_num_tail(group_num_tail_),
603+
group_num(group_num_),
604+
group_numel(group_numel_),
605+
group_numel_tail(group_numel_tail_),
606+
wgroup_size(wgroup_size_),
607+
num_non_indices(num_non_indices_),
608+
src_sizes(src_sizes_),
609+
src_strides(src_strides_),
610+
src_strides0(src_strides0_),
611+
num_indices(num_indices_),
612+
sizes(sizes_),
613+
strides(strides_),
614+
element_size_bytes(element_size_bytes_),
615+
indice_size_bytes(indice_size_bytes_),
616+
out_data(out_data_),
617+
in_data(in_data_),
618+
index_ptrs(index_ptrs_),
619+
local_offset(local_offset_) {}
620+
621+
private:
622+
const func_t f;
623+
int64_t indices_size;
624+
int64_t group_num_tail;
625+
int64_t group_num;
626+
int64_t group_numel;
627+
int64_t group_numel_tail;
628+
int64_t wgroup_size;
629+
size_t num_non_indices;
630+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> src_sizes;
631+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> src_strides;
632+
int64_t src_strides0;
633+
size_t num_indices;
634+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> sizes;
635+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> strides;
636+
int64_t element_size_bytes;
637+
int64_t indice_size_bytes;
638+
char* out_data;
639+
char* in_data;
640+
at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs;
641+
dpcpp_local_acc_t<int64_t, 1> local_offset;
642+
};
643+
513644
// DPCPP suggest: it’s possible (and even desirable) to oversubscribe tasks to
514645
// device;
515646
constexpr int OVER_SUBSCRIBE_DSS_FACTOR = 16;
@@ -572,64 +703,103 @@ void dpcpp_small_index_kernel_impl(
572703
}
573704

574705
dpcpp_local_acc_t<int64_t, 1> local_offset(indices_size, __cgh);
575-
auto kfn = DPCPP_Q_KFN(sycl::nd_item<1> item_id) {
576-
auto local_id = item_id.get_local_id(0);
577-
auto group_id = item_id.get_group(0);
578-
579-
// construct a indices_size table on SLM
580-
for (int64_t local_index = local_id; local_index < indices_size;
581-
local_index += wgroup_size) {
582-
int64_t offset = 0;
583-
for (size_t i = 0; i < num_indices; i++) {
584-
int64_t index =
585-
*(int64_t*)(index_ptrs[i] + local_index * indice_size_bytes);
586-
SYCL_KERNEL_ASSERT(
587-
index >= -sizes[i] && index < sizes[i] && "index out of bounds");
588-
if (index < 0) {
589-
index += sizes[i];
590-
}
591-
offset += index * strides[i];
592-
}
593-
local_offset[local_index] = offset;
594-
}
595-
596-
// calculate the number of workloads on each group
597-
auto group_linear_id = group_id * group_numel;
598-
auto group_numel_range = group_numel;
599-
if (group_num_tail && group_id >= group_num) {
600-
group_linear_id =
601-
group_num * group_numel + (group_id - group_num) * group_numel_tail;
602-
group_numel_range = group_numel_tail;
603-
}
604-
auto out_ptr = out_data;
605-
auto in_ptr = in_data;
606-
item_id.barrier(sycl::access::fence_space::local_space);
607-
608-
// compute the in/out/indices offsets and perform memory copy
609-
for (int64_t local_index = local_id; local_index < group_numel_range;
610-
local_index += wgroup_size) {
611-
auto linear_id = group_linear_id + local_index;
612-
auto out_offset = linear_id * element_size_bytes;
613-
auto src_linear_id = linear_id / indices_size;
614-
int64_t in_offset = 0;
615-
for (int i = num_non_indices - 1; i > 0; --i) {
616-
in_offset += (src_linear_id % src_sizes[i]) * src_strides[i];
617-
src_linear_id /= src_sizes[i];
618-
}
619-
in_offset += src_linear_id * src_strides0;
620-
621-
auto offset = local_offset[local_index % indices_size];
622-
f(out_ptr + out_offset, in_ptr + in_offset, offset);
623-
}
624-
};
625-
__cgh.parallel_for(
706+
DpcppSmallIndexKernelImplFunctor<func_t, index_buf_type> kfn(
707+
f,
708+
indices_size,
709+
group_num_tail,
710+
group_num,
711+
group_numel,
712+
group_numel_tail,
713+
wgroup_size,
714+
num_non_indices,
715+
src_sizes,
716+
src_strides,
717+
src_strides0,
718+
num_indices,
719+
sizes,
720+
strides,
721+
element_size_bytes,
722+
indice_size_bytes,
723+
out_data,
724+
in_data,
725+
index_ptrs,
726+
local_offset);
727+
__cgh.parallel_for<decltype(kfn)>(
626728
sycl::nd_range<1>(
627729
sycl::range<1>(global_size), sycl::range<1>(wgroup_size)),
628730
kfn);
629731
};
630732
DPCPP_Q_SUBMIT(dpcpp_queue, cgf);
631733
}
632734

735+
template <
736+
typename func_t,
737+
typename index_buf_type,
738+
typename OffsetCalculatorType>
739+
struct DpcppIndexKernelImplFunctor {
740+
void operator()(sycl::item<1> item_id) const {
741+
auto linear_idx = item_id.get_linear_id();
742+
auto offsets = offset_calc.get(linear_idx);
743+
auto out_ptr = out_data + offsets[0];
744+
auto in_ptr = in_data + offsets[1];
745+
int64_t offset = 0;
746+
//#pragma unroll
747+
for (size_t i = 0; i < num_indices; i++) {
748+
// handle int32 index tensor according to the indice_size_bytes.
749+
// we didn't use template parametor to avoid too many kernels' creation
750+
// with numbers of input datatypes.
751+
if (indice_size_bytes == 4) {
752+
int32_t index = *(int32_t*)(index_ptrs[i] + offsets[2]);
753+
SYCL_KERNEL_ASSERT(
754+
index >= -sizes[i] && index < sizes[i] && "index out of bounds");
755+
if (index < 0) {
756+
index += sizes[i];
757+
}
758+
offset += index * strides[i];
759+
} else {
760+
int64_t index = *(int64_t*)(index_ptrs[i] + offsets[2]);
761+
SYCL_KERNEL_ASSERT(
762+
index >= -sizes[i] && index < sizes[i] && "index out of bounds");
763+
if (index < 0) {
764+
index += sizes[i];
765+
}
766+
offset += index * strides[i];
767+
}
768+
}
769+
f(out_ptr, in_ptr, offset);
770+
}
771+
DpcppIndexKernelImplFunctor(
772+
const func_t f_,
773+
OffsetCalculatorType offset_calc_,
774+
int64_t indice_size_bytes_,
775+
char* out_data_,
776+
char* in_data_,
777+
size_t num_indices_,
778+
at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs_,
779+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> sizes_,
780+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> strides_)
781+
: f(f_),
782+
offset_calc(offset_calc_),
783+
indice_size_bytes(indice_size_bytes_),
784+
out_data(out_data_),
785+
in_data(in_data_),
786+
num_indices(num_indices_),
787+
index_ptrs(index_ptrs_),
788+
sizes(sizes_),
789+
strides(strides_) {}
790+
791+
private:
792+
const func_t f;
793+
OffsetCalculatorType offset_calc;
794+
int64_t indice_size_bytes;
795+
char* out_data;
796+
char* in_data;
797+
size_t num_indices;
798+
at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs;
799+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> sizes;
800+
at::detail::Array<int64_t, MAX_TENSORINFO_DIMS> strides;
801+
};
802+
633803
template <typename func_t>
634804
void dpcpp_index_kernel_impl(
635805
TensorIterator& iter,
@@ -645,6 +815,8 @@ void dpcpp_index_kernel_impl(
645815
strides[i] = index_stride[i];
646816
}
647817

818+
int64_t indice_size_bytes = iter.tensor(2).element_size();
819+
648820
auto& dpcpp_queue = dpcppGetCurrentQueue();
649821

650822
auto cgf = DPCPP_Q_CGF(__cgh) {
@@ -657,25 +829,17 @@ void dpcpp_index_kernel_impl(
657829
}
658830

659831
auto offset_calc = make_offset_calculator<3>(iter);
660-
auto kfn = DPCPP_Q_KFN(sycl::item<1> item_id) {
661-
auto linear_idx = item_id.get_linear_id();
662-
auto offsets = offset_calc.get(linear_idx);
663-
auto out_ptr = out_data + offsets[0];
664-
auto in_ptr = in_data + offsets[1];
665-
int64_t offset = 0;
666-
//#pragma unroll
667-
for (size_t i = 0; i < num_indices; i++) {
668-
int64_t index = *(int64_t*)(index_ptrs[i] + offsets[2]);
669-
SYCL_KERNEL_ASSERT(
670-
index >= -sizes[i] && index < sizes[i] && "index out of bounds");
671-
if (index < 0) {
672-
index += sizes[i];
673-
}
674-
offset += index * strides[i];
675-
}
676-
f(out_ptr, in_ptr, offset);
677-
};
678-
__cgh.parallel_for(sycl::range</*dim=*/1>(numel), kfn);
832+
DpcppIndexKernelImplFunctor<func_t, index_buf_type, decltype(offset_calc)>
833+
kfn(f,
834+
offset_calc,
835+
indice_size_bytes,
836+
out_data,
837+
in_data,
838+
num_indices,
839+
index_ptrs,
840+
sizes,
841+
strides);
842+
__cgh.parallel_for<decltype(kfn)>(sycl::range</*dim=*/1>(numel), kfn);
679843
};
680844
DPCPP_Q_SUBMIT(dpcpp_queue, cgf);
681845
}
@@ -700,19 +864,19 @@ void dpcpp_index_kernel(
700864
num_indices == static_cast<size_t>(iter.ntensors()) - 2);
701865
TORCH_INTERNAL_ASSERT(num_indices <= MAX_TENSORINFO_DIMS);
702866

703-
// the dpcpp_small_index_kernel_impl is applied for last several successive
704-
// dims indexing of an input tensor Taking 3-dims tensor input
867+
// the dpcpp_small_index_kernel_impl is applied for last several
868+
// successive dims indexing of an input tensor Taking 3-dims tensor input
705869
// (input.shape=[x,y,z]) for example: input[:,:,idx] or input[:,idx1,idx2]
706870
// when input tensor satisfies the following conditions, the
707-
// small_index_kernel path will be selected: 1.there are common indices such
708-
// as input[:,:,idx] and input[:,idx1,idx2] instead of
871+
// small_index_kernel path will be selected: 1.there are common indices
872+
// such as input[:,:,idx] and input[:,idx1,idx2] instead of
709873
// input[idx0,idx1,idx2], input[idx0,idx1,:], input[idx0,:,idx2],
710874
// input[idx0,:,:], input[:,idx1,:]
711875
// 2.the common indices numel should larger than 2 times of the
712876
// dpcppMaxComputeUnitSize (then we can get memory access benifit) 3.the
713-
// workloads in each group should larger than the maximum number of workitem
714-
// (ensure all the workitem activate) 4.the indices_table size should
715-
// satisfied the SLM limit condition
877+
// workloads in each group should larger than the maximum number of
878+
// workitem (ensure all the workitem activate) 4.the indices_table size
879+
// should satisfied the SLM limit condition
716880

717881
// check whether the current case satisfying the condition 1
718882
// for 3-dims input:

tests/gpu/regression/test_indexing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ def test_index_small_shape(self, dtype=torch.float):
6767
result_cpu = vertices_cpu[vert_filter_cpu_rand]
6868
result_xpu = vertices_xpu[vert_filter_xpu_rand]
6969
self.assertEqual(result_cpu, result_xpu.cpu())
70+
71+
def test_index_int32(self):
72+
probs = torch.ones((256, 50272), dtype=torch.float32)
73+
indice = torch.range(0, 255, dtype=torch.int32)
74+
75+
out_cpu = probs[indice]
76+
out_xpu = probs.xpu()[indice.xpu()]
77+
78+
self.assertEqual(out_xpu.to("cpu"), out_cpu)

0 commit comments

Comments
 (0)