Skip to content

Commit 830f428

Browse files
Renamed some files in sorting folders, in preparation for introduction of radix sort
1 parent 08605c4 commit 830f428

File tree

5 files changed

+26
-21
lines changed

5 files changed

+26
-21
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp renamed to dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#include <vector>
3333

3434
#include "kernels/dpctl_tensor_types.hpp"
35-
#include "kernels/sorting/sort_detail.hpp"
35+
#include "kernels/sorting/search_sorted_detail.hpp"
3636

3737
namespace dpctl
3838
{
@@ -41,9 +41,11 @@ namespace tensor
4141
namespace kernels
4242
{
4343

44-
namespace sort_detail
44+
namespace merge_sort_detail
4545
{
4646

47+
using namespace dpctl::tensor::kernels::search_sorted_detail;
48+
4749
/*! @brief Merge two contiguous sorted segments */
4850
template <typename InAcc, typename OutAcc, typename Compare>
4951
void merge_impl(const std::size_t offset,
@@ -699,7 +701,7 @@ merge_sorted_block_contig_impl(sycl::queue &q,
699701
return dep_ev;
700702
}
701703

702-
} // end of namespace sort_detail
704+
} // end of namespace merge_sort_detail
703705

704706
typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &,
705707
size_t,
@@ -741,8 +743,8 @@ sycl::event stable_sort_axis1_contig_impl(
741743
if (sort_nelems < sequential_sorting_threshold) {
742744
// equal work-item sorts entire row
743745
sycl::event sequential_sorting_ev =
744-
sort_detail::sort_base_step_contig_impl<const argTy *, argTy *,
745-
Comp>(
746+
merge_sort_detail::sort_base_step_contig_impl<const argTy *,
747+
argTy *, Comp>(
746748
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
747749
sort_nelems, depends);
748750

@@ -753,16 +755,16 @@ sycl::event stable_sort_axis1_contig_impl(
753755

754756
// Sort segments of the array
755757
sycl::event base_sort_ev =
756-
sort_detail::sort_over_work_group_contig_impl<const argTy *,
757-
argTy *, Comp>(
758+
merge_sort_detail::sort_over_work_group_contig_impl<const argTy *,
759+
argTy *, Comp>(
758760
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
759761
sorted_block_size, // modified in place with size of sorted
760762
// block size
761763
depends);
762764

763765
// Merge segments in parallel until all elements are sorted
764766
sycl::event merges_ev =
765-
sort_detail::merge_sorted_block_contig_impl<argTy *, Comp>(
767+
merge_sort_detail::merge_sorted_block_contig_impl<argTy *, Comp>(
766768
exec_q, iter_nelems, sort_nelems, res_tp, comp,
767769
sorted_block_size, {base_sort_ev});
768770

@@ -837,21 +839,24 @@ sycl::event stable_argsort_axis1_contig_impl(
837839
});
838840

839841
// Sort segments of the array
840-
sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl(
841-
exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
842-
sorted_block_size, // modified in place with size of sorted block size
843-
{populate_indexed_data_ev});
842+
sycl::event base_sort_ev =
843+
merge_sort_detail::sort_over_work_group_contig_impl(
844+
exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
845+
sorted_block_size, // modified in place with size of sorted block
846+
// size
847+
{populate_indexed_data_ev});
844848

845849
// Merge segments in parallel until all elements are sorted
846-
sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl(
850+
sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl(
847851
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
848852
{base_sort_ev});
849853

850854
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
851855
cgh.depends_on(merges_ev);
852856

853857
auto temp_acc =
854-
sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp, cgh);
858+
merge_sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp,
859+
cgh);
855860

856861
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
857862

dpctl/tensor/libtensor/include/kernels/sorting/sort_detail.hpp renamed to dpctl/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace tensor
3535
namespace kernels
3636
{
3737

38-
namespace sort_detail
38+
namespace search_sorted_detail
3939
{
4040

4141
template <typename T> T quotient_ceil(T n, T m) { return (n + m - 1) / m; }
@@ -111,7 +111,7 @@ std::size_t upper_bound_indexed_impl(const Acc acc,
111111
acc_indexer);
112112
}
113113

114-
} // namespace sort_detail
114+
} // namespace search_sorted_detail
115115

116116
} // namespace kernels
117117
} // namespace tensor

dpctl/tensor/libtensor/include/kernels/sorting/searchsorted.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#include <vector>
3232

3333
#include "kernels/dpctl_tensor_types.hpp"
34-
#include "kernels/sorting/sort_detail.hpp"
34+
#include "kernels/sorting/search_sorted_detail.hpp"
3535
#include "utils/offset_utils.hpp"
3636

3737
namespace dpctl
@@ -91,7 +91,7 @@ struct SearchSortedFunctor
9191

9292
// lower_bound returns the first pos such that bool(hay[pos] <
9393
// needle_v) is false, i.e. needle_v <= hay[pos]
94-
pos = sort_detail::lower_bound_indexed_impl(
94+
pos = search_sorted_detail::lower_bound_indexed_impl(
9595
hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer);
9696
}
9797
else {
@@ -100,7 +100,7 @@ struct SearchSortedFunctor
100100

101101
// upper_bound returns the first pos such that bool(needle_v <
102102
// hay[pos]) is true, i.e. needle_v < hay[pos]
103-
pos = sort_detail::upper_bound_indexed_impl(
103+
pos = search_sorted_detail::upper_bound_indexed_impl(
104104
hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer);
105105
}
106106

dpctl/tensor/libtensor/source/sorting/argsort.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "utils/type_dispatch.hpp"
3434

3535
#include "argsort.hpp"
36-
#include "kernels/sorting/sort.hpp"
36+
#include "kernels/sorting/merge_sort.hpp"
3737
#include "rich_comparisons.hpp"
3838

3939
namespace td_ns = dpctl::tensor::type_dispatch;

dpctl/tensor/libtensor/source/sorting/sort.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "utils/output_validation.hpp"
3434
#include "utils/type_dispatch.hpp"
3535

36-
#include "kernels/sorting/sort.hpp"
36+
#include "kernels/sorting/merge_sort.hpp"
3737
#include "rich_comparisons.hpp"
3838
#include "sort.hpp"
3939

0 commit comments

Comments
 (0)