@@ -71,6 +71,41 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
7171 throw std::runtime_error (" Could not construct top k kernel parameters" );
7272}
7373
74+ template <class KernelName , typename argTy, typename IndexTy>
75+ sycl::event write_out_impl (sycl::queue &exec_q,
76+ std::size_t iter_nelems,
77+ std::size_t k,
78+ const argTy *arg_tp,
79+ const IndexTy *index_data,
80+ std::size_t iter_index_stride,
81+ std::size_t axis_nelems,
82+ argTy *vals_tp,
83+ IndexTy *inds_tp,
84+ const std::vector<sycl::event> &depends)
85+ {
86+ sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
87+ cgh.depends_on (depends);
88+
89+ cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
90+ const std::size_t gid = id[0 ];
91+
92+ const std::size_t iter_gid = gid / k;
93+ const std::size_t axis_gid = gid - (iter_gid * k);
94+
95+ const std::size_t src_idx = iter_gid * iter_index_stride + axis_gid;
96+ const std::size_t dst_idx = gid;
97+
98+ const IndexTy res_ind = index_data[src_idx];
99+ const argTy v = arg_tp[res_ind];
100+
101+ vals_tp[dst_idx] = v;
102+ inds_tp[dst_idx] = (res_ind % axis_nelems);
103+ });
104+ });
105+
106+ return write_out_ev;
107+ }
108+
74109} // namespace topk_detail
75110
76111template <typename T1, typename T2, typename T3>
@@ -132,26 +167,13 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
132167 exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
133168 {base_sort_ev});
134169
135- sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
136- cgh.depends_on (merges_ev);
137-
138- using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
139-
140- cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
141- std::size_t gid = id[0 ];
142-
143- std::size_t iter_gid = gid / k;
144- std::size_t axis_gid = gid - (iter_gid * k);
170+ using WriteOutKernelName =
171+ topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
145172
146- std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
147- std::size_t dst_idx = iter_gid * k + axis_gid;
148-
149- const IndexTy res_ind = index_data[src_idx];
150- const argTy v = arg_tp[res_ind];
151- vals_tp[dst_idx] = v;
152- inds_tp[dst_idx] = res_ind % axis_nelems;
153- });
154- });
173+ sycl::event write_out_ev =
174+ topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
175+ exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems,
176+ axis_nelems, vals_tp, inds_tp, {merges_ev});
155177
156178 sycl::event cleanup_host_task_event =
157179 dpctl::tensor::alloc_utils::async_smart_free (exec_q, {write_out_ev},
@@ -399,27 +421,13 @@ sycl::event topk_merge_impl(
399421 k_rounded, {base_sort_ev});
400422
401423 // Write out top k of the merge-sorted memory
402- sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
403- cgh.depends_on (merges_ev);
404-
405- using KernelName =
406- topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
424+ using WriteOutKernelName =
425+ topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
407426
408- cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
409- const std::size_t gid = id[0 ];
410-
411- const std::size_t iter_gid = gid / k;
412- const std::size_t axis_gid = gid - (iter_gid * k);
413-
414- const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
415- const std::size_t dst_idx = gid;
416-
417- const IndexTy res_ind = index_data[src_idx];
418- const argTy v = arg_tp[res_ind];
419- vals_tp[dst_idx] = v;
420- inds_tp[dst_idx] = (res_ind % axis_nelems);
421- });
422- });
427+ sycl::event write_topk_ev =
428+ topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
429+ exec_q, iter_nelems, k, arg_tp, index_data, alloc_len,
430+ axis_nelems, vals_tp, inds_tp, {merges_ev});
423431
424432 sycl::event cleanup_host_task_event =
425433 dpctl::tensor::alloc_utils::async_smart_free (
@@ -502,26 +510,12 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
502510 ascending, {iota_ev});
503511
504512 // Write out top k of the temporary
505- sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
506- cgh.depends_on (radix_sort_ev);
507-
508- using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
513+ using WriteOutKernelName = topk_radix_map_back_krn<argTy, IndexTy>;
509514
510- cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
511- const std::size_t gid = id[0 ];
512-
513- const std::size_t iter_gid = gid / k;
514- const std::size_t axis_gid = gid - (iter_gid * k);
515-
516- const std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
517- const std::size_t dst_idx = gid;
518-
519- const IndexTy res_ind = tmp_tp[src_idx];
520- const argTy v = arg_tp[res_ind];
521- vals_tp[dst_idx] = v;
522- inds_tp[dst_idx] = (res_ind % axis_nelems);
523- });
524- });
515+ sycl::event write_topk_ev =
516+ topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
517+ exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems,
518+ vals_tp, inds_tp, {radix_sort_ev});
525519
526520 sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
527521 exec_q, {write_topk_ev}, workspace_owner);
0 commit comments