Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,10 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
dev.template get_info<sycl::info::device::max_work_group_size>();

constexpr std::uint16_t ref_wg_size = 64;
if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) {
constexpr bool enable_one_wg_radix_sort = true;
if (enable_one_wg_radix_sort && n_to_sort <= 16384 &&
ref_wg_size * 8 <= max_wg_size)
{
using _RadixSortKernel = OneWorkGroupRadixSortKernel<ValueT, ProjT>;

if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) {
Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,10 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op,
ascending, {iota_ev});

radix_sort_ev.wait();

// Write out top k of the temporary
sycl::event write_topk_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(radix_sort_ev);

using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;

cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
Expand All @@ -519,9 +519,9 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
});
});

sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(write_topk_ev);
write_topk_ev.wait();

sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
Expand Down
Loading