Skip to content

Commit b15e302

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Factor out write-out kernel into separate detail function
Replaced three duplicates of the same kernel with calls to this function.
1 parent 3a4d8ab commit b15e302

File tree

1 file changed

+52
-58
lines changed
  • dpctl/tensor/libtensor/include/kernels/sorting

1 file changed

+52
-58
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,41 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
7171
throw std::runtime_error("Could not construct top k kernel parameters");
7272
}
7373

74+
template <class KernelName, typename argTy, typename IndexTy>
75+
sycl::event write_out_impl(sycl::queue &exec_q,
76+
std::size_t iter_nelems,
77+
std::size_t k,
78+
const argTy *arg_tp,
79+
const IndexTy *index_data,
80+
std::size_t iter_index_stride,
81+
std::size_t axis_nelems,
82+
argTy *vals_tp,
83+
IndexTy *inds_tp,
84+
const std::vector<sycl::event> &depends)
85+
{
86+
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
87+
cgh.depends_on(depends);
88+
89+
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
90+
const std::size_t gid = id[0];
91+
92+
const std::size_t iter_gid = gid / k;
93+
const std::size_t axis_gid = gid - (iter_gid * k);
94+
95+
const std::size_t src_idx = iter_gid * iter_index_stride + axis_gid;
96+
const std::size_t dst_idx = gid;
97+
98+
const IndexTy res_ind = index_data[src_idx];
99+
const argTy v = arg_tp[res_ind];
100+
101+
vals_tp[dst_idx] = v;
102+
inds_tp[dst_idx] = (res_ind % axis_nelems);
103+
});
104+
});
105+
106+
return write_out_ev;
107+
}
108+
74109
} // namespace topk_detail
75110

76111
template <typename T1, typename T2, typename T3>
@@ -132,26 +167,13 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
132167
exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
133168
{base_sort_ev});
134169

135-
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
136-
cgh.depends_on(merges_ev);
137-
138-
using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
139-
140-
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
141-
std::size_t gid = id[0];
142-
143-
std::size_t iter_gid = gid / k;
144-
std::size_t axis_gid = gid - (iter_gid * k);
170+
using WriteOutKernelName =
171+
topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
145172

146-
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
147-
std::size_t dst_idx = iter_gid * k + axis_gid;
148-
149-
const IndexTy res_ind = index_data[src_idx];
150-
const argTy v = arg_tp[res_ind];
151-
vals_tp[dst_idx] = v;
152-
inds_tp[dst_idx] = res_ind % axis_nelems;
153-
});
154-
});
173+
sycl::event write_out_ev =
174+
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
175+
exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems,
176+
axis_nelems, vals_tp, inds_tp, {merges_ev});
155177

156178
sycl::event cleanup_host_task_event =
157179
dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev},
@@ -399,27 +421,13 @@ sycl::event topk_merge_impl(
399421
k_rounded, {base_sort_ev});
400422

401423
// Write out top k of the merge-sorted memory
402-
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
403-
cgh.depends_on(merges_ev);
404-
405-
using KernelName =
406-
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
424+
using WriteOutKernelName =
425+
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
407426

408-
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
409-
const std::size_t gid = id[0];
410-
411-
const std::size_t iter_gid = gid / k;
412-
const std::size_t axis_gid = gid - (iter_gid * k);
413-
414-
const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
415-
const std::size_t dst_idx = gid;
416-
417-
const IndexTy res_ind = index_data[src_idx];
418-
const argTy v = arg_tp[res_ind];
419-
vals_tp[dst_idx] = v;
420-
inds_tp[dst_idx] = (res_ind % axis_nelems);
421-
});
422-
});
427+
sycl::event write_topk_ev =
428+
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
429+
exec_q, iter_nelems, k, arg_tp, index_data, alloc_len,
430+
axis_nelems, vals_tp, inds_tp, {merges_ev});
423431

424432
sycl::event cleanup_host_task_event =
425433
dpctl::tensor::alloc_utils::async_smart_free(
@@ -502,26 +510,12 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
502510
ascending, {iota_ev});
503511

504512
// Write out top k of the temporary
505-
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
506-
cgh.depends_on(radix_sort_ev);
507-
508-
using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
513+
using WriteOutKernelName = topk_radix_map_back_krn<argTy, IndexTy>;
509514

510-
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
511-
const std::size_t gid = id[0];
512-
513-
const std::size_t iter_gid = gid / k;
514-
const std::size_t axis_gid = gid - (iter_gid * k);
515-
516-
const std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
517-
const std::size_t dst_idx = gid;
518-
519-
const IndexTy res_ind = tmp_tp[src_idx];
520-
const argTy v = arg_tp[res_ind];
521-
vals_tp[dst_idx] = v;
522-
inds_tp[dst_idx] = (res_ind % axis_nelems);
523-
});
524-
});
515+
sycl::event write_topk_ev =
516+
topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
517+
exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems,
518+
vals_tp, inds_tp, {radix_sort_ev});
525519

526520
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
527521
exec_q, {write_topk_ev}, workspace_owner);

0 commit comments

Comments
 (0)