@@ -326,7 +326,7 @@ sycl::event inclusive_scan_iter_1d(sycl::queue &exec_q,
326326 std::vector<sycl::event> &host_tasks,
327327 const std::vector<sycl::event> &depends = {})
328328{
329- ScanOpT scan_op{};
329+ constexpr ScanOpT scan_op{};
330330 constexpr outputT identity = su_ns::Identity<ScanOpT, outputT>::value;
331331
332332 constexpr size_t _iter_nelems = 1 ;
@@ -436,8 +436,12 @@ sycl::event inclusive_scan_iter_1d(sycl::queue &exec_q,
436436 sycl::nd_range<1 > ndRange{gRange , lRange};
437437
438438 cgh.parallel_for <UpdateKernelName>(
439- ndRange, [chunk_size, src, src_size, local_scans, scan_op,
440- identity](sycl::nd_item<1 > ndit) {
439+ ndRange, [chunk_size, src, src_size,
440+ local_scans](sycl::nd_item<1 > ndit) {
441+ constexpr ScanOpT scan_op{};
442+ constexpr outputT identity =
443+ su_ns::Identity<ScanOpT, outputT>::value;
444+
441445 const std::uint32_t lws = ndit.get_local_range (0 );
442446 const size_t block_offset =
443447 ndit.get_group (0 ) * n_wi * lws;
@@ -447,11 +451,10 @@ sycl::event inclusive_scan_iter_1d(sycl::queue &exec_q,
447451 block_offset + ndit.get_local_id (0 ) + i * lws;
448452 if (src_id < src_size) {
449453 const size_t scan_id = (src_id / chunk_size);
450- src[src_id] =
451- (scan_id > 0 )
452- ? scan_op (src[src_id],
453- local_scans[scan_id - 1 ])
454- : scan_op (src[src_id], identity);
454+ const outputT modifier =
455+ (scan_id > 0 ) ? local_scans[scan_id - 1 ]
456+ : identity;
457+ src[src_id] = scan_op (src[src_id], modifier);
455458 }
456459 }
457460 });
@@ -561,7 +564,7 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
561564 std::vector<sycl::event> &host_tasks,
562565 const std::vector<sycl::event> &depends = {})
563566{
564- ScanOpT scan_op = ScanOpT () ;
567+ constexpr ScanOpT scan_op{} ;
565568 constexpr outputT identity = su_ns::Identity<ScanOpT, outputT>::value;
566569
567570 using IterIndexerT =
@@ -708,43 +711,44 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
708711 cgh.depends_on (dependent_event);
709712 cgh.use_kernel_bundle (kb);
710713
711- sycl::range<1 > gRange {iter_nelems * update_nelems};
712- sycl::range<1 > lRange{sg_size};
714+ sycl::range<2 > gRange {iter_nelems, update_nelems};
715+ sycl::range<2 > lRange{1 , sg_size};
713716
714- sycl::nd_range<1 > ndRange{gRange , lRange};
717+ sycl::nd_range<2 > ndRange{gRange , lRange};
715718
716719 cgh.parallel_for <UpdateKernelName>(
717- ndRange,
718- [chunk_size, update_nelems, src_size, local_stride, src,
719- local_scans, scan_op, identity](sycl::nd_item<1 > ndit) {
720- const size_t gr_id = ndit.get_group (0 );
720+ ndRange, [chunk_size, src_size, local_stride, src,
721+ local_scans](sycl::nd_item<2 > ndit) {
722+ constexpr ScanOpT scan_op{};
723+ constexpr outputT identity =
724+ su_ns::Identity<ScanOpT, outputT>::value;
721725
722- const size_t iter_gid = gr_id / update_nelems;
723- const size_t axis_gr_id =
724- gr_id - (iter_gid * update_nelems);
726+ const size_t iter_gid = ndit.get_group (0 );
727+ const size_t axis_gr_id = ndit.get_group (1 );
725728
726729 const std::uint32_t lws = ndit.get_local_range (0 );
727730
728731 const size_t src_axis_id0 =
729732 axis_gr_id * updates_per_wi * lws;
730733 const size_t src_iter_id = iter_gid * src_size;
734+ const size_t scan_id0 = iter_gid * local_stride;
731735#pragma unroll
732736 for (nwiT i = 0 ; i < updates_per_wi; ++i) {
733737 const size_t src_axis_id =
734738 src_axis_id0 + ndit.get_local_id (0 ) + i * lws;
735- const size_t src_id = src_axis_id + src_iter_id;
736739
737740 if (src_axis_id < src_size) {
738741 const size_t scan_axis_id =
739742 src_axis_id / chunk_size;
740- const size_t scan_id =
741- scan_axis_id + iter_gid * local_stride;
743+ const size_t scan_id = scan_axis_id + scan_id0;
742744
743- src[src_id] =
745+ const outputT modifier =
744746 (scan_axis_id > 0 )
745- ? scan_op (src[src_id],
746- local_scans[scan_id - 1 ])
747- : scan_op (src[src_id], identity);
747+ ? local_scans[scan_id - 1 ]
748+ : identity;
749+
750+ const size_t src_id = src_axis_id + src_iter_id;
751+ src[src_id] = scan_op (src[src_id], modifier);
748752 }
749753 }
750754 });
@@ -759,35 +763,55 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
759763 outputT *local_scans = stack_elem.get_local_scans_ptr ();
760764 const size_t local_stride = stack_elem.get_local_stride ();
761765
766+ using UpdateKernelName =
767+ class inclusive_scan_final_chunk_update_krn <
768+ inputT, outputT, n_wi, OutIterIndexerT, OutIndexerT,
769+ TransformerT, NoOpTransformerT, ScanOpT, include_initial>;
770+
771+ const auto &kernel_id = sycl::get_kernel_id<UpdateKernelName>();
772+
773+ auto const &ctx = exec_q.get_context ();
774+ auto const &dev = exec_q.get_device ();
775+ auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
776+ ctx, {dev}, {kernel_id});
777+
778+ auto krn = kb.get_kernel (kernel_id);
779+
780+ const std::uint32_t sg_size = krn.template get_info <
781+ sycl::info::kernel_device_specific::max_sub_group_size>(dev);
782+
762783 constexpr nwiT updates_per_wi = n_wi;
763784 const size_t update_nelems =
764- ceiling_quotient<size_t >(src_size, updates_per_wi);
785+ ceiling_quotient<size_t >(src_size, sg_size * updates_per_wi) *
786+ sg_size;
787+
788+ sycl::range<2 > gRange {iter_nelems, update_nelems};
789+ sycl::range<2 > lRange{1 , sg_size};
790+
791+ sycl::nd_range<2 > ndRange{gRange , lRange};
765792
766793 dependent_event = exec_q.submit ([&](sycl::handler &cgh) {
767794 cgh.depends_on (dependent_event);
768795
769- using UpdateKernelName =
770- class inclusive_scan_final_chunk_update_krn <
771- inputT, outputT, n_wi, OutIterIndexerT, OutIndexerT,
772- TransformerT, NoOpTransformerT, ScanOpT,
773- include_initial>;
774-
775796 cgh.parallel_for <UpdateKernelName>(
776- {iter_nelems * update_nelems},
777- [chunk_size, update_nelems, src_size, local_stride, src,
778- local_scans, scan_op, identity, out_iter_indexer,
779- out_indexer](auto wiid) {
780- const size_t gid = wiid[0 ];
797+ ndRange,
798+ [chunk_size, src_size, local_stride, src, local_scans,
799+ out_iter_indexer, out_indexer](sycl::nd_item<2 > ndit) {
800+ constexpr ScanOpT scan_op{};
801+ constexpr outputT identity =
802+ su_ns::Identity<ScanOpT, outputT>::value;
781803
782- const size_t iter_gid = gid / update_nelems;
783- const size_t axis_gid =
784- gid - (iter_gid * update_nelems);
804+ const std::uint32_t lws = ndit.get_local_range (1 );
785805
786- const size_t src_axis_id0 = axis_gid * updates_per_wi;
806+ const size_t iter_gid = ndit.get_group (0 );
807+
808+ const size_t src_axis_id0 =
809+ ndit.get_group (1 ) * updates_per_wi * lws +
810+ ndit.get_local_id (1 );
787811 const size_t src_iter_id = out_iter_indexer (iter_gid);
788812#pragma unroll
789813 for (nwiT i = 0 ; i < updates_per_wi; ++i) {
790- const size_t src_axis_id = src_axis_id0 + i;
814+ const size_t src_axis_id = src_axis_id0 + i * lws ;
791815 const size_t src_id =
792816 out_indexer (src_axis_id) + src_iter_id;
793817
@@ -797,11 +821,12 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
797821 const size_t scan_id =
798822 scan_axis_id + iter_gid * local_stride;
799823
800- src[src_id] =
824+ const outputT modifier =
801825 (scan_axis_id > 0 )
802- ? scan_op (src[src_id],
803- local_scans[scan_id - 1 ])
804- : scan_op (src[src_id], identity);
826+ ? local_scans[scan_id - 1 ]
827+ : identity;
828+
829+ src[src_id] = scan_op (src[src_id], modifier);
805830 }
806831 }
807832 });
0 commit comments