@@ -346,7 +346,7 @@ static inline void launch_index_kernel(IdxConfig& cfg) {
346346 auto & queue = dpcppGetCurrentQueue ();
347347 auto cgf = DPCPP_Q_CGF (__cgh) {
348348 IndexKernel<IdxConfig, TrivialOffCal, known_problem_inner> idx_ker (cfg);
349- __cgh.parallel_for (
349+ __cgh.parallel_for < decltype (idx_ker)> (
350350 sycl::nd_range<2 >(cfg.global_size (), cfg.group_size ()), idx_ker);
351351 };
352352 DPCPP_Q_SUBMIT (queue, cgf);
@@ -510,6 +510,137 @@ static inline void _index_select_kernel(
510510 }
511511}
512512
513+ template <typename func_t , typename index_buf_type>
514+ struct DpcppSmallIndexKernelImplFunctor {
515+ void operator ()(sycl::nd_item<1 > item_id) const {
516+ auto local_id = item_id.get_local_id (0 );
517+ auto group_id = item_id.get_group (0 );
518+
519+ // construct a indices_size table on SLM
520+ for (int64_t local_index = local_id; local_index < indices_size;
521+ local_index += wgroup_size) {
522+ int64_t offset = 0 ;
523+ for (size_t i = 0 ; i < num_indices; i++) {
524+ // handle int32 index tensor according to the indice_size_bytes.
525+ // we didn't use template parametor to avoid too many kernels' creation
526+ // with numbers of input datatypes.
527+ if (indice_size_bytes == 4 ) {
528+ int32_t index =
529+ *(int32_t *)(index_ptrs[i] + local_index * indice_size_bytes);
530+ SYCL_KERNEL_ASSERT (
531+ index >= -sizes[i] && index < sizes[i] && " index out of bounds" );
532+ if (index < 0 ) {
533+ index += sizes[i];
534+ }
535+ offset += index * strides[i];
536+ } else {
537+ int64_t index =
538+ *(int64_t *)(index_ptrs[i] + local_index * indice_size_bytes);
539+ SYCL_KERNEL_ASSERT (
540+ index >= -sizes[i] && index < sizes[i] && " index out of bounds" );
541+ if (index < 0 ) {
542+ index += sizes[i];
543+ }
544+ offset += index * strides[i];
545+ }
546+ }
547+ local_offset[local_index] = offset;
548+ }
549+
550+ // calculate the number of workloads on each group
551+ auto group_linear_id = group_id * group_numel;
552+ auto group_numel_range = group_numel;
553+ if (group_num_tail && group_id >= group_num) {
554+ group_linear_id =
555+ group_num * group_numel + (group_id - group_num) * group_numel_tail;
556+ group_numel_range = group_numel_tail;
557+ }
558+ auto out_ptr = out_data;
559+ auto in_ptr = in_data;
560+ item_id.barrier (sycl::access::fence_space::local_space);
561+
562+ // compute the in/out/indices offsets and perform memory copy
563+ for (int64_t local_index = local_id; local_index < group_numel_range;
564+ local_index += wgroup_size) {
565+ auto linear_id = group_linear_id + local_index;
566+ auto out_offset = linear_id * element_size_bytes;
567+ auto src_linear_id = linear_id / indices_size;
568+ int64_t in_offset = 0 ;
569+ for (int i = num_non_indices - 1 ; i > 0 ; --i) {
570+ in_offset += (src_linear_id % src_sizes[i]) * src_strides[i];
571+ src_linear_id /= src_sizes[i];
572+ }
573+ in_offset += src_linear_id * src_strides0;
574+
575+ auto offset = local_offset[local_index % indices_size];
576+ f (out_ptr + out_offset, in_ptr + in_offset, offset);
577+ }
578+ }
579+ DpcppSmallIndexKernelImplFunctor (
580+ const func_t f_,
581+ int64_t indices_size_,
582+ int64_t group_num_tail_,
583+ int64_t group_num_,
584+ int64_t group_numel_,
585+ int64_t group_numel_tail_,
586+ int64_t wgroup_size_,
587+ size_t num_non_indices_,
588+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> src_sizes_,
589+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> src_strides_,
590+ int64_t src_strides0_,
591+ size_t num_indices_,
592+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> sizes_,
593+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> strides_,
594+ int64_t element_size_bytes_,
595+ int64_t indice_size_bytes_,
596+ char * out_data_,
597+ char * in_data_,
598+ at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs_,
599+ dpcpp_local_acc_t <int64_t , 1 > local_offset_)
600+ : f(f_),
601+ indices_size (indices_size_),
602+ group_num_tail(group_num_tail_),
603+ group_num(group_num_),
604+ group_numel(group_numel_),
605+ group_numel_tail(group_numel_tail_),
606+ wgroup_size(wgroup_size_),
607+ num_non_indices(num_non_indices_),
608+ src_sizes(src_sizes_),
609+ src_strides(src_strides_),
610+ src_strides0(src_strides0_),
611+ num_indices(num_indices_),
612+ sizes(sizes_),
613+ strides(strides_),
614+ element_size_bytes(element_size_bytes_),
615+ indice_size_bytes(indice_size_bytes_),
616+ out_data(out_data_),
617+ in_data(in_data_),
618+ index_ptrs(index_ptrs_),
619+ local_offset(local_offset_) {}
620+
621+ private:
622+ const func_t f;
623+ int64_t indices_size;
624+ int64_t group_num_tail;
625+ int64_t group_num;
626+ int64_t group_numel;
627+ int64_t group_numel_tail;
628+ int64_t wgroup_size;
629+ size_t num_non_indices;
630+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> src_sizes;
631+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> src_strides;
632+ int64_t src_strides0;
633+ size_t num_indices;
634+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> sizes;
635+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> strides;
636+ int64_t element_size_bytes;
637+ int64_t indice_size_bytes;
638+ char * out_data;
639+ char * in_data;
640+ at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs;
641+ dpcpp_local_acc_t <int64_t , 1 > local_offset;
642+ };
643+
513644// DPCPP suggest: it’s possible (and even desirable) to oversubscribe tasks to
514645// device;
515646constexpr int OVER_SUBSCRIBE_DSS_FACTOR = 16 ;
@@ -572,64 +703,103 @@ void dpcpp_small_index_kernel_impl(
572703 }
573704
574705 dpcpp_local_acc_t <int64_t , 1 > local_offset (indices_size, __cgh);
575- auto kfn = DPCPP_Q_KFN (sycl::nd_item<1 > item_id) {
576- auto local_id = item_id.get_local_id (0 );
577- auto group_id = item_id.get_group (0 );
578-
579- // construct a indices_size table on SLM
580- for (int64_t local_index = local_id; local_index < indices_size;
581- local_index += wgroup_size) {
582- int64_t offset = 0 ;
583- for (size_t i = 0 ; i < num_indices; i++) {
584- int64_t index =
585- *(int64_t *)(index_ptrs[i] + local_index * indice_size_bytes);
586- SYCL_KERNEL_ASSERT (
587- index >= -sizes[i] && index < sizes[i] && " index out of bounds" );
588- if (index < 0 ) {
589- index += sizes[i];
590- }
591- offset += index * strides[i];
592- }
593- local_offset[local_index] = offset;
594- }
595-
596- // calculate the number of workloads on each group
597- auto group_linear_id = group_id * group_numel;
598- auto group_numel_range = group_numel;
599- if (group_num_tail && group_id >= group_num) {
600- group_linear_id =
601- group_num * group_numel + (group_id - group_num) * group_numel_tail;
602- group_numel_range = group_numel_tail;
603- }
604- auto out_ptr = out_data;
605- auto in_ptr = in_data;
606- item_id.barrier (sycl::access::fence_space::local_space);
607-
608- // compute the in/out/indices offsets and perform memory copy
609- for (int64_t local_index = local_id; local_index < group_numel_range;
610- local_index += wgroup_size) {
611- auto linear_id = group_linear_id + local_index;
612- auto out_offset = linear_id * element_size_bytes;
613- auto src_linear_id = linear_id / indices_size;
614- int64_t in_offset = 0 ;
615- for (int i = num_non_indices - 1 ; i > 0 ; --i) {
616- in_offset += (src_linear_id % src_sizes[i]) * src_strides[i];
617- src_linear_id /= src_sizes[i];
618- }
619- in_offset += src_linear_id * src_strides0;
620-
621- auto offset = local_offset[local_index % indices_size];
622- f (out_ptr + out_offset, in_ptr + in_offset, offset);
623- }
624- };
625- __cgh.parallel_for (
706+ DpcppSmallIndexKernelImplFunctor<func_t , index_buf_type> kfn (
707+ f,
708+ indices_size,
709+ group_num_tail,
710+ group_num,
711+ group_numel,
712+ group_numel_tail,
713+ wgroup_size,
714+ num_non_indices,
715+ src_sizes,
716+ src_strides,
717+ src_strides0,
718+ num_indices,
719+ sizes,
720+ strides,
721+ element_size_bytes,
722+ indice_size_bytes,
723+ out_data,
724+ in_data,
725+ index_ptrs,
726+ local_offset);
727+ __cgh.parallel_for <decltype (kfn)>(
626728 sycl::nd_range<1 >(
627729 sycl::range<1 >(global_size), sycl::range<1 >(wgroup_size)),
628730 kfn);
629731 };
630732 DPCPP_Q_SUBMIT (dpcpp_queue, cgf);
631733}
632734
735+ template <
736+ typename func_t ,
737+ typename index_buf_type,
738+ typename OffsetCalculatorType>
739+ struct DpcppIndexKernelImplFunctor {
740+ void operator ()(sycl::item<1 > item_id) const {
741+ auto linear_idx = item_id.get_linear_id ();
742+ auto offsets = offset_calc.get (linear_idx);
743+ auto out_ptr = out_data + offsets[0 ];
744+ auto in_ptr = in_data + offsets[1 ];
745+ int64_t offset = 0 ;
746+ // #pragma unroll
747+ for (size_t i = 0 ; i < num_indices; i++) {
748+ // handle int32 index tensor according to the indice_size_bytes.
749+ // we didn't use template parametor to avoid too many kernels' creation
750+ // with numbers of input datatypes.
751+ if (indice_size_bytes == 4 ) {
752+ int32_t index = *(int32_t *)(index_ptrs[i] + offsets[2 ]);
753+ SYCL_KERNEL_ASSERT (
754+ index >= -sizes[i] && index < sizes[i] && " index out of bounds" );
755+ if (index < 0 ) {
756+ index += sizes[i];
757+ }
758+ offset += index * strides[i];
759+ } else {
760+ int64_t index = *(int64_t *)(index_ptrs[i] + offsets[2 ]);
761+ SYCL_KERNEL_ASSERT (
762+ index >= -sizes[i] && index < sizes[i] && " index out of bounds" );
763+ if (index < 0 ) {
764+ index += sizes[i];
765+ }
766+ offset += index * strides[i];
767+ }
768+ }
769+ f (out_ptr, in_ptr, offset);
770+ }
771+ DpcppIndexKernelImplFunctor (
772+ const func_t f_,
773+ OffsetCalculatorType offset_calc_,
774+ int64_t indice_size_bytes_,
775+ char * out_data_,
776+ char * in_data_,
777+ size_t num_indices_,
778+ at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs_,
779+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> sizes_,
780+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> strides_)
781+ : f(f_),
782+ offset_calc (offset_calc_),
783+ indice_size_bytes(indice_size_bytes_),
784+ out_data(out_data_),
785+ in_data(in_data_),
786+ num_indices(num_indices_),
787+ index_ptrs(index_ptrs_),
788+ sizes(sizes_),
789+ strides(strides_) {}
790+
791+ private:
792+ const func_t f;
793+ OffsetCalculatorType offset_calc;
794+ int64_t indice_size_bytes;
795+ char * out_data;
796+ char * in_data;
797+ size_t num_indices;
798+ at::detail::Array<index_buf_type, MAX_TENSORINFO_DIMS> index_ptrs;
799+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> sizes;
800+ at::detail::Array<int64_t , MAX_TENSORINFO_DIMS> strides;
801+ };
802+
633803template <typename func_t >
634804void dpcpp_index_kernel_impl (
635805 TensorIterator& iter,
@@ -645,6 +815,8 @@ void dpcpp_index_kernel_impl(
645815 strides[i] = index_stride[i];
646816 }
647817
818+ int64_t indice_size_bytes = iter.tensor (2 ).element_size ();
819+
648820 auto & dpcpp_queue = dpcppGetCurrentQueue ();
649821
650822 auto cgf = DPCPP_Q_CGF (__cgh) {
@@ -657,25 +829,17 @@ void dpcpp_index_kernel_impl(
657829 }
658830
659831 auto offset_calc = make_offset_calculator<3 >(iter);
660- auto kfn = DPCPP_Q_KFN (sycl::item<1 > item_id) {
661- auto linear_idx = item_id.get_linear_id ();
662- auto offsets = offset_calc.get (linear_idx);
663- auto out_ptr = out_data + offsets[0 ];
664- auto in_ptr = in_data + offsets[1 ];
665- int64_t offset = 0 ;
666- // #pragma unroll
667- for (size_t i = 0 ; i < num_indices; i++) {
668- int64_t index = *(int64_t *)(index_ptrs[i] + offsets[2 ]);
669- SYCL_KERNEL_ASSERT (
670- index >= -sizes[i] && index < sizes[i] && " index out of bounds" );
671- if (index < 0 ) {
672- index += sizes[i];
673- }
674- offset += index * strides[i];
675- }
676- f (out_ptr, in_ptr, offset);
677- };
678- __cgh.parallel_for (sycl::range</* dim=*/ 1 >(numel), kfn);
832+ DpcppIndexKernelImplFunctor<func_t , index_buf_type, decltype (offset_calc)>
833+ kfn (f,
834+ offset_calc,
835+ indice_size_bytes,
836+ out_data,
837+ in_data,
838+ num_indices,
839+ index_ptrs,
840+ sizes,
841+ strides);
842+ __cgh.parallel_for <decltype (kfn)>(sycl::range</* dim=*/ 1 >(numel), kfn);
679843 };
680844 DPCPP_Q_SUBMIT (dpcpp_queue, cgf);
681845}
@@ -700,19 +864,19 @@ void dpcpp_index_kernel(
700864 num_indices == static_cast <size_t >(iter.ntensors ()) - 2 );
701865 TORCH_INTERNAL_ASSERT (num_indices <= MAX_TENSORINFO_DIMS);
702866
703- // the dpcpp_small_index_kernel_impl is applied for last several successive
704- // dims indexing of an input tensor Taking 3-dims tensor input
867+ // the dpcpp_small_index_kernel_impl is applied for last several
868+ // successive dims indexing of an input tensor Taking 3-dims tensor input
705869 // (input.shape=[x,y,z]) for example: input[:,:,idx] or input[:,idx1,idx2]
706870 // when input tensor satisfies the following conditions, the
707- // small_index_kernel path will be selected: 1.there are common indices such
708- // as input[:,:,idx] and input[:,idx1,idx2] instead of
871+ // small_index_kernel path will be selected: 1.there are common indices
872+ // such as input[:,:,idx] and input[:,idx1,idx2] instead of
709873 // input[idx0,idx1,idx2], input[idx0,idx1,:], input[idx0,:,idx2],
710874 // input[idx0,:,:], input[:,idx1,:]
711875 // 2.the common indices numel should larger than 2 times of the
712876 // dpcppMaxComputeUnitSize (then we can get memory access benifit) 3.the
713- // workloads in each group should larger than the maximum number of workitem
714- // (ensure all the workitem activate) 4.the indices_table size should
715- // satisfied the SLM limit condition
877+ // workloads in each group should larger than the maximum number of
878+ // workitem (ensure all the workitem activate) 4.the indices_table size
879+ // should satisfied the SLM limit condition
716880
717881 // check whether the current case satisfying the condition 1
718882 // for 3-dims input:
0 commit comments