@@ -134,51 +134,65 @@ template <typename OrthogIndexerT,
134
134
typename MaskedDstIndexerT,
135
135
typename MaskedRhsIndexerT,
136
136
typename dataT,
137
- typename indT>
137
+ typename indT,
138
+ typename LocalAccessorT>
138
139
struct MaskedPlaceStridedFunctor
139
140
{
140
141
MaskedPlaceStridedFunctor (char *dst_data_p,
141
142
const char *cumsum_data_p,
142
143
const char *rhs_data_p,
143
- size_t orthog_iter_size,
144
144
size_t masked_iter_size,
145
145
const OrthogIndexerT &orthog_dst_rhs_indexer_,
146
146
const MaskedDstIndexerT &masked_dst_indexer_,
147
- const MaskedRhsIndexerT &masked_rhs_indexer_)
147
+ const MaskedRhsIndexerT &masked_rhs_indexer_,
148
+ const LocalAccessorT &lacc_)
148
149
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
149
- orthog_nelems (orthog_iter_size), masked_nelems(masked_iter_size),
150
+ masked_nelems (masked_iter_size),
150
151
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
151
152
masked_dst_indexer(masked_dst_indexer_),
152
- masked_rhs_indexer(masked_rhs_indexer_)
153
+ masked_rhs_indexer(masked_rhs_indexer_), lacc(lacc_)
153
154
{
155
+ static_assert (
156
+ std::is_same_v<indT, typename LocalAccessorT::value_type>);
154
157
}
155
158
156
- void operator ()(sycl::id< 1 > idx ) const
159
+ void operator ()(sycl::nd_item< 2 > ndit ) const
157
160
{
158
161
dataT *dst_data = reinterpret_cast <dataT *>(dst_cp);
159
162
const indT *cumsum_data = reinterpret_cast <const indT *>(cumsum_cp);
160
163
const dataT *rhs_data = reinterpret_cast <const dataT *>(rhs_cp);
161
164
162
- size_t global_i = idx[0 ];
163
- size_t orthog_i = global_i / masked_nelems;
164
- size_t masked_i = global_i - masked_nelems * orthog_i;
165
-
166
- indT current_running_count = cumsum_data[masked_i];
167
- bool mask_set =
168
- (masked_i == 0 )
169
- ? (current_running_count == 1 )
170
- : (current_running_count == cumsum_data[masked_i - 1 ] + 1 );
171
-
172
- // src[i, j] = rhs[cumsum[i] - 1, j] if cumsum[i] == ((i > 0) ?
173
- // cumsum[i-1]
174
- // + 1 : 1)
175
- if (mask_set) {
176
- auto orthog_offsets =
177
- orthog_dst_rhs_indexer (static_cast <ssize_t >(orthog_i));
178
-
179
- size_t total_dst_offset = masked_dst_indexer (masked_i) +
180
- orthog_offsets.get_first_offset ();
181
- size_t total_rhs_offset =
165
+ const std::size_t orthog_i = ndit.get_global_id (0 );
166
+ const std::size_t group_i = ndit.get_group (1 );
167
+ const std::uint32_t l_i = ndit.get_local_id (1 );
168
+ const std::uint32_t lws = ndit.get_local_range (1 );
169
+
170
+ const size_t masked_block_start = group_i * lws;
171
+ const size_t masked_i = masked_block_start + l_i;
172
+
173
+ for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
174
+ const size_t offset = masked_block_start + i;
175
+ lacc[i] = (offset == 0 ) ? indT (0 )
176
+ : (offset - 1 < masked_nelems)
177
+ ? cumsum_data[offset - 1 ]
178
+ : cumsum_data[masked_nelems - 1 ] + 1 ;
179
+ }
180
+
181
+ sycl::group_barrier (ndit.get_group ());
182
+
183
+ const indT current_running_count = lacc[l_i + 1 ];
184
+ const bool mask_set = (masked_i == 0 )
185
+ ? (current_running_count == 1 )
186
+ : (current_running_count == lacc[l_i] + 1 );
187
+
188
+ // src[i, j] = rhs[cumsum[i] - 1, j]
189
+ // if cumsum[i] == ((i > 0) ? cumsum[i-1] + 1 : 1)
190
+ if (mask_set && (masked_i < masked_nelems)) {
191
+ const auto &orthog_offsets = orthog_dst_rhs_indexer (orthog_i);
192
+
193
+ const size_t total_dst_offset = masked_dst_indexer (masked_i) +
194
+ orthog_offsets.get_first_offset ();
195
+ const size_t total_rhs_offset =
182
196
masked_rhs_indexer (current_running_count - 1 ) +
183
197
orthog_offsets.get_second_offset ();
184
198
@@ -190,8 +204,7 @@ struct MaskedPlaceStridedFunctor
190
204
char *dst_cp = nullptr ;
191
205
const char *cumsum_cp = nullptr ;
192
206
const char *rhs_cp = nullptr ;
193
- size_t orthog_nelems = 0 ;
194
- size_t masked_nelems = 0 ;
207
+ const size_t masked_nelems = 0 ;
195
208
// has nd, shape, dst_strides, rhs_strides for
196
209
// dimensions that ARE NOT masked
197
210
const OrthogIndexerT orthog_dst_rhs_indexer;
@@ -200,6 +213,7 @@ struct MaskedPlaceStridedFunctor
200
213
const MaskedDstIndexerT masked_dst_indexer;
201
214
// has 1, rhs_strides for dimensions that ARE masked
202
215
const MaskedRhsIndexerT masked_rhs_indexer;
216
+ LocalAccessorT lacc;
203
217
};
204
218
205
219
// ======= Masked extraction ================================
@@ -537,18 +551,35 @@ sycl::event masked_place_all_slices_strided_impl(
537
551
const StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
538
552
const Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
539
553
554
+ using KernelName = class masked_place_all_slices_strided_impl_krn <
555
+ TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer, dataT,
556
+ indT>;
557
+
558
+ constexpr std::size_t nominal_lws = 256 ;
559
+ const std::size_t masked_extent = iteration_size;
560
+ const std::size_t lws = std::min (masked_extent, nominal_lws);
561
+
562
+ const std::size_t n_groups = (masked_extent + lws - 1 ) / lws;
563
+
564
+ sycl::range<2 > gRange {1 , n_groups * lws};
565
+ sycl::range<2 > lRange{1 , lws};
566
+ sycl::nd_range<2 > ndRange{gRange , lRange};
567
+
568
+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
569
+
540
570
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
541
571
cgh.depends_on (depends);
542
572
543
- cgh.parallel_for <class masked_place_all_slices_strided_impl_krn <
544
- TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer,
545
- dataT, indT>>(
546
- sycl::range<1 >(static_cast <size_t >(iteration_size)),
573
+ const std::size_t lacc_size = std::min (masked_extent, lws) + 1 ;
574
+ LocalAccessorT lacc (lacc_size, cgh);
575
+
576
+ cgh.parallel_for <KernelName>(
577
+ ndRange,
547
578
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
548
- Strided1DCyclicIndexer, dataT, indT>(
549
- dst_p, cumsum_p, rhs_p, 1 , iteration_size,
550
- orthog_dst_rhs_indexer, masked_dst_indexer ,
551
- masked_rhs_indexer));
579
+ Strided1DCyclicIndexer, dataT, indT,
580
+ LocalAccessorT>(
581
+ dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer ,
582
+ masked_dst_indexer, masked_rhs_indexer, lacc ));
552
583
});
553
584
554
585
return comp_ev;
@@ -615,18 +646,32 @@ sycl::event masked_place_some_slices_strided_impl(
615
646
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
616
647
dataT, indT>;
617
648
618
- sycl::range<1 > gRange (static_cast <size_t >(orthog_nelems * masked_nelems));
649
+ constexpr std::size_t nominal_lws = 256 ;
650
+ const std::size_t orthog_extent = orthog_nelems;
651
+ const std::size_t masked_extent = masked_nelems;
652
+ const std::size_t lws = std::min (masked_extent, nominal_lws);
653
+
654
+ const std::size_t n_groups = (masked_extent + lws - 1 ) / lws;
655
+
656
+ sycl::range<2 > gRange {orthog_extent, n_groups * lws};
657
+ sycl::range<2 > lRange{1 , lws};
658
+ sycl::nd_range<2 > ndRange{gRange , lRange};
659
+
660
+ using LocalAccessorT = sycl::local_accessor<indT, 1 >;
619
661
620
662
sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
621
663
cgh.depends_on (depends);
622
664
665
+ const std::size_t lacc_size = std::min (masked_extent, lws) + 1 ;
666
+ LocalAccessorT lacc (lacc_size, cgh);
667
+
623
668
cgh.parallel_for <KernelName>(
624
- gRange ,
669
+ ndRange ,
625
670
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
626
- Strided1DCyclicIndexer, dataT, indT>(
627
- dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
628
- orthog_dst_rhs_indexer, masked_dst_indexer ,
629
- masked_rhs_indexer));
671
+ Strided1DCyclicIndexer, dataT, indT,
672
+ LocalAccessorT>(
673
+ dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer ,
674
+ masked_dst_indexer, masked_rhs_indexer, lacc ));
630
675
});
631
676
632
677
return comp_ev;
0 commit comments