@@ -134,51 +134,65 @@ template <typename OrthogIndexerT,
134134 typename MaskedDstIndexerT,
135135 typename MaskedRhsIndexerT,
136136 typename dataT,
137- typename indT>
137+ typename indT,
138+ typename LocalAccessorT>
138139struct MaskedPlaceStridedFunctor
139140{
140141 MaskedPlaceStridedFunctor (char *dst_data_p,
141142 const char *cumsum_data_p,
142143 const char *rhs_data_p,
143- size_t orthog_iter_size,
144144 size_t masked_iter_size,
145145 const OrthogIndexerT &orthog_dst_rhs_indexer_,
146146 const MaskedDstIndexerT &masked_dst_indexer_,
147- const MaskedRhsIndexerT &masked_rhs_indexer_)
147+ const MaskedRhsIndexerT &masked_rhs_indexer_,
148+ const LocalAccessorT &lacc_)
148149 : dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
149- orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
150+ masked_nelems (masked_iter_size),
150151 orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
151152 masked_dst_indexer(masked_dst_indexer_),
152- masked_rhs_indexer(masked_rhs_indexer_)
153+ masked_rhs_indexer(masked_rhs_indexer_), lacc(lacc_)
153154 {
155+ static_assert (
156+ std::is_same_v<indT, typename LocalAccessorT::value_type>);
154157 }
155158
156- void operator ()(sycl::id< 1 > idx ) const
159+ void operator ()(sycl::nd_item< 2 > ndit ) const
157160 {
158161 dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
159162 const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
160163 const dataT *rhs_data = reinterpret_cast <const dataT *>(rhs_cp);
161164
162- size_t global_i = idx[0 ];
163- size_t orthog_i = global_i / masked_nelems;
164- size_t masked_i = global_i - masked_nelems * orthog_i;
165-
166- indT current_running_count = cumsum_data[masked_i];
167- bool mask_set =
168- (masked_i == 0 )
169- ? (current_running_count == 1 )
170- : (current_running_count == cumsum_data[masked_i - 1 ] + 1 );
171-
172- // src[i, j] = rhs[cumsum[i] - 1, j] if cumsum[i] == ((i > 0) ?
173- // cumsum[i-1]
174- // + 1 : 1)
175- if (mask_set) {
176- auto orthog_offsets =
177- orthog_dst_rhs_indexer (static_cast <ssize_t >(orthog_i));
178-
179- size_t total_dst_offset = masked_dst_indexer (masked_i) +
180- orthog_offsets.get_first_offset ();
181- size_t total_rhs_offset =
165+ const std::size_t orthog_i = ndit.get_global_id (0 );
166+ const std::size_t group_i = ndit.get_group (1 );
167+ const std::uint32_t l_i = ndit.get_local_id (1 );
168+ const std::uint32_t lws = ndit.get_local_range (1 );
169+
170+ const size_t masked_block_start = group_i * lws;
171+ const size_t masked_i = masked_block_start + l_i;
172+
173+ for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
174+ const size_t offset = masked_block_start + i;
175+ lacc[i] = (offset == 0 ) ? indT (0 )
176+ : (offset - 1 < masked_nelems)
177+ ? cumsum_data[offset - 1 ]
178+ : cumsum_data[masked_nelems - 1 ] + 1 ;
179+ }
180+
181+ sycl::group_barrier (ndit.get_group ());
182+
183+ const indT current_running_count = lacc[l_i + 1 ];
184+ const bool mask_set = (masked_i == 0 )
185+ ? (current_running_count == 1 )
186+ : (current_running_count == lacc[l_i] + 1 );
187+
188+ // src[i, j] = rhs[cumsum[i] - 1, j]
189+ // if cumsum[i] == ((i > 0) ? cumsum[i-1] + 1 : 1)
190+ if (mask_set && (masked_i < masked_nelems)) {
191+ const auto &orthog_offsets = orthog_dst_rhs_indexer (orthog_i);
192+
193+ const size_t total_dst_offset = masked_dst_indexer (masked_i) +
194+ orthog_offsets.get_first_offset ();
195+ const size_t total_rhs_offset =
182196 masked_rhs_indexer (current_running_count - 1 ) +
183197 orthog_offsets.get_second_offset ();
184198
@@ -190,8 +204,7 @@ struct MaskedPlaceStridedFunctor
190204 char *dst_cp = nullptr ;
191205 const char *cumsum_cp = nullptr ;
192206 const char *rhs_cp = nullptr ;
193- size_t orthog_nelems = 0 ;
194- size_t masked_nelems = 0 ;
207+ const size_t masked_nelems = 0 ;
195208 // has nd, shape, dst_strides, rhs_strides for
196209 // dimensions that ARE NOT masked
197210 const OrthogIndexerT orthog_dst_rhs_indexer;
@@ -200,6 +213,7 @@ struct MaskedPlaceStridedFunctor
200213 const MaskedDstIndexerT masked_dst_indexer;
201214 // has 1, rhs_strides for dimensions that ARE masked
202215 const MaskedRhsIndexerT masked_rhs_indexer;
216+ LocalAccessorT lacc;
203217};
204218
205219// ======= Masked extraction ================================
@@ -537,18 +551,35 @@ sycl::event masked_place_all_slices_strided_impl(
537551 const StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
538552 const Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
539553
554+ using KernelName = class masked_place_all_slices_strided_impl_krn <
555+ TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer, dataT,
556+ indT>;
557+
558+ constexpr std::size_t nominal_lws = 256 ;
559+ const std::size_t masked_extent = iteration_size;
560+ const std::size_t lws = std::min (masked_extent, nominal_lws);
561+
562+ const std::size_t n_groups = (masked_extent + lws - 1 ) / lws;
563+
564+ sycl::range<2 > gRange {1 , n_groups * lws};
565+ sycl::range<2 > lRange{1 , lws};
566+ sycl::nd_range<2 > ndRange{gRange , lRange};
567+
568+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
569+
540570 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
541571 cgh.depends_on (depends);
542572
543- cgh.parallel_for <class masked_place_all_slices_strided_impl_krn <
544- TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer,
545- dataT, indT>>(
546- sycl::range<1 >(static_cast <size_t >(iteration_size)),
573+ const std::size_t lacc_size = std::min (masked_extent, lws) + 1 ;
574+ LocalAccessorT lacc (lacc_size, cgh);
575+
576+ cgh.parallel_for <KernelName>(
577+ ndRange,
547578 MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
548- Strided1DCyclicIndexer, dataT, indT>(
549- dst_p, cumsum_p, rhs_p, 1 , iteration_size,
550- orthog_dst_rhs_indexer, masked_dst_indexer ,
551- masked_rhs_indexer));
579+ Strided1DCyclicIndexer, dataT, indT,
580+ LocalAccessorT>(
581+ dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer ,
582+ masked_dst_indexer, masked_rhs_indexer, lacc ));
552583 });
553584
554585 return comp_ev;
@@ -615,18 +646,32 @@ sycl::event masked_place_some_slices_strided_impl(
615646 TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
616647 dataT, indT>;
617648
618- sycl::range<1 > gRange (static_cast <size_t >(orthog_nelems * masked_nelems));
649+ constexpr std::size_t nominal_lws = 256 ;
650+ const std::size_t orthog_extent = orthog_nelems;
651+ const std::size_t masked_extent = masked_nelems;
652+ const std::size_t lws = std::min (masked_extent, nominal_lws);
653+
654+ const std::size_t n_groups = (masked_extent + lws - 1 ) / lws;
655+
656+ sycl::range<2 > gRange {orthog_extent, n_groups * lws};
657+ sycl::range<2 > lRange{1 , lws};
658+ sycl::nd_range<2 > ndRange{gRange , lRange};
659+
660+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
619661
620662 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
621663 cgh.depends_on (depends);
622664
665+ const std::size_t lacc_size = std::min (masked_extent, lws) + 1 ;
666+ LocalAccessorT lacc (lacc_size, cgh);
667+
623668 cgh.parallel_for <KernelName>(
624- gRange ,
669+ ndRange ,
625670 MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
626- Strided1DCyclicIndexer, dataT, indT>(
627- dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
628- orthog_dst_rhs_indexer, masked_dst_indexer ,
629- masked_rhs_indexer));
671+ Strided1DCyclicIndexer, dataT, indT,
672+ LocalAccessorT>(
673+ dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer ,
674+ masked_dst_indexer, masked_rhs_indexer, lacc ));
630675 });
631676
632677 return comp_ev;
0 commit comments