@@ -48,51 +48,65 @@ template <typename OrthogIndexerT,
4848 typename MaskedSrcIndexerT,
4949 typename MaskedDstIndexerT,
5050 typename dataT,
51- typename indT>
51+ typename indT,
52+ typename LocalAccessorT>
5253struct MaskedExtractStridedFunctor
5354{
5455 MaskedExtractStridedFunctor (const char *src_data_p,
5556 const char *cumsum_data_p,
5657 char *dst_data_p,
57- size_t orthog_iter_size,
5858 size_t masked_iter_size,
5959 const OrthogIndexerT &orthog_src_dst_indexer_,
6060 const MaskedSrcIndexerT &masked_src_indexer_,
61- const MaskedDstIndexerT &masked_dst_indexer_)
61+ const MaskedDstIndexerT &masked_dst_indexer_,
62+ const LocalAccessorT &lacc_)
6263 : src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
63- orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
64+ masked_nelems (masked_iter_size),
6465 orthog_src_dst_indexer(orthog_src_dst_indexer_),
6566 masked_src_indexer(masked_src_indexer_),
66- masked_dst_indexer(masked_dst_indexer_)
67+ masked_dst_indexer(masked_dst_indexer_), lacc(lacc_)
6768 {
69+ static_assert (
70+ std::is_same_v<indT, typename LocalAccessorT::value_type>);
6871 }
6972
70- void operator ()(sycl::id< 1 > idx ) const
73+ void operator ()(sycl::nd_item< 2 > ndit ) const
7174 {
7275 const dataT *src_data = reinterpret_cast <const dataT *>(src_cp);
7376 dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
7477 const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
7578
76- size_t global_i = idx[0 ];
77- size_t orthog_i = global_i / masked_nelems;
78- size_t masked_i = global_i - masked_nelems * orthog_i;
79+ const size_t orthog_i = ndit.get_global_id (0 );
80+ const size_t group_i = ndit.get_group (1 );
81+ const std::uint32_t l_i = ndit.get_local_id (1 );
82+ const std::uint32_t lws = ndit.get_local_range (1 );
7983
80- indT current_running_count = cumsum_data[masked_i];
81- bool mask_set =
82- (masked_i == 0 )
83- ? (current_running_count == 1 )
84- : (current_running_count == cumsum_data[masked_i - 1 ] + 1 );
84+ const size_t masked_block_start = group_i * lws;
85+ const size_t masked_i = masked_block_start + l_i;
8586
86- // dst[cumsum[i], j] - 1 = src[i, j] if cumsum[i] == ((i > 0) ?
87- // cumsum[i-1]
88- // + 1 : 1)
89- if (mask_set) {
90- auto orthog_offsets =
91- orthog_src_dst_indexer (static_cast <ssize_t >(orthog_i));
87+ for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
88+ const size_t offset = masked_block_start + i;
89+ lacc[i] = (offset == 0 ) ? indT (0 )
90+ : (offset - 1 < masked_nelems)
91+ ? cumsum_data[offset - 1 ]
92+ : cumsum_data[masked_nelems - 1 ] + 1 ;
93+ }
9294
93- size_t total_src_offset = masked_src_indexer (masked_i) +
94- orthog_offsets.get_first_offset ();
95- size_t total_dst_offset =
95+ sycl::group_barrier (ndit.get_group ());
96+
97+ const indT current_running_count = lacc[l_i + 1 ];
98+ const bool mask_set = (masked_i == 0 )
99+ ? (current_running_count == 1 )
100+ : (current_running_count == lacc[l_i] + 1 );
101+
102+ // dst[cumsum[i] - 1, j] = src[i, j]
103+ // if cumsum[i] == ((i > 0) ? cumsum[i-1] + 1 : 1)
104+ if (mask_set && (masked_i < masked_nelems)) {
105+ const auto &orthog_offsets = orthog_src_dst_indexer (orthog_i);
106+
107+ const size_t total_src_offset = masked_src_indexer (masked_i) +
108+ orthog_offsets.get_first_offset ();
109+ const size_t total_dst_offset =
96110 masked_dst_indexer (current_running_count - 1 ) +
97111 orthog_offsets.get_second_offset ();
98112
@@ -104,8 +118,7 @@ struct MaskedExtractStridedFunctor
104118 const char *src_cp = nullptr ;
105119 const char *cumsum_cp = nullptr ;
106120 char *dst_cp = nullptr ;
107- size_t orthog_nelems = 0 ;
108- size_t masked_nelems = 0 ;
121+ const size_t masked_nelems = 0 ;
109122 // has nd, shape, src_strides, dst_strides for
110123 // dimensions that ARE NOT masked
111124 const OrthogIndexerT orthog_src_dst_indexer;
@@ -114,6 +127,7 @@ struct MaskedExtractStridedFunctor
114127 const MaskedSrcIndexerT masked_src_indexer;
115128 // has 1, dst_strides for dimensions that ARE masked
116129 const MaskedDstIndexerT masked_dst_indexer;
130+ LocalAccessorT lacc;
117131};
118132
119133template <typename OrthogIndexerT,
@@ -190,8 +204,72 @@ struct MaskedPlaceStridedFunctor
190204
191205// ======= Masked extraction ================================
192206
193- template <typename OrthoIndexerT,
194- typename MaskedSrcIndexerT,
207+ template <typename MaskedDstIndexerT, typename dataT, typename indT>
208+ class masked_extract_all_slices_contig_impl_krn ;
209+
210+ typedef sycl::event (*masked_extract_all_slices_contig_impl_fn_ptr_t )(
211+ sycl::queue &,
212+ ssize_t ,
213+ const char *,
214+ const char *,
215+ char *,
216+ ssize_t ,
217+ ssize_t ,
218+ const std::vector<sycl::event> &);
219+
220+ template <typename dataT, typename indT>
221+ sycl::event masked_extract_all_slices_contig_impl (
222+ sycl::queue &exec_q,
223+ ssize_t iteration_size,
224+ const char *src_p,
225+ const char *cumsum_p,
226+ char *dst_p,
227+ ssize_t dst_size, // dst is 1D
228+ ssize_t dst_stride,
229+ const std::vector<sycl::event> &depends = {})
230+ {
231+ constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232+
233+ constexpr NoOpIndexer masked_src_indexer{};
234+ const Strided1DIndexer masked_dst_indexer (/* size */ dst_size,
235+ /* step */ dst_stride);
236+
237+ using KernelName =
238+ class masked_extract_all_slices_contig_impl_krn <Strided1DIndexer, dataT,
239+ indT>;
240+
241+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
242+ using Impl =
243+ struct MaskedExtractStridedFunctor <TwoZeroOffsets_Indexer, NoOpIndexer,
244+ Strided1DIndexer, dataT, indT,
245+ LocalAccessorT>;
246+
247+ constexpr std::size_t nominal_lws = 256 ;
248+ const std::size_t masked_extent = iteration_size;
249+ const std::size_t lws = std::min (masked_extent, nominal_lws);
250+ const std::size_t n_groups = (iteration_size + lws - 1 ) / lws;
251+
252+ sycl::range<2 > gRange {1 , n_groups * lws};
253+ sycl::range<2 > lRange{1 , lws};
254+
255+ sycl::nd_range<2 > ndRange (gRange , lRange);
256+
257+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
258+ cgh.depends_on (depends);
259+
260+ const std::size_t lacc_size = std::min (lws, masked_extent) + 1 ;
261+ LocalAccessorT lacc (lacc_size, cgh);
262+
263+ cgh.parallel_for <KernelName>(
264+ ndRange,
265+ Impl (src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
266+ masked_src_indexer, masked_dst_indexer, lacc));
267+ });
268+
269+ return comp_ev;
270+ }
271+
272+ template <typename MaskedSrcIndexerT,
195273 typename MaskedDstIndexerT,
196274 typename dataT,
197275 typename indT>
@@ -223,11 +301,6 @@ sycl::event masked_extract_all_slices_strided_impl(
223301 ssize_t dst_stride,
224302 const std::vector<sycl::event> &depends = {})
225303{
226- // using MaskedExtractStridedFunctor;
227- // using Strided1DIndexer;
228- // using StridedIndexer;
229- // using TwoZeroOffsets_Indexer;
230-
231304 constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232305
233306 /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
@@ -236,18 +309,35 @@ sycl::event masked_extract_all_slices_strided_impl(
236309 const Strided1DIndexer masked_dst_indexer (/* size */ dst_size,
237310 /* step */ dst_stride);
238311
312+ using KernelName = class masked_extract_all_slices_strided_impl_krn <
313+ StridedIndexer, Strided1DIndexer, dataT, indT>;
314+
315+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
316+ using Impl =
317+ struct MaskedExtractStridedFunctor <TwoZeroOffsets_Indexer,
318+ StridedIndexer, Strided1DIndexer,
319+ dataT, indT, LocalAccessorT>;
320+
321+ constexpr std::size_t nominal_lws = 256 ;
322+ const std::size_t masked_nelems = iteration_size;
323+ const std::size_t lws = std::min (masked_nelems, nominal_lws);
324+ const std::size_t n_groups = (masked_nelems + lws - 1 ) / lws;
325+
326+ sycl::range<2 > gRange {1 , n_groups * lws};
327+ sycl::range<2 > lRange{1 , lws};
328+
329+ sycl::nd_range<2 > ndRange (gRange , lRange);
330+
239331 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
240332 cgh.depends_on (depends);
241333
242- cgh.parallel_for <class masked_extract_all_slices_strided_impl_krn <
243- TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, dataT,
244- indT>>(
245- sycl::range<1 >(static_cast <size_t >(iteration_size)),
246- MaskedExtractStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
247- Strided1DIndexer, dataT, indT>(
248- src_p, cumsum_p, dst_p, 1 , iteration_size,
249- orthog_src_dst_indexer, masked_src_indexer,
250- masked_dst_indexer));
334+ const std::size_t lacc_size = std::min (lws, masked_nelems) + 1 ;
335+ LocalAccessorT lacc (lacc_size, cgh);
336+
337+ cgh.parallel_for <KernelName>(
338+ ndRange,
339+ Impl (src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
340+ masked_src_indexer, masked_dst_indexer, lacc));
251341 });
252342
253343 return comp_ev;
@@ -299,11 +389,6 @@ sycl::event masked_extract_some_slices_strided_impl(
299389 ssize_t masked_dst_stride,
300390 const std::vector<sycl::event> &depends = {})
301391{
302- // using MaskedExtractStridedFunctor;
303- // using Strided1DIndexer;
304- // using StridedIndexer;
305- // using TwoOffsets_StridedIndexer;
306-
307392 const TwoOffsets_StridedIndexer orthog_src_dst_indexer{
308393 orthog_nd, ortho_src_offset, ortho_dst_offset,
309394 packed_ortho_src_dst_shape_strides};
@@ -313,24 +398,63 @@ sycl::event masked_extract_some_slices_strided_impl(
313398 const Strided1DIndexer masked_dst_indexer{/* size */ masked_dst_size,
314399 /* step */ masked_dst_stride};
315400
401+ using KernelName = class masked_extract_some_slices_strided_impl_krn <
402+ TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
403+ indT>;
404+
405+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
406+ using Impl =
407+ struct MaskedExtractStridedFunctor <TwoOffsets_StridedIndexer,
408+ StridedIndexer, Strided1DIndexer,
409+ dataT, indT, LocalAccessorT>;
410+
411+ const size_t nominal_lws = 256 ;
412+ const std::size_t masked_extent = masked_nelems;
413+ const size_t lws = std::min (masked_extent, nominal_lws);
414+ const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
415+ const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
416+
417+ sycl::range<2 > gRange {orthog_extent, n_groups * lws};
418+ sycl::range<2 > lRange{1 , lws};
419+
420+ sycl::nd_range<2 > ndRange (gRange , lRange);
421+
316422 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
317423 cgh.depends_on (depends);
318424
319- cgh.parallel_for <class masked_extract_some_slices_strided_impl_krn <
320- TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
321- indT>>(
322- sycl::range<1 >(static_cast <size_t >(orthog_nelems * masked_nelems)),
323- MaskedExtractStridedFunctor<TwoOffsets_StridedIndexer,
324- StridedIndexer, Strided1DIndexer, dataT,
325- indT>(
326- src_p, cumsum_p, dst_p, orthog_nelems, masked_nelems,
327- orthog_src_dst_indexer, masked_src_indexer,
328- masked_dst_indexer));
425+ const std::size_t lacc_size =
426+ std::min<std::size_t >(lws, masked_extent) + 1 ;
427+ LocalAccessorT lacc (lacc_size, cgh);
428+
429+ cgh.parallel_for <KernelName>(
430+ ndRange,
431+ Impl (src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
432+ masked_src_indexer, masked_dst_indexer, lacc));
329433 });
330434
331435 return comp_ev;
332436}
333437
438+ template <typename fnT, typename T>
439+ struct MaskExtractAllSlicesContigFactoryForInt32
440+ {
441+ fnT get ()
442+ {
443+ fnT fn = masked_extract_all_slices_contig_impl<T, std::int32_t >;
444+ return fn;
445+ }
446+ };
447+
448+ template <typename fnT, typename T>
449+ struct MaskExtractAllSlicesContigFactoryForInt64
450+ {
451+ fnT get ()
452+ {
453+ fnT fn = masked_extract_all_slices_contig_impl<T, std::int64_t >;
454+ return fn;
455+ }
456+ };
457+
334458template <typename fnT, typename T>
335459struct MaskExtractAllSlicesStridedFactoryForInt32
336460{
@@ -487,13 +611,17 @@ sycl::event masked_place_some_slices_strided_impl(
487611 const Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
488612 masked_rhs_stride};
489613
614+ using KernelName = class masked_place_some_slices_strided_impl_krn <
615+ TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
616+ dataT, indT>;
617+
618+ sycl::range<1 > gRange (static_cast <size_t >(orthog_nelems * masked_nelems));
619+
490620 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
491621 cgh.depends_on (depends);
492622
493- cgh.parallel_for <class masked_place_some_slices_strided_impl_krn <
494- TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
495- dataT, indT>>(
496- sycl::range<1 >(static_cast <size_t >(orthog_nelems * masked_nelems)),
623+ cgh.parallel_for <KernelName>(
624+ gRange ,
497625 MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
498626 Strided1DCyclicIndexer, dataT, indT>(
499627 dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
0 commit comments