Skip to content

Commit 6d33867

Browse files
Add license headers
Add common typedef for sort_contig_fn_ptr_t from merge_sort header and radix_sort header files into a new file. Used it in cpp files.
1 parent dbe62d0 commit 6d33867

File tree

9 files changed

+179
-56
lines changed

9 files changed

+179
-56
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -703,17 +703,6 @@ merge_sorted_block_contig_impl(sycl::queue &q,
703703

704704
} // end of namespace merge_sort_detail
705705

706-
typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &,
707-
size_t,
708-
size_t,
709-
const char *,
710-
char *,
711-
ssize_t,
712-
ssize_t,
713-
ssize_t,
714-
ssize_t,
715-
const std::vector<sycl::event> &);
716-
717706
template <typename argTy, typename Comp = std::less<argTy>>
718707
sycl::event stable_sort_axis1_contig_impl(
719708
sycl::queue &exec_q,

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
22+
/// extension.
23+
//===--------------------------------------------------------------------===//
24+
25+
// Implementation in this file were adapted from oneDPL's radix sort
26+
// implementation, license Apache-2.0 WITH LLVM-exception
27+
128
#pragma once
229

330
#include <cstdint>
@@ -217,7 +244,7 @@ radix_sort_count_submit(sycl::queue &exec_q,
217244

218245
// iteration space info
219246
const std::size_t n = n_values;
220-
// Each segment is processed by a work-group
247+
// each segment is processed by a work-group
221248
const std::size_t elems_per_segment = (n + n_segments - 1) / n_segments;
222249
const std::size_t no_op_flag_id = n_counts - 1;
223250

@@ -372,8 +399,7 @@ sycl::event radix_sort_scan_submit(sycl::queue &exec_q,
372399
const auto lid = ndit.get_local_linear_id();
373400

374401
// NB: No race condition here, because the condition may ever be
375-
// true
376-
// for only on one WG, one WI.
402+
// true for only on one WG, one WI.
377403
if ((lid == wg_size - 1) && (begin_ptr[scan_size - 1] == n_values))
378404
{
379405
// set flag, since all the values got into one
@@ -613,7 +639,6 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
613639
const ProjT &proj_op,
614640
const std::vector<sycl::event> dependency_events)
615641
{
616-
// typedefs
617642
using ValueT = InputT;
618643
using PeerHelper = peer_prefix_helper<OffsetT, PeerAlgo>;
619644

@@ -649,7 +674,6 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
649674

650675
sycl::nd_range<1> ndRange{gRange, lRange};
651676

652-
// Each work-group processes one segment ?
653677
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> ndit) {
654678
const std::size_t group_id = ndit.get_group(0);
655679
const std::size_t iter_id = group_id / n_segments;
@@ -670,9 +694,9 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
670694
return;
671695
}
672696

673-
// 1. create a private array for storing offset values
674-
// and add total offset and offset for compute unit for a certain
675-
// radix state
697+
// create a private array for storing offset values
698+
// and add total offset and offset for compute unit
699+
// for a certain radix state
676700
std::array<OffsetT, radix_states> offset_arr{};
677701
const std::size_t scan_size = n_segments + 1;
678702

@@ -688,7 +712,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
688712
const std::uint32_t local_offset_id =
689713
segment_id + scan_size * radix_state_id;
690714

691-
// scan bins (serial)
715+
// scan bins serially
692716
const std::size_t last_segment_bucket_id =
693717
radix_state_id * scan_size - 1;
694718
scanned_bin += b_offset_ptr[last_segment_bucket_id];
@@ -739,7 +763,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
739763
if (tail_size > 0) {
740764
ValueT in_val;
741765

742-
// greater than any actual radix state
766+
// default: is greater than any actual radix state
743767
std::uint32_t bucket_id = radix_states;
744768
if (lid < tail_size) {
745769
in_val = std::move(b_input_ptr[seg_end + lid]);
@@ -749,6 +773,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
749773
bucket_id =
750774
get_bucket_id<radix_mask>(mapped_val, radix_offset);
751775
}
776+
752777
OffsetT new_offset_id = 0;
753778
for (std::uint32_t radix_state_id = 0;
754779
radix_state_id < radix_states; ++radix_state_id)
@@ -761,6 +786,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
761786

762787
offset_arr[radix_state_id] += sg_total_offset;
763788
}
789+
764790
if (lid < tail_size) {
765791
b_output_ptr[new_offset_id] = std::move(in_val);
766792
}
@@ -1589,19 +1615,6 @@ template <typename IndexT, typename ValueT, typename ProjT> struct IndexedProj
15891615

15901616
} // end of namespace radix_sort_details
15911617

1592-
// same signature as sort_contig_fn_ptr_t
1593-
typedef sycl::event (*radix_sort_contig_fn_ptr_t)(
1594-
sycl::queue &,
1595-
size_t,
1596-
size_t,
1597-
const char *,
1598-
char *,
1599-
ssize_t,
1600-
ssize_t,
1601-
ssize_t,
1602-
ssize_t,
1603-
const std::vector<sycl::event> &);
1604-
16051618
template <typename argTy, bool sort_ascending>
16061619
sycl::event
16071620
radix_sort_axis1_contig_impl(sycl::queue &exec_q,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
22+
/// extension.
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
27+
#include <sycl/sycl.hpp>
28+
#include <vector>
29+
30+
namespace dpctl
31+
{
32+
namespace tensor
33+
{
34+
namespace kernels
35+
{
36+
37+
typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &,
38+
size_t,
39+
size_t,
40+
const char *,
41+
char *,
42+
ssize_t,
43+
ssize_t,
44+
ssize_t,
45+
ssize_t,
46+
const std::vector<sycl::event> &);
47+
48+
}
49+
} // namespace tensor
50+
} // namespace dpctl

dpctl/tensor/libtensor/source/sorting/argsort.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "utils/type_dispatch.hpp"
3434

3535
#include "kernels/sorting/merge_sort.hpp"
36+
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
3637
#include "rich_comparisons.hpp"
3738

3839
#include "argsort.hpp"

dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@
3232
#include "utils/output_validation.hpp"
3333
#include "utils/type_dispatch.hpp"
3434

35-
#include "argsort.hpp"
36-
#include "kernels/sorting/merge_sort.hpp"
37-
#include "rich_comparisons.hpp"
38-
3935
namespace td_ns = dpctl::tensor::type_dispatch;
4036

4137
namespace dpctl

dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
22+
/// extension.
23+
//===--------------------------------------------------------------------===//
24+
125
#include <cstdint>
226
#include <utility>
327
#include <vector>
@@ -15,10 +39,11 @@
1539
#include "utils/type_dispatch.hpp"
1640

1741
#include "kernels/sorting/radix_sort.hpp"
18-
#include "radix_sort_support.hpp"
42+
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
1943

2044
#include "py_argsort_common.hpp"
2145
#include "radix_argsort.hpp"
46+
#include "radix_sort_support.hpp"
2247

2348
namespace dpctl
2449
{
@@ -30,12 +55,12 @@ namespace py_internal
3055
namespace td_ns = dpctl::tensor::type_dispatch;
3156
namespace impl_ns = dpctl::tensor::kernels::radix_sort_details;
3257

33-
using dpctl::tensor::kernels::radix_sort_contig_fn_ptr_t;
58+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
3459

35-
static radix_sort_contig_fn_ptr_t
60+
static sort_contig_fn_ptr_t
3661
ascending_radix_argsort_contig_dispatch_table[td_ns::num_types]
3762
[td_ns::num_types];
38-
static radix_sort_contig_fn_ptr_t
63+
static sort_contig_fn_ptr_t
3964
descending_radix_argsort_contig_dispatch_table[td_ns::num_types]
4065
[td_ns::num_types];
4166

@@ -79,15 +104,15 @@ struct DescendingRadixArgSortContigFactory
79104

80105
void init_radix_argsort_dispatch_tables(void)
81106
{
82-
using dpctl::tensor::kernels::radix_sort_contig_fn_ptr_t;
107+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
83108

84-
td_ns::DispatchTableBuilder<radix_sort_contig_fn_ptr_t,
109+
td_ns::DispatchTableBuilder<sort_contig_fn_ptr_t,
85110
AscendingRadixArgSortContigFactory,
86111
td_ns::num_types>
87112
dtb1;
88113
dtb1.populate_dispatch_table(ascending_radix_argsort_contig_dispatch_table);
89114

90-
td_ns::DispatchTableBuilder<radix_sort_contig_fn_ptr_t,
115+
td_ns::DispatchTableBuilder<sort_contig_fn_ptr_t,
91116
DescendingRadixArgSortContigFactory,
92117
td_ns::num_types>
93118
dtb2;

dpctl/tensor/libtensor/source/sorting/radix_sort.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
22+
/// extension.
23+
//===--------------------------------------------------------------------===//
24+
125
#include <cstdint>
226
#include <utility>
327
#include <vector>
@@ -15,10 +39,11 @@
1539
#include "utils/type_dispatch.hpp"
1640

1741
#include "kernels/sorting/radix_sort.hpp"
18-
#include "radix_sort_support.hpp"
42+
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
1943

2044
#include "py_sort_common.hpp"
2145
#include "radix_sort.hpp"
46+
#include "radix_sort_support.hpp"
2247

2348
namespace dpctl
2449
{
@@ -30,10 +55,10 @@ namespace py_internal
3055
namespace td_ns = dpctl::tensor::type_dispatch;
3156
namespace impl_ns = dpctl::tensor::kernels::radix_sort_details;
3257

33-
using dpctl::tensor::kernels::radix_sort_contig_fn_ptr_t;
34-
static radix_sort_contig_fn_ptr_t
58+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
59+
static sort_contig_fn_ptr_t
3560
ascending_radix_sort_contig_dispatch_vector[td_ns::num_types];
36-
static radix_sort_contig_fn_ptr_t
61+
static sort_contig_fn_ptr_t
3762
descending_radix_sort_contig_dispatch_vector[td_ns::num_types];
3863

3964
template <typename fnT, typename argTy> struct AscendingRadixSortContigFactory
@@ -66,15 +91,14 @@ template <typename fnT, typename argTy> struct DescendingRadixSortContigFactory
6691

6792
void init_radix_sort_dispatch_vectors(void)
6893
{
69-
using dpctl::tensor::kernels::radix_sort_contig_fn_ptr_t;
94+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
7095

71-
td_ns::DispatchVectorBuilder<radix_sort_contig_fn_ptr_t,
72-
AscendingRadixSortContigFactory,
73-
td_ns::num_types>
96+
td_ns::DispatchVectorBuilder<
97+
sort_contig_fn_ptr_t, AscendingRadixSortContigFactory, td_ns::num_types>
7498
dtv1;
7599
dtv1.populate_dispatch_vector(ascending_radix_sort_contig_dispatch_vector);
76100

77-
td_ns::DispatchVectorBuilder<radix_sort_contig_fn_ptr_t,
101+
td_ns::DispatchVectorBuilder<sort_contig_fn_ptr_t,
78102
DescendingRadixSortContigFactory,
79103
td_ns::num_types>
80104
dtv2;

dpctl/tensor/libtensor/source/sorting/radix_sort_support.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
22+
/// extension.
23+
//===--------------------------------------------------------------------===//
24+
125
#pragma once
226

327
#include <type_traits>

0 commit comments

Comments
 (0)