@@ -71,6 +71,41 @@ void scale_topk_params(const std::uint64_t nelems_per_slm,
71
71
throw std::runtime_error (" Could not construct top k kernel parameters" );
72
72
}
73
73
74
+ template <class KernelName , typename argTy, typename IndexTy>
75
+ sycl::event write_out_impl (sycl::queue &exec_q,
76
+ std::size_t iter_nelems,
77
+ std::size_t k,
78
+ const argTy *arg_tp,
79
+ const IndexTy *index_data,
80
+ std::size_t iter_index_stride,
81
+ std::size_t axis_nelems,
82
+ argTy *vals_tp,
83
+ IndexTy *inds_tp,
84
+ const std::vector<sycl::event> &depends)
85
+ {
86
+ sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
87
+ cgh.depends_on (depends);
88
+
89
+ cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
90
+ const std::size_t gid = id[0 ];
91
+
92
+ const std::size_t iter_gid = gid / k;
93
+ const std::size_t axis_gid = gid - (iter_gid * k);
94
+
95
+ const std::size_t src_idx = iter_gid * iter_index_stride + axis_gid;
96
+ const std::size_t dst_idx = gid;
97
+
98
+ const IndexTy res_ind = index_data[src_idx];
99
+ const argTy v = arg_tp[res_ind];
100
+
101
+ vals_tp[dst_idx] = v;
102
+ inds_tp[dst_idx] = (res_ind % axis_nelems);
103
+ });
104
+ });
105
+
106
+ return write_out_ev;
107
+ }
108
+
74
109
} // namespace topk_detail
75
110
76
111
template <typename T1, typename T2, typename T3>
@@ -132,26 +167,13 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
132
167
exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size,
133
168
{base_sort_ev});
134
169
135
- sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
136
- cgh.depends_on (merges_ev);
137
-
138
- using KernelName = topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
139
-
140
- cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
141
- std::size_t gid = id[0 ];
142
-
143
- std::size_t iter_gid = gid / k;
144
- std::size_t axis_gid = gid - (iter_gid * k);
170
+ using WriteOutKernelName =
171
+ topk_full_merge_map_back_krn<argTy, IndexTy, CompT>;
145
172
146
- std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
147
- std::size_t dst_idx = iter_gid * k + axis_gid;
148
-
149
- const IndexTy res_ind = index_data[src_idx];
150
- const argTy v = arg_tp[res_ind];
151
- vals_tp[dst_idx] = v;
152
- inds_tp[dst_idx] = res_ind % axis_nelems;
153
- });
154
- });
173
+ sycl::event write_out_ev =
174
+ topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
175
+ exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems,
176
+ axis_nelems, vals_tp, inds_tp, {merges_ev});
155
177
156
178
sycl::event cleanup_host_task_event =
157
179
dpctl::tensor::alloc_utils::async_smart_free (exec_q, {write_out_ev},
@@ -399,27 +421,13 @@ sycl::event topk_merge_impl(
399
421
k_rounded, {base_sort_ev});
400
422
401
423
// Write out top k of the merge-sorted memory
402
- sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
403
- cgh.depends_on (merges_ev);
404
-
405
- using KernelName =
406
- topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
424
+ using WriteOutKernelName =
425
+ topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
407
426
408
- cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
409
- const std::size_t gid = id[0 ];
410
-
411
- const std::size_t iter_gid = gid / k;
412
- const std::size_t axis_gid = gid - (iter_gid * k);
413
-
414
- const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
415
- const std::size_t dst_idx = gid;
416
-
417
- const IndexTy res_ind = index_data[src_idx];
418
- const argTy v = arg_tp[res_ind];
419
- vals_tp[dst_idx] = v;
420
- inds_tp[dst_idx] = (res_ind % axis_nelems);
421
- });
422
- });
427
+ sycl::event write_topk_ev =
428
+ topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
429
+ exec_q, iter_nelems, k, arg_tp, index_data, alloc_len,
430
+ axis_nelems, vals_tp, inds_tp, {merges_ev});
423
431
424
432
sycl::event cleanup_host_task_event =
425
433
dpctl::tensor::alloc_utils::async_smart_free (
@@ -502,26 +510,12 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
502
510
ascending, {iota_ev});
503
511
504
512
// Write out top k of the temporary
505
- sycl::event write_topk_ev = exec_q.submit ([&](sycl::handler &cgh) {
506
- cgh.depends_on (radix_sort_ev);
507
-
508
- using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
513
+ using WriteOutKernelName = topk_radix_map_back_krn<argTy, IndexTy>;
509
514
510
- cgh.parallel_for <KernelName>(iter_nelems * k, [=](sycl::id<1 > id) {
511
- const std::size_t gid = id[0 ];
512
-
513
- const std::size_t iter_gid = gid / k;
514
- const std::size_t axis_gid = gid - (iter_gid * k);
515
-
516
- const std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
517
- const std::size_t dst_idx = gid;
518
-
519
- const IndexTy res_ind = tmp_tp[src_idx];
520
- const argTy v = arg_tp[res_ind];
521
- vals_tp[dst_idx] = v;
522
- inds_tp[dst_idx] = (res_ind % axis_nelems);
523
- });
524
- });
515
+ sycl::event write_topk_ev =
516
+ topk_detail::write_out_impl<WriteOutKernelName, argTy, IndexTy>(
517
+ exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems,
518
+ vals_tp, inds_tp, {radix_sort_ev});
525
519
526
520
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
527
521
exec_q, {write_topk_ev}, workspace_owner);
0 commit comments