@@ -52,15 +52,15 @@ template <typename OrthogIndexerT,
5252 typename LocalAccessorT>
5353struct MaskedExtractStridedFunctor
5454{
55- MaskedExtractStridedFunctor (const char *src_data_p,
56- const char *cumsum_data_p,
57- char *dst_data_p,
55+ MaskedExtractStridedFunctor (const dataT *src_data_p,
56+ const indT *cumsum_data_p,
57+ dataT *dst_data_p,
5858 size_t masked_iter_size,
5959 const OrthogIndexerT &orthog_src_dst_indexer_,
6060 const MaskedSrcIndexerT &masked_src_indexer_,
6161 const MaskedDstIndexerT &masked_dst_indexer_,
6262 const LocalAccessorT &lacc_)
63- : src_cp (src_data_p), cumsum_cp (cumsum_data_p), dst_cp (dst_data_p),
63+ : src (src_data_p), cumsum (cumsum_data_p), dst (dst_data_p),
6464 masked_nelems (masked_iter_size),
6565 orthog_src_dst_indexer(orthog_src_dst_indexer_),
6666 masked_src_indexer(masked_src_indexer_),
@@ -72,24 +72,19 @@ struct MaskedExtractStridedFunctor
7272
7373 void operator ()(sycl::nd_item<2 > ndit) const
7474 {
75- const dataT *src_data = reinterpret_cast <const dataT *>(src_cp);
76- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
77- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
78-
79- const size_t orthog_i = ndit.get_global_id (0 );
80- const size_t group_i = ndit.get_group (1 );
75+ const std::size_t orthog_i = ndit.get_global_id (0 );
8176 const std::uint32_t l_i = ndit.get_local_id (1 );
8277 const std::uint32_t lws = ndit.get_local_range (1 );
8378
84- const size_t masked_block_start = group_i * lws ;
85- const size_t masked_i = masked_block_start + l_i;
79+ const std:: size_t masked_i = ndit. get_global_id ( 1 ) ;
80+ const std:: size_t masked_block_start = masked_i - l_i;
8681
82+ const std::size_t max_offset = masked_nelems + 1 ;
8783 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
8884 const size_t offset = masked_block_start + i;
89- lacc[i] = (offset == 0 ) ? indT (0 )
90- : (offset - 1 < masked_nelems)
91- ? cumsum_data[offset - 1 ]
92- : cumsum_data[masked_nelems - 1 ] + 1 ;
85+ lacc[i] = (offset == 0 ) ? indT (0 )
86+ : (offset < max_offset) ? cumsum[offset - 1 ]
87+ : cumsum[masked_nelems - 1 ] + 1 ;
9388 }
9489
9590 sycl::group_barrier (ndit.get_group ());
@@ -110,14 +105,14 @@ struct MaskedExtractStridedFunctor
110105 masked_dst_indexer (current_running_count - 1 ) +
111106 orthog_offsets.get_second_offset ();
112107
113- dst_data [total_dst_offset] = src_data [total_src_offset];
108+ dst [total_dst_offset] = src [total_src_offset];
114109 }
115110 }
116111
117112private:
118- const char *src_cp = nullptr ;
119- const char *cumsum_cp = nullptr ;
120- char *dst_cp = nullptr ;
113+ const dataT *src = nullptr ;
114+ const indT *cumsum = nullptr ;
115+ dataT *dst = nullptr ;
121116 const size_t masked_nelems = 0 ;
122117 // has nd, shape, src_strides, dst_strides for
123118 // dimensions that ARE NOT masked
@@ -138,15 +133,15 @@ template <typename OrthogIndexerT,
138133 typename LocalAccessorT>
139134struct MaskedPlaceStridedFunctor
140135{
141- MaskedPlaceStridedFunctor (char *dst_data_p,
142- const char *cumsum_data_p,
143- const char *rhs_data_p,
136+ MaskedPlaceStridedFunctor (dataT *dst_data_p,
137+ const indT *cumsum_data_p,
138+ const dataT *rhs_data_p,
144139 size_t masked_iter_size,
145140 const OrthogIndexerT &orthog_dst_rhs_indexer_,
146141 const MaskedDstIndexerT &masked_dst_indexer_,
147142 const MaskedRhsIndexerT &masked_rhs_indexer_,
148143 const LocalAccessorT &lacc_)
149- : dst_cp (dst_data_p), cumsum_cp (cumsum_data_p), rhs_cp (rhs_data_p),
144+ : dst (dst_data_p), cumsum (cumsum_data_p), rhs (rhs_data_p),
150145 masked_nelems (masked_iter_size),
151146 orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
152147 masked_dst_indexer(masked_dst_indexer_),
@@ -158,24 +153,19 @@ struct MaskedPlaceStridedFunctor
158153
159154 void operator ()(sycl::nd_item<2 > ndit) const
160155 {
161- dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
162- const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
163- const dataT *rhs_data = reinterpret_cast <const dataT *>(rhs_cp);
164-
165156 const std::size_t orthog_i = ndit.get_global_id (0 );
166- const std::size_t group_i = ndit.get_group (1 );
167157 const std::uint32_t l_i = ndit.get_local_id (1 );
168158 const std::uint32_t lws = ndit.get_local_range (1 );
169159
170- const size_t masked_block_start = group_i * lws ;
171- const size_t masked_i = masked_block_start + l_i;
160+ const size_t masked_i = ndit. get_global_id ( 1 ) ;
161+ const size_t masked_block_start = masked_i - l_i;
172162
163+ const std::size_t max_offset = masked_nelems + 1 ;
173164 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
174165 const size_t offset = masked_block_start + i;
175- lacc[i] = (offset == 0 ) ? indT (0 )
176- : (offset - 1 < masked_nelems)
177- ? cumsum_data[offset - 1 ]
178- : cumsum_data[masked_nelems - 1 ] + 1 ;
166+ lacc[i] = (offset == 0 ) ? indT (0 )
167+ : (offset < max_offset) ? cumsum[offset - 1 ]
168+ : cumsum[masked_nelems - 1 ] + 1 ;
179169 }
180170
181171 sycl::group_barrier (ndit.get_group ());
@@ -196,14 +186,14 @@ struct MaskedPlaceStridedFunctor
196186 masked_rhs_indexer (current_running_count - 1 ) +
197187 orthog_offsets.get_second_offset ();
198188
199- dst_data [total_dst_offset] = rhs_data [total_rhs_offset];
189+ dst [total_dst_offset] = rhs [total_rhs_offset];
200190 }
201191 }
202192
203193private:
204- char *dst_cp = nullptr ;
205- const char *cumsum_cp = nullptr ;
206- const char *rhs_cp = nullptr ;
194+ dataT *dst = nullptr ;
195+ const indT *cumsum = nullptr ;
196+ const dataT *rhs = nullptr ;
207197 const size_t masked_nelems = 0 ;
208198 // has nd, shape, dst_strides, rhs_strides for
209199 // dimensions that ARE NOT masked
@@ -218,6 +208,30 @@ struct MaskedPlaceStridedFunctor
218208
219209// ======= Masked extraction ================================
220210
211+ namespace
212+ {
213+
214+ template <std::size_t I, std::size_t ... IR>
215+ std::size_t _get_lws_impl (std::size_t n)
216+ {
217+ if constexpr (sizeof ...(IR) == 0 ) {
218+ return I;
219+ }
220+ else {
221+ return (n < I) ? _get_lws_impl<IR...>(n) : I;
222+ }
223+ }
224+
225+ std::size_t get_lws (std::size_t n)
226+ {
227+ constexpr std::size_t lws0 = 256u ;
228+ constexpr std::size_t lws1 = 128u ;
229+ constexpr std::size_t lws2 = 64u ;
230+ return _get_lws_impl<lws0, lws1, lws2>(n);
231+ }
232+
233+ } // end of anonymous namespace
234+
221235template <typename MaskedDstIndexerT, typename dataT, typename indT>
222236class masked_extract_all_slices_contig_impl_krn ;
223237
@@ -258,26 +272,31 @@ sycl::event masked_extract_all_slices_contig_impl(
258272 Strided1DIndexer, dataT, indT,
259273 LocalAccessorT>;
260274
261- constexpr std::size_t nominal_lws = 256 ;
262275 const std::size_t masked_extent = iteration_size;
263- const std::size_t lws = std::min (masked_extent, nominal_lws);
276+
277+ const std::size_t lws = get_lws (masked_extent);
278+
264279 const std::size_t n_groups = (iteration_size + lws - 1 ) / lws;
265280
266281 sycl::range<2 > gRange {1 , n_groups * lws};
267282 sycl::range<2 > lRange{1 , lws};
268283
269284 sycl::nd_range<2 > ndRange (gRange , lRange);
270285
286+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
287+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
288+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
289+
271290 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
272291 cgh.depends_on (depends);
273292
274293 const std::size_t lacc_size = std::min (lws, masked_extent) + 1 ;
275294 LocalAccessorT lacc (lacc_size, cgh);
276295
277296 cgh.parallel_for <KernelName>(
278- ndRange,
279- Impl (src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
280- masked_src_indexer, masked_dst_indexer, lacc));
297+ ndRange, Impl (src_tp, cumsum_tp, dst_tp, masked_extent,
298+ orthog_src_dst_indexer, masked_src_indexer ,
299+ masked_dst_indexer, lacc));
281300 });
282301
283302 return comp_ev;
@@ -332,26 +351,31 @@ sycl::event masked_extract_all_slices_strided_impl(
332351 StridedIndexer, Strided1DIndexer,
333352 dataT, indT, LocalAccessorT>;
334353
335- constexpr std::size_t nominal_lws = 256 ;
336354 const std::size_t masked_nelems = iteration_size;
337- const std::size_t lws = std::min (masked_nelems, nominal_lws);
355+
356+ const std::size_t lws = get_lws (masked_nelems);
357+
338358 const std::size_t n_groups = (masked_nelems + lws - 1 ) / lws;
339359
340360 sycl::range<2 > gRange {1 , n_groups * lws};
341361 sycl::range<2 > lRange{1 , lws};
342362
343363 sycl::nd_range<2 > ndRange (gRange , lRange);
344364
365+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
366+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
367+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
368+
345369 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
346370 cgh.depends_on (depends);
347371
348372 const std::size_t lacc_size = std::min (lws, masked_nelems) + 1 ;
349373 LocalAccessorT lacc (lacc_size, cgh);
350374
351375 cgh.parallel_for <KernelName>(
352- ndRange,
353- Impl (src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
354- masked_src_indexer, masked_dst_indexer, lacc));
376+ ndRange, Impl (src_tp, cumsum_tp, dst_tp, iteration_size,
377+ orthog_src_dst_indexer, masked_src_indexer ,
378+ masked_dst_indexer, lacc));
355379 });
356380
357381 return comp_ev;
@@ -422,9 +446,10 @@ sycl::event masked_extract_some_slices_strided_impl(
422446 StridedIndexer, Strided1DIndexer,
423447 dataT, indT, LocalAccessorT>;
424448
425- const size_t nominal_lws = 256 ;
426449 const std::size_t masked_extent = masked_nelems;
427- const size_t lws = std::min (masked_extent, nominal_lws);
450+
451+ const std::size_t lws = get_lws (masked_extent);
452+
428453 const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
429454 const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
430455
@@ -433,6 +458,10 @@ sycl::event masked_extract_some_slices_strided_impl(
433458
434459 sycl::nd_range<2 > ndRange (gRange , lRange);
435460
461+ const dataT *src_tp = reinterpret_cast <const dataT *>(src_p);
462+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
463+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
464+
436465 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
437466 cgh.depends_on (depends);
438467
@@ -441,9 +470,9 @@ sycl::event masked_extract_some_slices_strided_impl(
441470 LocalAccessorT lacc (lacc_size, cgh);
442471
443472 cgh.parallel_for <KernelName>(
444- ndRange,
445- Impl (src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
446- masked_src_indexer, masked_dst_indexer, lacc));
473+ ndRange, Impl (src_tp, cumsum_tp, dst_tp, masked_nelems,
474+ orthog_src_dst_indexer, masked_src_indexer ,
475+ masked_dst_indexer, lacc));
447476 });
448477
449478 return comp_ev;
@@ -567,6 +596,10 @@ sycl::event masked_place_all_slices_strided_impl(
567596
568597 using LocalAccessorT = sycl::local_accessor<indT, 1 >;
569598
599+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
600+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
601+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
602+
570603 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
571604 cgh.depends_on (depends);
572605
@@ -578,8 +611,9 @@ sycl::event masked_place_all_slices_strided_impl(
578611 MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
579612 Strided1DCyclicIndexer, dataT, indT,
580613 LocalAccessorT>(
581- dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer,
582- masked_dst_indexer, masked_rhs_indexer, lacc));
614+ dst_tp, cumsum_tp, rhs_tp, iteration_size,
615+ orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
616+ lacc));
583617 });
584618
585619 return comp_ev;
@@ -659,6 +693,10 @@ sycl::event masked_place_some_slices_strided_impl(
659693
660694 using LocalAccessorT = sycl::local_accessor<indT, 1 >;
661695
696+ dataT *dst_tp = reinterpret_cast <dataT *>(dst_p);
697+ const dataT *rhs_tp = reinterpret_cast <const dataT *>(rhs_p);
698+ const indT *cumsum_tp = reinterpret_cast <const indT *>(cumsum_p);
699+
662700 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
663701 cgh.depends_on (depends);
664702
@@ -670,8 +708,9 @@ sycl::event masked_place_some_slices_strided_impl(
670708 MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
671709 Strided1DCyclicIndexer, dataT, indT,
672710 LocalAccessorT>(
673- dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer,
674- masked_dst_indexer, masked_rhs_indexer, lacc));
711+ dst_tp, cumsum_tp, rhs_tp, masked_nelems,
712+ orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
713+ lacc));
675714 });
676715
677716 return comp_ev;
0 commit comments