3535
3636#include " kernels/dpctl_tensor_types.hpp"
3737#include " merge_sort.hpp"
38+ #include " radix_sort.hpp"
3839#include " utils/sycl_alloc_utils.hpp"
3940#include < sycl/ext/oneapi/sub_group_mask.hpp>
4041
@@ -70,31 +71,25 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
7071} // namespace topk_detail
7172
7273template <typename T1, typename T2, typename T3>
73- class populate_index_data_full_sort_krn ;
74+ class topk_populate_index_data_krn ;
7475
7576template <typename T1, typename T2, typename T3>
76- class topk_map_to_rows_full_sort_krn ;
77-
78- template <typename T1, typename T2, typename T3> class populate_index_data_krn ;
79-
80- template <typename T1, typename T2, typename T3> class topk_map_to_rows_krn ;
77+ class topk_full_merge_map_back_krn ;
8178
8279template <typename argTy, typename IndexTy, typename CompT>
83- sycl::event topk_full_sort_impl (
84- sycl::queue &exec_q,
85- std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a
86- // matrix when sorting over rows)
87- std::size_t sort_nelems, // size of each array to sort (length of rows,
88- // i.e. number of columns)
89- std::size_t k,
90- const argTy *arg_tp,
91- argTy *vals_tp,
92- IndexTy *inds_tp,
93- const CompT &comp,
94- const std::vector<sycl::event> &depends)
80+ sycl::event
81+ topk_full_merge_sort_impl (sycl::queue &exec_q,
82+ std::size_t iter_nelems, // number of sub-arrays
83+ std::size_t axis_nelems, // size of each sub-array
84+ std::size_t k,
85+ const argTy *arg_tp,
86+ argTy *vals_tp,
87+ IndexTy *inds_tp,
88+ const CompT &comp,
89+ const std::vector<sycl::event> &depends)
9590{
9691 IndexTy *index_data =
97- sycl::malloc_device<IndexTy>(iter_nelems * sort_nelems , exec_q);
92+ sycl::malloc_device<IndexTy>(iter_nelems * axis_nelems , exec_q);
9893 if (index_data == nullptr ) {
9994 throw std::runtime_error (" Unable to allocate device_memory" );
10095 }
@@ -103,10 +98,10 @@ sycl::event topk_full_sort_impl(
10398 exec_q.submit ([&](sycl::handler &cgh) {
10499 cgh.depends_on (depends);
105100
106- auto const &range = sycl::range<1 >(iter_nelems * sort_nelems );
101+ auto const &range = sycl::range<1 >(iter_nelems * axis_nelems );
107102
108103 using KernelName =
109- populate_index_data_full_sort_krn <argTy, IndexTy, CompT>;
104+ topk_populate_index_data_krn <argTy, IndexTy, CompT>;
110105
111106 cgh.parallel_for <KernelName>(range, [=](sycl::id<1 > id) {
112107 std::size_t i = id[0 ];
@@ -118,34 +113,33 @@ sycl::event topk_full_sort_impl(
118113 // Sort segments of the array
119114 sycl::event base_sort_ev =
120115 merge_sort_detail::sort_over_work_group_contig_impl (
121- exec_q, iter_nelems, sort_nelems , index_data, index_data, comp,
116+ exec_q, iter_nelems, axis_nelems , index_data, index_data, comp,
122117 sorted_block_size, // modified in place with size of sorted block
123118 // size
124119 {populate_indexed_data_ev});
125120
126121 // Merge segments in parallel until all elements are sorted
127122 sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl (
128- exec_q, iter_nelems, sort_nelems , index_data, comp, sorted_block_size,
123+ exec_q, iter_nelems, axis_nelems , index_data, comp, sorted_block_size,
129124 {base_sort_ev});
130125
131126 sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
132127 cgh.depends_on (merges_ev);
133128
134- using KernelName =
135- topk_map_to_rows_full_sort_krn<argTy, IndexTy, CompT>;
129+ using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
136130
137131 cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
138132 std::size_t gid = id[0 ];
139133
140134 std::size_t iter_gid = gid / k;
141135 std::size_t axis_gid = gid - (iter_gid * k);
142136
143- std::size_t src_idx = iter_gid * sort_nelems + axis_gid;
137+ std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
144138 std::size_t dst_idx = iter_gid * k + axis_gid;
145139
146140 auto res_ind = index_data[src_idx];
147141 vals_tp[dst_idx] = arg_tp[res_ind];
148- inds_tp[dst_idx] = res_ind % sort_nelems ;
142+ inds_tp[dst_idx] = res_ind % axis_nelems ;
149143 });
150144 });
151145
@@ -162,29 +156,32 @@ sycl::event topk_full_sort_impl(
162156 return cleanup_host_task_event;
163157};
164158
159+ template <typename T1, typename T2, typename T3>
160+ class topk_partial_merge_map_back_krn ;
161+
165162template <typename T1, typename T2, typename Comp>
166163class topk_over_work_group_krn ;
167164
168165template <typename argTy,
169166 typename IndexTy,
170167 typename ValueComp = std::less<argTy>>
171- sycl::event
172- topk_impl ( sycl::queue &exec_q,
173- std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
174- // in a matrix when sorting over rows)
175- std::size_t axis_nelems, // size of each array to sort (length of
176- // rows, i.e. number of columns)
177- std::size_t k,
178- const char *arg_cp,
179- char *vals_cp,
180- char *inds_cp,
181- dpctl::tensor::ssize_t iter_arg_offset,
182- dpctl::tensor::ssize_t iter_vals_offset,
183- dpctl::tensor::ssize_t iter_inds_offset,
184- dpctl::tensor::ssize_t axis_arg_offset,
185- dpctl::tensor::ssize_t axis_vals_offset,
186- dpctl::tensor::ssize_t axis_inds_offset,
187- const std::vector<sycl::event> &depends)
168+ sycl::event topk_merge_impl (
169+ sycl::queue &exec_q,
170+ std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows
171+ // in a matrix when sorting over rows)
172+ std::size_t axis_nelems, // size of each array to sort (length of
173+ // rows, i.e. number of columns)
174+ std::size_t k,
175+ const char *arg_cp,
176+ char *vals_cp,
177+ char *inds_cp,
178+ dpctl::tensor::ssize_t iter_arg_offset,
179+ dpctl::tensor::ssize_t iter_vals_offset,
180+ dpctl::tensor::ssize_t iter_inds_offset,
181+ dpctl::tensor::ssize_t axis_arg_offset,
182+ dpctl::tensor::ssize_t axis_vals_offset,
183+ dpctl::tensor::ssize_t axis_inds_offset,
184+ const std::vector<sycl::event> &depends)
188185{
189186 if (axis_nelems < k) {
190187 throw std::runtime_error (" Invalid sort axis size for value of k" );
@@ -201,8 +198,9 @@ topk_impl(sycl::queue &exec_q,
201198 const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
202199
203200 if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2 ) {
204- return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k, arg_tp,
205- vals_tp, inds_tp, index_comp, depends);
201+ return topk_full_merge_sort_impl (exec_q, iter_nelems, axis_nelems, k,
202+ arg_tp, vals_tp, inds_tp, index_comp,
203+ depends);
206204 }
207205 else {
208206 using PartialKernelName =
@@ -269,9 +267,9 @@ topk_impl(sycl::queue &exec_q,
269267 if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
270268 alloc_len >= axis_nelems / 2 )
271269 {
272- return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k ,
273- arg_tp, vals_tp, inds_tp, index_comp ,
274- depends);
270+ return topk_full_merge_sort_impl (exec_q, iter_nelems, axis_nelems,
271+ k, arg_tp, vals_tp, inds_tp,
272+ index_comp, depends);
275273 }
276274
277275 IndexTy *index_data =
@@ -399,7 +397,8 @@ topk_impl(sycl::queue &exec_q,
399397 sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
400398 cgh.depends_on (merges_ev);
401399
402- using KernelName = topk_map_to_rows_krn<argTy, IndexTy, ValueComp>;
400+ using KernelName =
401+ topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
403402
404403 cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
405404 std::size_t gid = id[0 ];
@@ -430,6 +429,109 @@ topk_impl(sycl::queue &exec_q,
430429 }
431430}
432431
432+ template <typename T1, typename T2> class topk_iota_krn ;
433+
434+ template <typename T1, typename T2> class topk_radix_map_back_krn ;
435+
436+ template <typename argTy, typename IndexTy>
437+ sycl::event topk_radix_impl (sycl::queue &exec_q,
438+ std::size_t iter_nelems, // number of sub-arrays
439+ std::size_t axis_nelems, // size of each sub-array
440+ std::size_t k,
441+ bool ascending,
442+ const char *arg_cp,
443+ char *vals_cp,
444+ char *inds_cp,
445+ dpctl::tensor::ssize_t iter_arg_offset,
446+ dpctl::tensor::ssize_t iter_vals_offset,
447+ dpctl::tensor::ssize_t iter_inds_offset,
448+ dpctl::tensor::ssize_t axis_arg_offset,
449+ dpctl::tensor::ssize_t axis_vals_offset,
450+ dpctl::tensor::ssize_t axis_inds_offset,
451+ const std::vector<sycl::event> &depends)
452+ {
453+ if (axis_nelems < k) {
454+ throw std::runtime_error (" Invalid sort axis size for value of k" );
455+ }
456+
457+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
458+ iter_arg_offset + axis_arg_offset;
459+ argTy *vals_tp = reinterpret_cast <argTy *>(vals_cp) + iter_vals_offset +
460+ axis_vals_offset;
461+ IndexTy *inds_tp = reinterpret_cast <IndexTy *>(inds_cp) + iter_inds_offset +
462+ axis_inds_offset;
463+
464+ const std::size_t total_nelems = iter_nelems * axis_nelems;
465+ const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
466+ IndexTy *workspace = sycl::malloc_device<IndexTy>(
467+ padded_total_nelems + total_nelems, exec_q);
468+
469+ IndexTy *tmp_tp = sycl::malloc_device<IndexTy>(total_nelems, exec_q);
470+
471+ if (nullptr == workspace || nullptr == tmp_tp) {
472+ throw std::runtime_error (
473+ " Not enough device memory for radix sort topk" );
474+ }
475+
476+ using IdentityProjT = radix_sort_details::IdentityProj;
477+ using IndexedProjT =
478+ radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
479+ const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
480+
481+ sycl::event iota_ev = exec_q.submit ([&](sycl::handler &cgh) {
482+ cgh.depends_on (depends);
483+
484+ using KernelName = topk_iota_krn<argTy, IndexTy>;
485+
486+ cgh.parallel_for <KernelName>(
487+ sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
488+ size_t i = id[0 ];
489+ IndexTy sort_id = static_cast <IndexTy>(i);
490+ workspace[i] = sort_id;
491+ });
492+ });
493+
494+ sycl::event radix_sort_ev =
495+ radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
496+ exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op,
497+ ascending, {iota_ev});
498+
499+ // Write out top k of the temporary
500+ sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
501+ cgh.depends_on (radix_sort_ev);
502+
503+ using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
504+
505+ cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
506+ std::size_t gid = id[0 ];
507+
508+ std::size_t iter_gid = gid / k;
509+ std::size_t axis_gid = gid - (iter_gid * k);
510+
511+ std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
512+ std::size_t dst_idx = iter_gid * k + axis_gid;
513+
514+ IndexTy res_ind = tmp_tp[src_idx];
515+ vals_tp[dst_idx] = arg_tp[res_ind];
516+ inds_tp[dst_idx] = res_ind % axis_nelems;
517+ });
518+ });
519+
520+ sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
521+ cgh.depends_on (write_topk_ev);
522+
523+ const sycl::context &ctx = exec_q.get_context ();
524+
525+ using dpctl::tensor::alloc_utils::sycl_free_noexcept;
526+ cgh.host_task ([ctx, workspace, tmp_tp] {
527+ sycl_free_noexcept (workspace, ctx);
528+ sycl_free_noexcept (tmp_tp, ctx);
529+ });
530+ });
531+
532+ return cleanup_ev;
533+ }
534+
433535} // end of namespace kernels
434536} // end of namespace tensor
435537} // end of namespace dpctl
0 commit comments