Skip to content

Commit beb456f

Browse files
Improvement in boolean index extract
1. Use shared local memory to optimize access to neighboring elements of cumulative sums. 2. Introduce contig variant for masked_extract code 3. Removed unused orthog_nelems functor argument, and added local_accessor argument instead. The example ``` import dpctl.tensor as dpt x = dpt.ones(20241024, dtype='f4') m = dpt.ones(x.size, dtype='b1') %time x[m] ``` decreased from 41ms on Iris Xe WSL box to 37 ms.
1 parent cb4a049 commit beb456f

File tree

2 files changed

+269
-98
lines changed

2 files changed

+269
-98
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 188 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -48,51 +48,65 @@ template <typename OrthogIndexerT,
4848
typename MaskedSrcIndexerT,
4949
typename MaskedDstIndexerT,
5050
typename dataT,
51-
typename indT>
51+
typename indT,
52+
typename LocalAccessorT>
5253
struct MaskedExtractStridedFunctor
5354
{
5455
MaskedExtractStridedFunctor(const char *src_data_p,
5556
const char *cumsum_data_p,
5657
char *dst_data_p,
57-
size_t orthog_iter_size,
5858
size_t masked_iter_size,
5959
const OrthogIndexerT &orthog_src_dst_indexer_,
6060
const MaskedSrcIndexerT &masked_src_indexer_,
61-
const MaskedDstIndexerT &masked_dst_indexer_)
61+
const MaskedDstIndexerT &masked_dst_indexer_,
62+
const LocalAccessorT &lacc_)
6263
: src_cp(src_data_p), cumsum_cp(cumsum_data_p), dst_cp(dst_data_p),
63-
orthog_nelems(orthog_iter_size), masked_nelems(masked_iter_size),
64+
masked_nelems(masked_iter_size),
6465
orthog_src_dst_indexer(orthog_src_dst_indexer_),
6566
masked_src_indexer(masked_src_indexer_),
66-
masked_dst_indexer(masked_dst_indexer_)
67+
masked_dst_indexer(masked_dst_indexer_), lacc(lacc_)
6768
{
69+
static_assert(
70+
std::is_same_v<indT, typename LocalAccessorT::value_type>);
6871
}
6972

70-
void operator()(sycl::id<1> idx) const
73+
void operator()(sycl::nd_item<2> ndit) const
7174
{
7275
const dataT *src_data = reinterpret_cast<const dataT *>(src_cp);
7376
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
7477
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
7578

76-
size_t global_i = idx[0];
77-
size_t orthog_i = global_i / masked_nelems;
78-
size_t masked_i = global_i - masked_nelems * orthog_i;
79+
const size_t orthog_i = ndit.get_global_id(0);
80+
const size_t group_i = ndit.get_group(1);
81+
const std::uint32_t l_i = ndit.get_local_id(1);
82+
const std::uint32_t lws = ndit.get_local_range(1);
7983

80-
indT current_running_count = cumsum_data[masked_i];
81-
bool mask_set =
82-
(masked_i == 0)
83-
? (current_running_count == 1)
84-
: (current_running_count == cumsum_data[masked_i - 1] + 1);
84+
const size_t masked_block_start = group_i * lws;
85+
const size_t masked_i = masked_block_start + l_i;
8586

86-
// dst[cumsum[i], j] - 1 = src[i, j] if cumsum[i] == ((i > 0) ?
87-
// cumsum[i-1]
88-
// + 1 : 1)
89-
if (mask_set) {
90-
auto orthog_offsets =
91-
orthog_src_dst_indexer(static_cast<ssize_t>(orthog_i));
87+
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
88+
const size_t offset = masked_block_start + i;
89+
lacc[i] = (offset == 0) ? indT(0)
90+
: (offset - 1 < masked_nelems)
91+
? cumsum_data[offset - 1]
92+
: cumsum_data[masked_nelems - 1] + 1;
93+
}
9294

93-
size_t total_src_offset = masked_src_indexer(masked_i) +
94-
orthog_offsets.get_first_offset();
95-
size_t total_dst_offset =
95+
sycl::group_barrier(ndit.get_group());
96+
97+
const indT current_running_count = lacc[l_i + 1];
98+
const bool mask_set = (masked_i == 0)
99+
? (current_running_count == 1)
100+
: (current_running_count == lacc[l_i] + 1);
101+
102+
// dst[cumsum[i] - 1, j] = src[i, j]
103+
// if cumsum[i] == ((i > 0) ? cumsum[i-1] + 1 : 1)
104+
if (mask_set && (masked_i < masked_nelems)) {
105+
const auto &orthog_offsets = orthog_src_dst_indexer(orthog_i);
106+
107+
const size_t total_src_offset = masked_src_indexer(masked_i) +
108+
orthog_offsets.get_first_offset();
109+
const size_t total_dst_offset =
96110
masked_dst_indexer(current_running_count - 1) +
97111
orthog_offsets.get_second_offset();
98112

@@ -104,8 +118,7 @@ struct MaskedExtractStridedFunctor
104118
const char *src_cp = nullptr;
105119
const char *cumsum_cp = nullptr;
106120
char *dst_cp = nullptr;
107-
size_t orthog_nelems = 0;
108-
size_t masked_nelems = 0;
121+
const size_t masked_nelems = 0;
109122
// has nd, shape, src_strides, dst_strides for
110123
// dimensions that ARE NOT masked
111124
const OrthogIndexerT orthog_src_dst_indexer;
@@ -114,6 +127,7 @@ struct MaskedExtractStridedFunctor
114127
const MaskedSrcIndexerT masked_src_indexer;
115128
// has 1, dst_strides for dimensions that ARE masked
116129
const MaskedDstIndexerT masked_dst_indexer;
130+
LocalAccessorT lacc;
117131
};
118132

119133
template <typename OrthogIndexerT,
@@ -190,8 +204,72 @@ struct MaskedPlaceStridedFunctor
190204

191205
// ======= Masked extraction ================================
192206

193-
template <typename OrthoIndexerT,
194-
typename MaskedSrcIndexerT,
207+
template <typename MaskedDstIndexerT, typename dataT, typename indT>
208+
class masked_extract_all_slices_contig_impl_krn;
209+
210+
typedef sycl::event (*masked_extract_all_slices_contig_impl_fn_ptr_t)(
211+
sycl::queue &,
212+
ssize_t,
213+
const char *,
214+
const char *,
215+
char *,
216+
ssize_t,
217+
ssize_t,
218+
const std::vector<sycl::event> &);
219+
220+
template <typename dataT, typename indT>
221+
sycl::event masked_extract_all_slices_contig_impl(
222+
sycl::queue &exec_q,
223+
ssize_t iteration_size,
224+
const char *src_p,
225+
const char *cumsum_p,
226+
char *dst_p,
227+
ssize_t dst_size, // dst is 1D
228+
ssize_t dst_stride,
229+
const std::vector<sycl::event> &depends = {})
230+
{
231+
constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232+
233+
constexpr NoOpIndexer masked_src_indexer{};
234+
const Strided1DIndexer masked_dst_indexer(/* size */ dst_size,
235+
/* step */ dst_stride);
236+
237+
using KernelName =
238+
class masked_extract_all_slices_contig_impl_krn<Strided1DIndexer, dataT,
239+
indT>;
240+
241+
using LocalAccessorT = sycl::local_accessor<indT, 1>;
242+
using Impl =
243+
struct MaskedExtractStridedFunctor<TwoZeroOffsets_Indexer, NoOpIndexer,
244+
Strided1DIndexer, dataT, indT,
245+
LocalAccessorT>;
246+
247+
constexpr std::size_t nominal_lws = 256;
248+
const std::size_t masked_extent = iteration_size;
249+
const std::size_t lws = std::min(masked_extent, nominal_lws);
250+
const std::size_t n_groups = (iteration_size + lws - 1) / lws;
251+
252+
sycl::range<2> gRange{1, n_groups * lws};
253+
sycl::range<2> lRange{1, lws};
254+
255+
sycl::nd_range<2> ndRange(gRange, lRange);
256+
257+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
258+
cgh.depends_on(depends);
259+
260+
const std::size_t lacc_size = std::min(lws, masked_extent) + 1;
261+
LocalAccessorT lacc(lacc_size, cgh);
262+
263+
cgh.parallel_for<KernelName>(
264+
ndRange,
265+
Impl(src_p, cumsum_p, dst_p, masked_extent, orthog_src_dst_indexer,
266+
masked_src_indexer, masked_dst_indexer, lacc));
267+
});
268+
269+
return comp_ev;
270+
}
271+
272+
template <typename MaskedSrcIndexerT,
195273
typename MaskedDstIndexerT,
196274
typename dataT,
197275
typename indT>
@@ -223,11 +301,6 @@ sycl::event masked_extract_all_slices_strided_impl(
223301
ssize_t dst_stride,
224302
const std::vector<sycl::event> &depends = {})
225303
{
226-
// using MaskedExtractStridedFunctor;
227-
// using Strided1DIndexer;
228-
// using StridedIndexer;
229-
// using TwoZeroOffsets_Indexer;
230-
231304
constexpr TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232305

233306
/* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
@@ -236,18 +309,35 @@ sycl::event masked_extract_all_slices_strided_impl(
236309
const Strided1DIndexer masked_dst_indexer(/* size */ dst_size,
237310
/* step */ dst_stride);
238311

312+
using KernelName = class masked_extract_all_slices_strided_impl_krn<
313+
StridedIndexer, Strided1DIndexer, dataT, indT>;
314+
315+
using LocalAccessorT = sycl::local_accessor<indT, 1>;
316+
using Impl =
317+
struct MaskedExtractStridedFunctor<TwoZeroOffsets_Indexer,
318+
StridedIndexer, Strided1DIndexer,
319+
dataT, indT, LocalAccessorT>;
320+
321+
constexpr std::size_t nominal_lws = 256;
322+
const std::size_t masked_nelems = iteration_size;
323+
const std::size_t lws = std::min(masked_nelems, nominal_lws);
324+
const std::size_t n_groups = (masked_nelems + lws - 1) / lws;
325+
326+
sycl::range<2> gRange{1, n_groups * lws};
327+
sycl::range<2> lRange{1, lws};
328+
329+
sycl::nd_range<2> ndRange(gRange, lRange);
330+
239331
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
240332
cgh.depends_on(depends);
241333

242-
cgh.parallel_for<class masked_extract_all_slices_strided_impl_krn<
243-
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, dataT,
244-
indT>>(
245-
sycl::range<1>(static_cast<size_t>(iteration_size)),
246-
MaskedExtractStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
247-
Strided1DIndexer, dataT, indT>(
248-
src_p, cumsum_p, dst_p, 1, iteration_size,
249-
orthog_src_dst_indexer, masked_src_indexer,
250-
masked_dst_indexer));
334+
const std::size_t lacc_size = std::min(lws, masked_nelems) + 1;
335+
LocalAccessorT lacc(lacc_size, cgh);
336+
337+
cgh.parallel_for<KernelName>(
338+
ndRange,
339+
Impl(src_p, cumsum_p, dst_p, iteration_size, orthog_src_dst_indexer,
340+
masked_src_indexer, masked_dst_indexer, lacc));
251341
});
252342

253343
return comp_ev;
@@ -299,11 +389,6 @@ sycl::event masked_extract_some_slices_strided_impl(
299389
ssize_t masked_dst_stride,
300390
const std::vector<sycl::event> &depends = {})
301391
{
302-
// using MaskedExtractStridedFunctor;
303-
// using Strided1DIndexer;
304-
// using StridedIndexer;
305-
// using TwoOffsets_StridedIndexer;
306-
307392
const TwoOffsets_StridedIndexer orthog_src_dst_indexer{
308393
orthog_nd, ortho_src_offset, ortho_dst_offset,
309394
packed_ortho_src_dst_shape_strides};
@@ -313,24 +398,63 @@ sycl::event masked_extract_some_slices_strided_impl(
313398
const Strided1DIndexer masked_dst_indexer{/* size */ masked_dst_size,
314399
/* step */ masked_dst_stride};
315400

401+
using KernelName = class masked_extract_some_slices_strided_impl_krn<
402+
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
403+
indT>;
404+
405+
using LocalAccessorT = sycl::local_accessor<indT, 1>;
406+
using Impl =
407+
struct MaskedExtractStridedFunctor<TwoOffsets_StridedIndexer,
408+
StridedIndexer, Strided1DIndexer,
409+
dataT, indT, LocalAccessorT>;
410+
411+
const size_t nominal_lws = 256;
412+
const std::size_t masked_extent = masked_nelems;
413+
const size_t lws = std::min(masked_extent, nominal_lws);
414+
const size_t n_groups = ((masked_extent + lws - 1) / lws);
415+
const size_t orthog_extent = static_cast<size_t>(orthog_nelems);
416+
417+
sycl::range<2> gRange{orthog_extent, n_groups * lws};
418+
sycl::range<2> lRange{1, lws};
419+
420+
sycl::nd_range<2> ndRange(gRange, lRange);
421+
316422
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
317423
cgh.depends_on(depends);
318424

319-
cgh.parallel_for<class masked_extract_some_slices_strided_impl_krn<
320-
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT,
321-
indT>>(
322-
sycl::range<1>(static_cast<size_t>(orthog_nelems * masked_nelems)),
323-
MaskedExtractStridedFunctor<TwoOffsets_StridedIndexer,
324-
StridedIndexer, Strided1DIndexer, dataT,
325-
indT>(
326-
src_p, cumsum_p, dst_p, orthog_nelems, masked_nelems,
327-
orthog_src_dst_indexer, masked_src_indexer,
328-
masked_dst_indexer));
425+
const std::size_t lacc_size =
426+
std::min<std::size_t>(lws, masked_extent) + 1;
427+
LocalAccessorT lacc(lacc_size, cgh);
428+
429+
cgh.parallel_for<KernelName>(
430+
ndRange,
431+
Impl(src_p, cumsum_p, dst_p, masked_nelems, orthog_src_dst_indexer,
432+
masked_src_indexer, masked_dst_indexer, lacc));
329433
});
330434

331435
return comp_ev;
332436
}
333437

438+
template <typename fnT, typename T>
439+
struct MaskExtractAllSlicesContigFactoryForInt32
440+
{
441+
fnT get()
442+
{
443+
fnT fn = masked_extract_all_slices_contig_impl<T, std::int32_t>;
444+
return fn;
445+
}
446+
};
447+
448+
template <typename fnT, typename T>
449+
struct MaskExtractAllSlicesContigFactoryForInt64
450+
{
451+
fnT get()
452+
{
453+
fnT fn = masked_extract_all_slices_contig_impl<T, std::int64_t>;
454+
return fn;
455+
}
456+
};
457+
334458
template <typename fnT, typename T>
335459
struct MaskExtractAllSlicesStridedFactoryForInt32
336460
{
@@ -487,13 +611,17 @@ sycl::event masked_place_some_slices_strided_impl(
487611
const Strided1DCyclicIndexer masked_rhs_indexer{0, masked_rhs_size,
488612
masked_rhs_stride};
489613

614+
using KernelName = class masked_place_some_slices_strided_impl_krn<
615+
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
616+
dataT, indT>;
617+
618+
sycl::range<1> gRange(static_cast<size_t>(orthog_nelems * masked_nelems));
619+
490620
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
491621
cgh.depends_on(depends);
492622

493-
cgh.parallel_for<class masked_place_some_slices_strided_impl_krn<
494-
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
495-
dataT, indT>>(
496-
sycl::range<1>(static_cast<size_t>(orthog_nelems * masked_nelems)),
623+
cgh.parallel_for<KernelName>(
624+
gRange,
497625
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
498626
Strided1DCyclicIndexer, dataT, indT>(
499627
dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,

0 commit comments

Comments
 (0)