Skip to content

Commit 5161b8e

Browse files
MaskedPlaced kernel optimized
Use local_accessor to improve memory bandwidth of the work-group.
1 parent beb456f commit 5161b8e

File tree

1 file changed

+87
-42
lines changed

1 file changed

+87
-42
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -134,51 +134,65 @@ template <typename OrthogIndexerT,
134134
typename MaskedDstIndexerT,
135135
typename MaskedRhsIndexerT,
136136
typename dataT,
137-
typename indT>
137+
typename indT,
138+
typename LocalAccessorT>
138139
struct MaskedPlaceStridedFunctor
139140
{
140141
MaskedPlaceStridedFunctor(char *dst_data_p,
141142
const char *cumsum_data_p,
142143
const char *rhs_data_p,
143-
size_t orthog_iter_size,
144144
size_t masked_iter_size,
145145
const OrthogIndexerT &orthog_dst_rhs_indexer_,
146146
const MaskedDstIndexerT &masked_dst_indexer_,
147-
const MaskedRhsIndexerT &masked_rhs_indexer_)
147+
const MaskedRhsIndexerT &masked_rhs_indexer_,
148+
const LocalAccessorT &lacc_)
148149
: dst_cp(dst_data_p), cumsum_cp(cumsum_data_p), rhs_cp(rhs_data_p),
149-
orthog_nelems(orthog_iter_size), masked_nelems(masked_iter_size),
150+
masked_nelems(masked_iter_size),
150151
orthog_dst_rhs_indexer(orthog_dst_rhs_indexer_),
151152
masked_dst_indexer(masked_dst_indexer_),
152-
masked_rhs_indexer(masked_rhs_indexer_)
153+
masked_rhs_indexer(masked_rhs_indexer_), lacc(lacc_)
153154
{
155+
static_assert(
156+
std::is_same_v<indT, typename LocalAccessorT::value_type>);
154157
}
155158

156-
void operator()(sycl::id<1> idx) const
159+
void operator()(sycl::nd_item<2> ndit) const
157160
{
158161
dataT *dst_data = reinterpret_cast<dataT *>(dst_cp);
159162
const indT *cumsum_data = reinterpret_cast<const indT *>(cumsum_cp);
160163
const dataT *rhs_data = reinterpret_cast<const dataT *>(rhs_cp);
161164

162-
size_t global_i = idx[0];
163-
size_t orthog_i = global_i / masked_nelems;
164-
size_t masked_i = global_i - masked_nelems * orthog_i;
165-
166-
indT current_running_count = cumsum_data[masked_i];
167-
bool mask_set =
168-
(masked_i == 0)
169-
? (current_running_count == 1)
170-
: (current_running_count == cumsum_data[masked_i - 1] + 1);
171-
172-
// src[i, j] = rhs[cumsum[i] - 1, j] if cumsum[i] == ((i > 0) ?
173-
// cumsum[i-1]
174-
// + 1 : 1)
175-
if (mask_set) {
176-
auto orthog_offsets =
177-
orthog_dst_rhs_indexer(static_cast<ssize_t>(orthog_i));
178-
179-
size_t total_dst_offset = masked_dst_indexer(masked_i) +
180-
orthog_offsets.get_first_offset();
181-
size_t total_rhs_offset =
165+
const std::size_t orthog_i = ndit.get_global_id(0);
166+
const std::size_t group_i = ndit.get_group(1);
167+
const std::uint32_t l_i = ndit.get_local_id(1);
168+
const std::uint32_t lws = ndit.get_local_range(1);
169+
170+
const size_t masked_block_start = group_i * lws;
171+
const size_t masked_i = masked_block_start + l_i;
172+
173+
for (std::uint32_t i = l_i; i < lacc.size(); i += lws) {
174+
const size_t offset = masked_block_start + i;
175+
lacc[i] = (offset == 0) ? indT(0)
176+
: (offset - 1 < masked_nelems)
177+
? cumsum_data[offset - 1]
178+
: cumsum_data[masked_nelems - 1] + 1;
179+
}
180+
181+
sycl::group_barrier(ndit.get_group());
182+
183+
const indT current_running_count = lacc[l_i + 1];
184+
const bool mask_set = (masked_i == 0)
185+
? (current_running_count == 1)
186+
: (current_running_count == lacc[l_i] + 1);
187+
188+
// src[i, j] = rhs[cumsum[i] - 1, j]
189+
// if cumsum[i] == ((i > 0) ? cumsum[i-1] + 1 : 1)
190+
if (mask_set && (masked_i < masked_nelems)) {
191+
const auto &orthog_offsets = orthog_dst_rhs_indexer(orthog_i);
192+
193+
const size_t total_dst_offset = masked_dst_indexer(masked_i) +
194+
orthog_offsets.get_first_offset();
195+
const size_t total_rhs_offset =
182196
masked_rhs_indexer(current_running_count - 1) +
183197
orthog_offsets.get_second_offset();
184198

@@ -190,8 +204,7 @@ struct MaskedPlaceStridedFunctor
190204
char *dst_cp = nullptr;
191205
const char *cumsum_cp = nullptr;
192206
const char *rhs_cp = nullptr;
193-
size_t orthog_nelems = 0;
194-
size_t masked_nelems = 0;
207+
const size_t masked_nelems = 0;
195208
// has nd, shape, dst_strides, rhs_strides for
196209
// dimensions that ARE NOT masked
197210
const OrthogIndexerT orthog_dst_rhs_indexer;
@@ -200,6 +213,7 @@ struct MaskedPlaceStridedFunctor
200213
const MaskedDstIndexerT masked_dst_indexer;
201214
// has 1, rhs_strides for dimensions that ARE masked
202215
const MaskedRhsIndexerT masked_rhs_indexer;
216+
LocalAccessorT lacc;
203217
};
204218

205219
// ======= Masked extraction ================================
@@ -537,18 +551,35 @@ sycl::event masked_place_all_slices_strided_impl(
537551
const StridedIndexer masked_dst_indexer(nd, 0, packed_dst_shape_strides);
538552
const Strided1DCyclicIndexer masked_rhs_indexer(0, rhs_size, rhs_stride);
539553

554+
using KernelName = class masked_place_all_slices_strided_impl_krn<
555+
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer, dataT,
556+
indT>;
557+
558+
constexpr std::size_t nominal_lws = 256;
559+
const std::size_t masked_extent = iteration_size;
560+
const std::size_t lws = std::min(masked_extent, nominal_lws);
561+
562+
const std::size_t n_groups = (masked_extent + lws - 1) / lws;
563+
564+
sycl::range<2> gRange{1, n_groups * lws};
565+
sycl::range<2> lRange{1, lws};
566+
sycl::nd_range<2> ndRange{gRange, lRange};
567+
568+
using LocalAccessorT = sycl::local_accessor<indT, 1>;
569+
540570
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
541571
cgh.depends_on(depends);
542572

543-
cgh.parallel_for<class masked_place_all_slices_strided_impl_krn<
544-
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer,
545-
dataT, indT>>(
546-
sycl::range<1>(static_cast<size_t>(iteration_size)),
573+
const std::size_t lacc_size = std::min(masked_extent, lws) + 1;
574+
LocalAccessorT lacc(lacc_size, cgh);
575+
576+
cgh.parallel_for<KernelName>(
577+
ndRange,
547578
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
548-
Strided1DCyclicIndexer, dataT, indT>(
549-
dst_p, cumsum_p, rhs_p, 1, iteration_size,
550-
orthog_dst_rhs_indexer, masked_dst_indexer,
551-
masked_rhs_indexer));
579+
Strided1DCyclicIndexer, dataT, indT,
580+
LocalAccessorT>(
581+
dst_p, cumsum_p, rhs_p, iteration_size, orthog_dst_rhs_indexer,
582+
masked_dst_indexer, masked_rhs_indexer, lacc));
552583
});
553584

554585
return comp_ev;
@@ -615,18 +646,32 @@ sycl::event masked_place_some_slices_strided_impl(
615646
TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer,
616647
dataT, indT>;
617648

618-
sycl::range<1> gRange(static_cast<size_t>(orthog_nelems * masked_nelems));
649+
constexpr std::size_t nominal_lws = 256;
650+
const std::size_t orthog_extent = orthog_nelems;
651+
const std::size_t masked_extent = masked_nelems;
652+
const std::size_t lws = std::min(masked_extent, nominal_lws);
653+
654+
const std::size_t n_groups = (masked_extent + lws - 1) / lws;
655+
656+
sycl::range<2> gRange{orthog_extent, n_groups * lws};
657+
sycl::range<2> lRange{1, lws};
658+
sycl::nd_range<2> ndRange{gRange, lRange};
659+
660+
using LocalAccessorT = sycl::local_accessor<indT, 1>;
619661

620662
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
621663
cgh.depends_on(depends);
622664

665+
const std::size_t lacc_size = std::min(masked_extent, lws) + 1;
666+
LocalAccessorT lacc(lacc_size, cgh);
667+
623668
cgh.parallel_for<KernelName>(
624-
gRange,
669+
ndRange,
625670
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
626-
Strided1DCyclicIndexer, dataT, indT>(
627-
dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
628-
orthog_dst_rhs_indexer, masked_dst_indexer,
629-
masked_rhs_indexer));
671+
Strided1DCyclicIndexer, dataT, indT,
672+
LocalAccessorT>(
673+
dst_p, cumsum_p, rhs_p, masked_nelems, orthog_dst_rhs_indexer,
674+
masked_dst_indexer, masked_rhs_indexer, lacc));
630675
});
631676

632677
return comp_ev;

0 commit comments

Comments
 (0)