Skip to content

Commit cd1243f

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Simplify write-out kernels in topk implementation (avoid recomputing gid)
1 parent 275827a commit cd1243f

File tree

1 file changed

+14
-14
lines changed
  • dpctl/tensor/libtensor/include/kernels/sorting

1 file changed

+14
-14
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -410,17 +410,17 @@ sycl::event topk_merge_impl(
410410
topk_partial_merge_map_back_krn<argTy, IndexTy, ValueComp>;
411411

412412
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
413-
std::size_t gid = id[0];
413+
const std::size_t gid = id[0];
414414

415-
std::size_t iter_gid = gid / k;
416-
std::size_t axis_gid = gid - (iter_gid * k);
415+
const std::size_t iter_gid = gid / k;
416+
const std::size_t axis_gid = gid - (iter_gid * k);
417417

418-
std::size_t src_idx = iter_gid * alloc_len + axis_gid;
419-
std::size_t dst_idx = iter_gid * k + axis_gid;
418+
const std::size_t src_idx = iter_gid * alloc_len + axis_gid;
419+
const std::size_t dst_idx = gid;
420420

421-
auto res_ind = index_data[src_idx];
421+
const auto res_ind = index_data[src_idx];
422422
vals_tp[dst_idx] = arg_tp[res_ind];
423-
inds_tp[dst_idx] = res_ind % axis_nelems;
423+
inds_tp[dst_idx] = (res_ind % axis_nelems);
424424
});
425425
});
426426

@@ -519,17 +519,17 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
519519
using KernelName = topk_radix_map_back_krn<argTy, IndexTy>;
520520

521521
cgh.parallel_for<KernelName>(iter_nelems * k, [=](sycl::id<1> id) {
522-
std::size_t gid = id[0];
522+
const std::size_t gid = id[0];
523523

524-
std::size_t iter_gid = gid / k;
525-
std::size_t axis_gid = gid - (iter_gid * k);
524+
const std::size_t iter_gid = gid / k;
525+
const std::size_t axis_gid = gid - (iter_gid * k);
526526

527-
std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
528-
std::size_t dst_idx = iter_gid * k + axis_gid;
527+
const std::size_t src_idx = iter_gid * axis_nelems + axis_gid;
528+
const std::size_t dst_idx = gid;
529529

530-
IndexTy res_ind = tmp_tp[src_idx];
530+
const IndexTy res_ind = tmp_tp[src_idx];
531531
vals_tp[dst_idx] = arg_tp[res_ind];
532-
inds_tp[dst_idx] = res_ind % axis_nelems;
532+
inds_tp[dst_idx] = (res_ind % axis_nelems);
533533
});
534534
});
535535

0 commit comments

Comments
 (0)