Skip to content

Commit 2608aea

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Reimplemented map_back_impl to process few elements per work-item
Factored out map_back_impl projects indexing from flat index to a row-wise index. Removed dead code excluded by preprocessor conditional.
1 parent 4ef615e commit 2608aea

File tree

4 files changed

+70
-67
lines changed

4 files changed

+70
-67
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -816,24 +816,9 @@ sycl::event stable_argsort_axis1_contig_impl(
816816

817817
using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;
818818

819-
#if 1
820819
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
821820
exec_q, res_tp, total_nelems, depends);
822821

823-
#else
824-
sycl::event populate_indexed_data_ev =
825-
exec_q.submit([&](sycl::handler &cgh) {
826-
cgh.depends_on(depends);
827-
828-
const sycl::range<1> range{total_nelems};
829-
830-
cgh.parallel_for<IotaKernelName>(range, [=](sycl::id<1> id) {
831-
size_t i = id[0];
832-
res_tp[i] = static_cast<IndexTy>(i);
833-
});
834-
});
835-
#endif
836-
837822
// Sort segments of the array
838823
sycl::event base_sort_ev =
839824
merge_sort_detail::sort_over_work_group_contig_impl(
@@ -847,21 +832,11 @@ sycl::event stable_argsort_axis1_contig_impl(
847832
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
848833
{base_sort_ev});
849834

850-
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
851-
cgh.depends_on(merges_ev);
852-
853-
auto temp_acc =
854-
merge_sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp,
855-
cgh);
835+
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
836+
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
856837

857-
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
858-
859-
const sycl::range<1> range{total_nelems};
860-
861-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
862-
res_tp[id] = (temp_acc[id] % sort_nelems);
863-
});
864-
});
838+
sycl::event write_out_ev = map_back_impl<MapBackKernelName, IndexTy>(
839+
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev});
865840

866841
return write_out_ev;
867842
}

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,41 +1766,21 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17661766

17671767
using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;
17681768

1769-
#if 1
17701769
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
17711770

17721771
sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
17731772
exec_q, workspace, total_nelems, depends);
1774-
#else
1775-
1776-
sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
1777-
cgh.depends_on(depends);
1778-
1779-
cgh.parallel_for<IotaKernelName>(
1780-
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
1781-
size_t i = id[0];
1782-
IndexTy sort_id = static_cast<IndexTy>(i);
1783-
workspace[i] = sort_id;
1784-
});
1785-
});
1786-
#endif
17871773

17881774
sycl::event radix_sort_ev =
17891775
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
17901776
exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op,
17911777
sort_ascending, {iota_ev});
17921778

1793-
sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) {
1794-
cgh.depends_on(radix_sort_ev);
1795-
1796-
using KernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1779+
using MapBackKernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1780+
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
17971781

1798-
cgh.parallel_for<KernelName>(
1799-
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
1800-
IndexTy linear_index = res_tp[id];
1801-
res_tp[id] = (linear_index % sort_nelems);
1802-
});
1803-
});
1782+
sycl::event map_back_ev = map_back_impl<MapBackKernelName, IndexTy>(
1783+
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev});
18041784

18051785
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
18061786
exec_q, {map_back_ev}, workspace_owner);

dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,27 @@ sycl::event iota_impl(sycl::queue &exec_q,
9797
return e;
9898
}
9999

100+
template <class KernelName, typename IndexTy>
101+
sycl::event map_back_impl(sycl::queue &exec_q,
102+
std::size_t nelems,
103+
const IndexTy *flat_index_data,
104+
IndexTy *reduced_index_data,
105+
std::size_t row_size,
106+
const std::vector<sycl::event> &dependent_events)
107+
{
108+
sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) {
109+
cgh.depends_on(dependent_events);
110+
111+
cgh.parallel_for<KernelName>(
112+
sycl::range<1>(nelems), [=](sycl::id<1> id) {
113+
const IndexTy linear_index = flat_index_data[id];
114+
reduced_index_data[id] = (linear_index % row_size);
115+
});
116+
});
117+
118+
return map_back_ev;
119+
}
120+
100121
} // end of namespace sort_utils_detail
101122
} // end of namespace kernels
102123
} // end of namespace tensor

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,50 @@ sycl::event write_out_impl(sycl::queue &exec_q,
8383
IndexTy *inds_tp,
8484
const std::vector<sycl::event> &depends)
8585
{
86-
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
87-
cgh.depends_on(depends);
88-
89-
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
90-
const std::size_t gid = id[0];
86+
constexpr std::uint32_t lws = 64;
87+
constexpr std::uint32_t n_wi = 4;
88+
const std::size_t nelems = iter_nelems * k;
89+
const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws);
9190

92-
const std::size_t iter_gid = gid / k;
93-
const std::size_t axis_gid = gid - (iter_gid * k);
91+
sycl::range<1> lRange{lws};
92+
sycl::range<1> gRange{n_groups * lws};
93+
sycl::nd_range<1> ndRange{gRange, lRange};
9494

95-
const std::size_t src_idx = iter_gid * iter_index_stride + axis_gid;
96-
const std::size_t dst_idx = gid;
97-
98-
const IndexTy res_ind = index_data[src_idx];
99-
const argTy v = arg_tp[res_ind];
95+
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
96+
cgh.depends_on(depends);
10097

101-
vals_tp[dst_idx] = v;
102-
inds_tp[dst_idx] = (res_ind % axis_nelems);
98+
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> it) {
99+
const std::size_t gid = it.get_global_linear_id();
100+
const auto &sg = it.get_sub_group();
101+
const std::uint32_t lane_id = sg.get_local_id()[0];
102+
const std::uint32_t sg_size = sg.get_max_local_range()[0];
103+
104+
const std::size_t start_id =
105+
(gid - lane_id) * sg_size * n_wi + lane_id;
106+
107+
#pragma unroll
108+
for (std::uint32_t i = 0; i < n_wi; ++i) {
109+
const std::size_t data_id = start_id + i * sg_size;
110+
111+
if (data_id < nelems) {
112+
const std::size_t iter_id = data_id / k;
113+
114+
/*
115+
const std::size_t axis_gid = data_id - (iter_gid * k);
116+
const std::size_t src_idx = iter_gid * iter_index_stride +
117+
axis_gid;
118+
*/
119+
const std::size_t src_idx =
120+
data_id + iter_id * (iter_index_stride - k);
121+
122+
const IndexTy res_ind = index_data[src_idx];
123+
const argTy v = arg_tp[res_ind];
124+
125+
const std::size_t dst_idx = data_id;
126+
vals_tp[dst_idx] = v;
127+
inds_tp[dst_idx] = (res_ind % axis_nelems);
128+
}
129+
}
103130
});
104131
});
105132

0 commit comments

Comments
 (0)