32
32
#include < vector>
33
33
34
34
#include " kernels/dpctl_tensor_types.hpp"
35
- #include " kernels/sorting/sort_detail .hpp"
35
+ #include " kernels/sorting/search_sorted_detail .hpp"
36
36
37
37
namespace dpctl
38
38
{
@@ -41,9 +41,11 @@ namespace tensor
41
41
namespace kernels
42
42
{
43
43
44
- namespace sort_detail
44
+ namespace merge_sort_detail
45
45
{
46
46
47
+ using namespace dpctl ::tensor::kernels::search_sorted_detail;
48
+
47
49
/* ! @brief Merge two contiguous sorted segments */
48
50
template <typename InAcc, typename OutAcc, typename Compare>
49
51
void merge_impl (const std::size_t offset,
@@ -699,7 +701,7 @@ merge_sorted_block_contig_impl(sycl::queue &q,
699
701
return dep_ev;
700
702
}
701
703
702
- } // end of namespace sort_detail
704
+ } // end of namespace merge_sort_detail
703
705
704
706
typedef sycl::event (*sort_contig_fn_ptr_t )(sycl::queue &,
705
707
size_t ,
@@ -741,8 +743,8 @@ sycl::event stable_sort_axis1_contig_impl(
741
743
if (sort_nelems < sequential_sorting_threshold) {
742
744
// equal work-item sorts entire row
743
745
sycl::event sequential_sorting_ev =
744
- sort_detail ::sort_base_step_contig_impl<const argTy *, argTy *,
745
- Comp>(
746
+ merge_sort_detail ::sort_base_step_contig_impl<const argTy *,
747
+ argTy *, Comp>(
746
748
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
747
749
sort_nelems, depends);
748
750
@@ -753,16 +755,16 @@ sycl::event stable_sort_axis1_contig_impl(
753
755
754
756
// Sort segments of the array
755
757
sycl::event base_sort_ev =
756
- sort_detail ::sort_over_work_group_contig_impl<const argTy *,
757
- argTy *, Comp>(
758
+ merge_sort_detail ::sort_over_work_group_contig_impl<const argTy *,
759
+ argTy *, Comp>(
758
760
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
759
761
sorted_block_size, // modified in place with size of sorted
760
762
// block size
761
763
depends);
762
764
763
765
// Merge segments in parallel until all elements are sorted
764
766
sycl::event merges_ev =
765
- sort_detail ::merge_sorted_block_contig_impl<argTy *, Comp>(
767
+ merge_sort_detail ::merge_sorted_block_contig_impl<argTy *, Comp>(
766
768
exec_q, iter_nelems, sort_nelems, res_tp, comp,
767
769
sorted_block_size, {base_sort_ev});
768
770
@@ -837,21 +839,24 @@ sycl::event stable_argsort_axis1_contig_impl(
837
839
});
838
840
839
841
// Sort segments of the array
840
- sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl (
841
- exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
842
- sorted_block_size, // modified in place with size of sorted block size
843
- {populate_indexed_data_ev});
842
+ sycl::event base_sort_ev =
843
+ merge_sort_detail::sort_over_work_group_contig_impl (
844
+ exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
845
+ sorted_block_size, // modified in place with size of sorted block
846
+ // size
847
+ {populate_indexed_data_ev});
844
848
845
849
// Merge segments in parallel until all elements are sorted
846
- sycl::event merges_ev = sort_detail ::merge_sorted_block_contig_impl (
850
+ sycl::event merges_ev = merge_sort_detail ::merge_sorted_block_contig_impl (
847
851
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
848
852
{base_sort_ev});
849
853
850
854
sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
851
855
cgh.depends_on (merges_ev);
852
856
853
857
auto temp_acc =
854
- sort_detail::GetReadOnlyAccess<decltype (res_tp)>{}(res_tp, cgh);
858
+ merge_sort_detail::GetReadOnlyAccess<decltype (res_tp)>{}(res_tp,
859
+ cgh);
855
860
856
861
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
857
862
0 commit comments