@@ -48,51 +48,65 @@ template <typename OrthogIndexerT,
48
48
typename MaskedSrcIndexerT,
49
49
typename MaskedDstIndexerT,
50
50
typename dataT,
51
- typename indT>
51
+ typename indT,
52
+ typename LocalAccessorT>
52
53
struct MaskedExtractStridedFunctor
53
54
{
54
55
MaskedExtractStridedFunctor (const char *src_data_p,
55
56
const char *cumsum_data_p,
56
57
char *dst_data_p,
57
- size_t orthog_iter_size,
58
58
size_t masked_iter_size,
59
59
const OrthogIndexerT &orthog_src_dst_indexer_,
60
60
const MaskedSrcIndexerT &masked_src_indexer_,
61
- const MaskedDstIndexerT &masked_dst_indexer_)
61
+ const MaskedDstIndexerT &masked_dst_indexer_,
62
+ const LocalAccessorT &lacc_)
62
63
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
63
- orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
64
+ masked_nelems (masked_iter_size),
64
65
orthog_src_dst_indexer(orthog_src_dst_indexer_),
65
66
masked_src_indexer(masked_src_indexer_),
66
- masked_dst_indexer(masked_dst_indexer_)
67
+ masked_dst_indexer(masked_dst_indexer_), lacc(lacc_)
67
68
{
69
+ static_assert (
70
+ std::is_same_v<indT, typename LocalAccessorT::value_type>);
68
71
}
69
72
70
- void operator ()(sycl::id< 1 > idx ) const
73
+ void operator ()(sycl::nd_item< 2 > ndit ) const
71
74
{
72
75
const dataT *src_data = reinterpret_cast <const dataT *>(src_cp);
73
76
dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
74
77
const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
75
78
76
- size_t global_i = idx[0 ];
77
- size_t orthog_i = global_i / masked_nelems;
78
- size_t masked_i = global_i - masked_nelems * orthog_i;
79
+ const size_t orthog_i = ndit.get_global_id (0 );
80
+ const size_t group_i = ndit.get_group (1 );
81
+ const std::uint32_t l_i = ndit.get_local_id (1 );
82
+ const std::uint32_t lws = ndit.get_local_range (1 );
79
83
80
- indT current_running_count = cumsum_data[masked_i];
81
- bool mask_set =
82
- (masked_i == 0 )
83
- ? (current_running_count == 1 )
84
- : (current_running_count == cumsum_data[masked_i - 1 ] + 1 );
84
+ const size_t masked_block_start = group_i * lws;
85
+ const size_t masked_i = masked_block_start + l_i;
85
86
86
- // dst[cumsum[i], j] - 1 = src[i, j] if cumsum[i] == ((i > 0) ?
87
- // cumsum[i-1]
88
- // + 1 : 1)
89
- if (mask_set) {
90
- auto orthog_offsets =
91
- orthog_src_dst_indexer (static_cast <ssize_t >(orthog_i));
87
+ for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
88
+ const size_t offset = masked_block_start + i;
89
+ lacc[i] = (offset == 0 ) ? indT (0 )
90
+ : (offset - 1 < masked_nelems)
91
+ ? cumsum_data[offset - 1 ]
92
+ : cumsum_data[masked_nelems - 1 ] + 1 ;
93
+ }
92
94
93
- size_t total_src_offset = masked_src_indexer (masked_i) +
94
- orthog_offsets.get_first_offset ();
95
- size_t total_dst_offset =
95
+ sycl::group_barrier (ndit.get_group ());
96
+
97
+ const indT current_running_count = lacc[l_i + 1 ];
98
+ const bool mask_set = (masked_i == 0 )
99
+ ? (current_running_count == 1 )
100
+ : (current_running_count == lacc[l_i] + 1 );
101
+
102
+ // dst[cumsum[i] - 1, j] = src[i, j]
103
+ // if cumsum[i] == ((i > 0) ? cumsum[i-1] + 1 : 1)
104
+ if (mask_set && (masked_i < masked_nelems)) {
105
+ const auto &orthog_offsets = orthog_src_dst_indexer (orthog_i);
106
+
107
+ const size_t total_src_offset = masked_src_indexer (masked_i) +
108
+ orthog_offsets.get_first_offset ();
109
+ const size_t total_dst_offset =
96
110
masked_dst_indexer (current_running_count - 1 ) +
97
111
orthog_offsets.get_second_offset ();
98
112
@@ -104,8 +118,7 @@ struct MaskedExtractStridedFunctor
104
118
const char *src_cp = nullptr ;
105
119
const char *cumsum_cp = nullptr ;
106
120
char *dst_cp = nullptr ;
107
- size_t orthog_nelems = 0 ;
108
- size_t masked_nelems = 0 ;
121
+ const size_t masked_nelems = 0 ;
109
122
// has nd, shape, src_strides, dst_strides for
110
123
// dimensions that ARE NOT masked
111
124
const OrthogIndexerT orthog_src_dst_indexer;
@@ -114,6 +127,7 @@ struct MaskedExtractStridedFunctor
114
127
const MaskedSrcIndexerT masked_src_indexer;
115
128
// has 1, dst_strides for dimensions that ARE masked
116
129
const MaskedDstIndexerT masked_dst_indexer;
130
+ LocalAccessorT lacc;
117
131
};
118
132
119
133
template <typename OrthogIndexerT,
@@ -190,8 +204,72 @@ struct MaskedPlaceStridedFunctor
190
204
191
205
// ======= Masked extraction ================================
192
206
193
- template <typename OrthoIndexerT,
194
- typename MaskedSrcIndexerT,
207
+ template <typename MaskedDstIndexerT, typename dataT, typename indT>
208
+ class masked_extract_all_slices_contig_impl_krn ;
209
+
210
+ typedef sycl::event (*masked_extract_all_slices_contig_impl_fn_ptr_t )(
211
+ sycl::queue &,
212
+ ssize_t ,
213
+ const char *,
214
+ const char *,
215
+ char *,
216
+ ssize_t ,
217
+ ssize_t ,
218
+ const std::vector<sycl::event> &);
219
+
220
+ template <typename dataT, typename indT>
221
+ sycl::event masked_extract_all_slices_contig_impl (
222
+ sycl::queue &exec_q,
223
+ ssize_t iteration_size,
224
+ const char *src_p,
225
+ const char *cumsum_p,
226
+ char *dst_p,
227
+ ssize_t dst_size, // dst is 1D
228
+ ssize_t dst_stride,
229
+ const std::vector<sycl::event> &depends = {})
230
+ {
231
+ constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232
+
233
+ constexpr NoOpIndexer masked_src_indexer{};
234
+ const Strided1DIndexer masked_dst_indexer (/* size */ dst_size,
235
+ /* step */ dst_stride);
236
+
237
+ using KernelName =
238
+ class masked_extract_all_slices_contig_impl_krn <Strided1DIndexer, dataT,
239
+ indT>;
240
+
241
+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
242
+ using Impl =
243
+ struct MaskedExtractStridedFunctor <TwoZeroOffsets_Indexer, NoOpIndexer,
244
+ Strided1DIndexer, dataT, indT,
245
+ LocalAccessorT>;
246
+
247
+ constexpr std::size_t nominal_lws = 256 ;
248
+ const std::size_t masked_extent = iteration_size;
249
+ const std::size_t lws = std::min (masked_extent, nominal_lws);
250
+ const std::size_t n_groups = (iteration_size + lws - 1 ) / lws;
251
+
252
+ sycl::range<2 > gRange {1 , n_groups * lws};
253
+ sycl::range<2 > lRange{1 , lws};
254
+
255
+ sycl::nd_range<2 > ndRange (gRange , lRange);
256
+
257
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
258
+ cgh.depends_on (depends);
259
+
260
+ const std::size_t lacc_size = std::min (lws, masked_extent) + 1 ;
261
+ LocalAccessorT lacc (lacc_size, cgh);
262
+
263
+ cgh.parallel_for <KernelName>(
264
+ ndRange,
265
+ Impl (src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
266
+ masked_src_indexer, masked_dst_indexer, lacc));
267
+ });
268
+
269
+ return comp_ev;
270
+ }
271
+
272
+ template <typename MaskedSrcIndexerT,
195
273
typename MaskedDstIndexerT,
196
274
typename dataT,
197
275
typename indT>
@@ -223,11 +301,6 @@ sycl::event masked_extract_all_slices_strided_impl(
223
301
ssize_t dst_stride,
224
302
const std::vector<sycl::event> &depends = {})
225
303
{
226
- // using MaskedExtractStridedFunctor;
227
- // using Strided1DIndexer;
228
- // using StridedIndexer;
229
- // using TwoZeroOffsets_Indexer;
230
-
231
304
constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232
305
233
306
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
@@ -236,18 +309,35 @@ sycl::event masked_extract_all_slices_strided_impl(
236
309
const Strided1DIndexer masked_dst_indexer (/* size */ dst_size,
237
310
/* step */ dst_stride);
238
311
312
+ using KernelName = class masked_extract_all_slices_strided_impl_krn <
313
+ StridedIndexer, Strided1DIndexer, dataT, indT>;
314
+
315
+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
316
+ using Impl =
317
+ struct MaskedExtractStridedFunctor <TwoZeroOffsets_Indexer,
318
+ StridedIndexer, Strided1DIndexer,
319
+ dataT, indT, LocalAccessorT>;
320
+
321
+ constexpr std::size_t nominal_lws = 256 ;
322
+ const std::size_t masked_nelems = iteration_size;
323
+ const std::size_t lws = std::min (masked_nelems, nominal_lws);
324
+ const std::size_t n_groups = (masked_nelems + lws - 1 ) / lws;
325
+
326
+ sycl::range<2 > gRange {1 , n_groups * lws};
327
+ sycl::range<2 > lRange{1 , lws};
328
+
329
+ sycl::nd_range<2 > ndRange (gRange , lRange);
330
+
239
331
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
240
332
cgh.depends_on (depends);
241
333
242
- cgh.parallel_for <class masked_extract_all_slices_strided_impl_krn <
243
- TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, dataT,
244
- indT>>(
245
- sycl::range<1 >(static_cast <size_t >(iteration_size)),
246
- MaskedExtractStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
247
- Strided1DIndexer, dataT, indT>(
248
- src_p, cumsum_p, dst_p, 1 , iteration_size,
249
- orthog_src_dst_indexer, masked_src_indexer,
250
- masked_dst_indexer));
334
+ const std::size_t lacc_size = std::min (lws, masked_nelems) + 1 ;
335
+ LocalAccessorT lacc (lacc_size, cgh);
336
+
337
+ cgh.parallel_for <KernelName>(
338
+ ndRange,
339
+ Impl (src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
340
+ masked_src_indexer, masked_dst_indexer, lacc));
251
341
});
252
342
253
343
return comp_ev;
@@ -299,11 +389,6 @@ sycl::event masked_extract_some_slices_strided_impl(
299
389
ssize_t masked_dst_stride,
300
390
const std::vector<sycl::event> &depends = {})
301
391
{
302
- // using MaskedExtractStridedFunctor;
303
- // using Strided1DIndexer;
304
- // using StridedIndexer;
305
- // using TwoOffsets_StridedIndexer;
306
-
307
392
const TwoOffsets_StridedIndexer orthog_src_dst_indexer{
308
393
orthog_nd, ortho_src_offset, ortho_dst_offset,
309
394
packed_ortho_src_dst_shape_strides};
@@ -313,24 +398,63 @@ sycl::event masked_extract_some_slices_strided_impl(
313
398
const Strided1DIndexer masked_dst_indexer{/* size */ masked_dst_size,
314
399
/* step */ masked_dst_stride};
315
400
401
+ using KernelName = class masked_extract_some_slices_strided_impl_krn <
402
+ TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
403
+ indT>;
404
+
405
+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
406
+ using Impl =
407
+ struct MaskedExtractStridedFunctor <TwoOffsets_StridedIndexer,
408
+ StridedIndexer, Strided1DIndexer,
409
+ dataT, indT, LocalAccessorT>;
410
+
411
+ const size_t nominal_lws = 256 ;
412
+ const std::size_t masked_extent = masked_nelems;
413
+ const size_t lws = std::min (masked_extent, nominal_lws);
414
+ const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
415
+ const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
416
+
417
+ sycl::range<2 > gRange {orthog_extent, n_groups * lws};
418
+ sycl::range<2 > lRange{1 , lws};
419
+
420
+ sycl::nd_range<2 > ndRange (gRange , lRange);
421
+
316
422
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
317
423
cgh.depends_on (depends);
318
424
319
- cgh.parallel_for <class masked_extract_some_slices_strided_impl_krn <
320
- TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
321
- indT>>(
322
- sycl::range<1 >(static_cast <size_t >(orthog_nelems * masked_nelems)),
323
- MaskedExtractStridedFunctor<TwoOffsets_StridedIndexer,
324
- StridedIndexer, Strided1DIndexer, dataT,
325
- indT>(
326
- src_p, cumsum_p, dst_p, orthog_nelems, masked_nelems,
327
- orthog_src_dst_indexer, masked_src_indexer,
328
- masked_dst_indexer));
425
+ const std::size_t lacc_size =
426
+ std::min<std::size_t >(lws, masked_extent) + 1 ;
427
+ LocalAccessorT lacc (lacc_size, cgh);
428
+
429
+ cgh.parallel_for <KernelName>(
430
+ ndRange,
431
+ Impl (src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
432
+ masked_src_indexer, masked_dst_indexer, lacc));
329
433
});
330
434
331
435
return comp_ev;
332
436
}
333
437
438
+ template <typename fnT, typename T>
439
+ struct MaskExtractAllSlicesContigFactoryForInt32
440
+ {
441
+ fnT get ()
442
+ {
443
+ fnT fn = masked_extract_all_slices_contig_impl<T, std::int32_t >;
444
+ return fn;
445
+ }
446
+ };
447
+
448
+ template <typename fnT, typename T>
449
+ struct MaskExtractAllSlicesContigFactoryForInt64
450
+ {
451
+ fnT get ()
452
+ {
453
+ fnT fn = masked_extract_all_slices_contig_impl<T, std::int64_t >;
454
+ return fn;
455
+ }
456
+ };
457
+
334
458
template <typename fnT, typename T>
335
459
struct MaskExtractAllSlicesStridedFactoryForInt32
336
460
{
@@ -487,13 +611,17 @@ sycl::event masked_place_some_slices_strided_impl(
487
611
const Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
488
612
masked_rhs_stride};
489
613
614
+ using KernelName = class masked_place_some_slices_strided_impl_krn <
615
+ TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
616
+ dataT, indT>;
617
+
618
+ sycl::range<1 > gRange (static_cast <size_t >(orthog_nelems * masked_nelems));
619
+
490
620
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
491
621
cgh.depends_on (depends);
492
622
493
- cgh.parallel_for <class masked_place_some_slices_strided_impl_krn <
494
- TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
495
- dataT, indT>>(
496
- sycl::range<1 >(static_cast <size_t >(orthog_nelems * masked_nelems)),
623
+ cgh.parallel_for <KernelName>(
624
+ gRange ,
497
625
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
498
626
Strided1DCyclicIndexer, dataT, indT>(
499
627
dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
0 commit comments