@@ -744,31 +744,58 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q,
744744 const indT1 *cumsum_data = reinterpret_cast <const indT1 *>(cumsum_cp);
745745 indT2 *indexes_data = reinterpret_cast <indT2 *>(indexes_cp);
746746
747+ constexpr std::size_t nominal_lws = 256u ;
748+ const std::size_t masked_extent = iter_size;
749+ const std::size_t lws = std::min (masked_extent, nominal_lws);
750+
751+ const std::size_t n_groups = (masked_extent + lws - 1 ) / lws;
752+ sycl::range<1 > gRange {n_groups * lws};
753+ sycl::range<1 > lRange{lws};
754+
755+ sycl::nd_range<1 > ndRange{gRange , lRange};
756+
747757 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
748758 cgh.depends_on (depends);
749- cgh.parallel_for <class non_zero_indexes_krn <indT1, indT2>>(
750- sycl::range<1 >(iter_size), [=](sycl::id<1 > idx) {
751- auto i = idx[0 ];
752759
753- auto cs_curr_val = cumsum_data[i] - 1 ;
754- auto cs_prev_val = (i > 0 ) ? cumsum_data[i - 1 ] : indT1 (0 );
755- bool cond = (cs_curr_val == cs_prev_val);
760+ const std::size_t lacc_size = std::min (lws, masked_extent) + 1 ;
761+ sycl::local_accessor<indT1, 1 > lacc (lacc_size, cgh);
762+
763+ using KernelName = class non_zero_indexes_krn <indT1, indT2>;
756764
765+ cgh.parallel_for <KernelName>(ndRange, [=](sycl::nd_item<1 > ndit) {
766+ const std::size_t group_i = ndit.get_group (0 );
767+ const std::uint32_t l_i = ndit.get_local_id (0 );
768+ const std::uint32_t lws = ndit.get_local_range (0 );
769+
770+ const std::size_t masked_block_start = group_i * lws;
771+
772+ for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
773+ const size_t offset = masked_block_start + i;
774+ lacc[i] = (offset == 0 ) ? indT1 (0 )
775+ : (offset - 1 < masked_extent)
776+ ? cumsum_data[offset - 1 ]
777+ : cumsum_data[masked_extent - 1 ] + 1 ;
778+ }
779+
780+ sycl::group_barrier (ndit.get_group ());
781+
782+ const std::size_t i = masked_block_start + l_i;
783+ const auto cs_val = lacc[l_i];
784+ const bool cond = (lacc[l_i + 1 ] == cs_val + 1 );
785+
786+ if (cond && (i < masked_extent)) {
757787 ssize_t i_ = static_cast <ssize_t >(i);
758788 for (int dim = nd; --dim > 0 ;) {
759- auto sd = mask_shape[dim];
760- ssize_t q = i_ / sd;
761- ssize_t r = (i_ - q * sd);
762- if (cond) {
763- indexes_data[cs_curr_val + dim * nz_elems] =
764- static_cast <indT2>(r);
765- }
789+ const auto sd = mask_shape[dim];
790+ const ssize_t q = i_ / sd;
791+ const ssize_t r = (i_ - q * sd);
792+ indexes_data[cs_val + dim * nz_elems] =
793+ static_cast <indT2>(r);
766794 i_ = q;
767795 }
768- if (cond) {
769- indexes_data[cs_curr_val] = static_cast <indT2>(i_);
770- }
771- });
796+ indexes_data[cs_val] = static_cast <indT2>(i_);
797+ }
798+ });
772799 });
773800
774801 return comp_ev;
0 commit comments