@@ -86,7 +86,7 @@ sycl::event topk_full_sort_impl(
8686 // matrix when sorting over rows)
8787 std::size_t sort_nelems, // size of each array to sort (length of rows,
8888 // i.e. number of columns)
89- dpctl::tensor:: ssize_t k,
89+ std:: size_t k,
9090 const argTy *arg_tp,
9191 argTy *vals_tp,
9292 IndexTy *inds_tp,
@@ -174,7 +174,7 @@ topk_impl(sycl::queue &exec_q,
174174 // in a matrix when sorting over rows)
175175 std::size_t axis_nelems, // size of each array to sort (length of
176176 // rows, i.e. number of columns)
177- dpctl::tensor:: ssize_t k,
177+ std:: size_t k,
178178 const char *arg_cp,
179179 char *vals_cp,
180180 char *inds_cp,
@@ -186,7 +186,7 @@ topk_impl(sycl::queue &exec_q,
186186 dpctl::tensor::ssize_t axis_inds_offset,
187187 const std::vector<sycl::event> &depends)
188188{
189- if (axis_nelems < static_cast <std:: size_t >(k) ) {
189+ if (axis_nelems < k ) {
190190 throw std::runtime_error (" Invalid sort axis size for value of k" );
191191 }
192192
@@ -200,9 +200,7 @@ topk_impl(sycl::queue &exec_q,
200200 using dpctl::tensor::kernels::IndexComp;
201201 const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
202202
203- if (axis_nelems <= 512 || k >= 1024 ||
204- static_cast <std::size_t >(k) > axis_nelems / 2 )
205- {
203+ if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2 ) {
206204 return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k, arg_tp,
207205 vals_tp, inds_tp, index_comp, depends);
208206 }
@@ -256,22 +254,19 @@ topk_impl(sycl::queue &exec_q,
256254 sorted_block_size);
257255
258256 // round k up for the later merge kernel
259- const dpctl::tensor::ssize_t round_k_to = elems_per_wi;
260- dpctl::tensor::ssize_t k_rounded =
261- merge_sort_detail::quotient_ceil<dpctl::tensor::ssize_t >(
262- k, round_k_to) *
263- round_k_to;
257+ std::size_t k_rounded =
258+ merge_sort_detail::quotient_ceil<std::size_t >(k, elems_per_wi) *
259+ elems_per_wi;
264260
265261 // get length of tail for alloc size
266262 auto rem = axis_nelems % sorted_block_size;
267- auto alloc_len = (rem && rem < static_cast <std:: size_t >( k_rounded) )
263+ auto alloc_len = (rem && rem < k_rounded)
268264 ? rem + k_rounded * (n_segments - 1 )
269265 : k_rounded * n_segments;
270266
271267 // if allocation would be sufficiently large or k is larger than
272268 // elements processed, use full sort
273- if (static_cast <std::size_t >(k_rounded) >= axis_nelems ||
274- static_cast <std::size_t >(k_rounded) >= sorted_block_size ||
269+ if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size ||
275270 alloc_len >= axis_nelems / 2 )
276271 {
277272 return topk_full_sort_impl (exec_q, iter_nelems, axis_nelems, k,
@@ -386,7 +381,7 @@ topk_impl(sycl::queue &exec_q,
386381 for (std::size_t array_id = k_segment_start_idx + lid;
387382 array_id < k_segment_end_idx; array_id += lws)
388383 {
389- if (lid < static_cast <std:: size_t >( k_rounded) ) {
384+ if (lid < k_rounded) {
390385 index_data[iter_id * alloc_len + array_id] =
391386 out_src[array_id - k_segment_start_idx];
392387 }
0 commit comments