diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index dc3da24315..9e82004dc3 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -1483,7 +1483,10 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, dev.template get_info(); constexpr std::uint16_t ref_wg_size = 64; - if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) { + constexpr bool enable_one_wg_radix_sort = true; + if (enable_one_wg_radix_sort && n_to_sort <= 16384 && + ref_wg_size * 8 <= max_wg_size) + { using _RadixSortKernel = OneWorkGroupRadixSortKernel; if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) { diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index 2674f877c9..43685f2ab0 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -498,10 +498,10 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op, ascending, {iota_ev}); + radix_sort_ev.wait(); + // Write out top k of the temporary sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(radix_sort_ev); - using KernelName = topk_radix_map_back_krn; cgh.parallel_for(iter_nelems * k, [=](sycl::id<1> id) { @@ -519,9 +519,9 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, }); }); - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); + write_topk_ev.wait(); + sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { const sycl::context &ctx = exec_q.get_context(); using dpctl::tensor::alloc_utils::sycl_free_noexcept;