|
36 | 36 | #include "kernels/dpctl_tensor_types.hpp" |
37 | 37 | #include "merge_sort.hpp" |
38 | 38 | #include "radix_sort.hpp" |
| 39 | +#include "search_sorted_detail.hpp" |
39 | 40 | #include "utils/sycl_alloc_utils.hpp" |
40 | 41 | #include <sycl/ext/oneapi/sub_group_mask.hpp> |
41 | 42 |
|
@@ -247,14 +248,16 @@ sycl::event topk_merge_impl( |
247 | 248 | // This assumption permits doing away with using a loop |
248 | 249 | assert(sorted_block_size % lws == 0); |
249 | 250 |
|
| 251 | + using search_sorted_detail::quotient_ceil; |
250 | 252 | const std::size_t n_segments = |
251 | | - merge_sort_detail::quotient_ceil<std::size_t>(axis_nelems, |
252 | | - sorted_block_size); |
| 253 | + quotient_ceil<std::size_t>(axis_nelems, sorted_block_size); |
253 | 254 |
|
254 | | - // round k up for the later merge kernel |
| 255 | + // round k up for the later merge kernel if necessary |
| 256 | + const std::size_t round_k_to = dev.has(sycl::aspect::cpu) ? 32 : 4; |
255 | 257 | std::size_t k_rounded = |
256 | | - merge_sort_detail::quotient_ceil<std::size_t>(k, elems_per_wi) * |
257 | | - elems_per_wi; |
| 258 | + (k < round_k_to) |
| 259 | + ? k |
| 260 | + : quotient_ceil<std::size_t>(k, round_k_to) * round_k_to; |
258 | 261 |
|
259 | 262 | // get length of tail for alloc size |
260 | 263 | auto rem = axis_nelems % sorted_block_size; |
@@ -322,8 +325,7 @@ sycl::event topk_merge_impl( |
322 | 325 | sycl::group_barrier(it.get_group()); |
323 | 326 |
|
324 | 327 | const std::size_t chunk = |
325 | | - merge_sort_detail::quotient_ceil<std::size_t>( |
326 | | - sorted_block_size, lws); |
| 328 | + quotient_ceil<std::size_t>(sorted_block_size, lws); |
327 | 329 |
|
328 | 330 | const std::size_t chunk_start_idx = lid * chunk; |
329 | 331 | const std::size_t chunk_end_idx = |
|
0 commit comments