Skip to content

Commit dbe62d0

Browse files
Added Python API to exercise radix sort functions
1 parent 3517745 commit dbe62d0

File tree

11 files changed

+708
-207
lines changed

11 files changed

+708
-207
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ set(_reduction_sources
114114
set(_sorting_sources
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
116116
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
117+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
118+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
117119
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
118120
)
119121
set(_static_lib_sources

dpctl/tensor/libtensor/source/sorting/argsort.cpp

Lines changed: 3 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
#include "utils/output_validation.hpp"
3333
#include "utils/type_dispatch.hpp"
3434

35-
#include "argsort.hpp"
3635
#include "kernels/sorting/merge_sort.hpp"
3736
#include "rich_comparisons.hpp"
3837

38+
#include "argsort.hpp"
39+
#include "py_argsort_common.hpp"
40+
3941
namespace td_ns = dpctl::tensor::type_dispatch;
4042

4143
namespace dpctl
@@ -45,112 +47,6 @@ namespace tensor
4547
namespace py_internal
4648
{
4749

48-
template <typename sorting_contig_impl_fnT>
49-
std::pair<sycl::event, sycl::event>
50-
py_argsort(const dpctl::tensor::usm_ndarray &src,
51-
const int trailing_dims_to_sort,
52-
const dpctl::tensor::usm_ndarray &dst,
53-
sycl::queue &exec_q,
54-
const std::vector<sycl::event> &depends,
55-
const sorting_contig_impl_fnT &stable_sort_contig_fns)
56-
{
57-
int src_nd = src.get_ndim();
58-
int dst_nd = dst.get_ndim();
59-
if (src_nd != dst_nd) {
60-
throw py::value_error("The input and output arrays must have "
61-
"the same array ranks");
62-
}
63-
int iteration_nd = src_nd - trailing_dims_to_sort;
64-
if (trailing_dims_to_sort <= 0 || iteration_nd < 0) {
65-
throw py::value_error("Trailing_dim_to_sort must be positive, but no "
66-
"greater than rank of the array being sorted");
67-
}
68-
69-
const py::ssize_t *src_shape_ptr = src.get_shape_raw();
70-
const py::ssize_t *dst_shape_ptr = dst.get_shape_raw();
71-
72-
bool same_shapes = true;
73-
size_t iter_nelems(1);
74-
75-
for (int i = 0; same_shapes && (i < iteration_nd); ++i) {
76-
auto src_shape_i = src_shape_ptr[i];
77-
same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]);
78-
iter_nelems *= static_cast<size_t>(src_shape_i);
79-
}
80-
81-
size_t sort_nelems(1);
82-
for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) {
83-
auto src_shape_i = src_shape_ptr[i];
84-
same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]);
85-
sort_nelems *= static_cast<size_t>(src_shape_i);
86-
}
87-
88-
if (!same_shapes) {
89-
throw py::value_error(
90-
"Destination shape does not match the input shape");
91-
}
92-
93-
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
94-
throw py::value_error(
95-
"Execution queue is not compatible with allocation queues");
96-
}
97-
98-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
99-
100-
if ((iter_nelems == 0) || (sort_nelems == 0)) {
101-
// Nothing to do
102-
return std::make_pair(sycl::event(), sycl::event());
103-
}
104-
105-
// check that dst and src do not overlap
106-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
107-
if (overlap(src, dst)) {
108-
throw py::value_error("Arrays index overlapping segments of memory");
109-
}
110-
111-
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
112-
dst, sort_nelems * iter_nelems);
113-
114-
int src_typenum = src.get_typenum();
115-
int dst_typenum = dst.get_typenum();
116-
117-
const auto &array_types = td_ns::usm_ndarray_types();
118-
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
119-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
120-
121-
if ((dst_typeid != static_cast<int>(td_ns::typenum_t::INT64)) &&
122-
(dst_typeid != static_cast<int>(td_ns::typenum_t::INT32)))
123-
{
124-
throw py::value_error(
125-
"Output index array must have data type int32 or int64");
126-
}
127-
128-
bool is_src_c_contig = src.is_c_contiguous();
129-
bool is_dst_c_contig = dst.is_c_contiguous();
130-
131-
if (is_src_c_contig && is_dst_c_contig) {
132-
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
133-
134-
auto fn = stable_sort_contig_fns[src_typeid][dst_typeid];
135-
136-
if (fn == nullptr) {
137-
throw py::value_error("Not implemented for given index type");
138-
}
139-
140-
sycl::event comp_ev =
141-
fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(),
142-
zero_offset, zero_offset, zero_offset, zero_offset, depends);
143-
144-
sycl::event keep_args_alive_ev =
145-
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
146-
147-
return std::make_pair(keep_args_alive_ev, comp_ev);
148-
}
149-
150-
throw py::value_error(
151-
"Both source and destination arrays must be C-contiguous");
152-
}
153-
15450
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
15551
static sort_contig_fn_ptr_t
15652
ascending_argsort_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
22+
/// extension.
23+
//===--------------------------------------------------------------------===//
24+
25+
#include "dpctl4pybind11.hpp"
26+
#include <pybind11/pybind11.h>
27+
#include <pybind11/stl.h>
28+
#include <sycl/sycl.hpp>
29+
30+
#include "utils/math_utils.hpp"
31+
#include "utils/memory_overlap.hpp"
32+
#include "utils/output_validation.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
35+
#include "argsort.hpp"
36+
#include "kernels/sorting/merge_sort.hpp"
37+
#include "rich_comparisons.hpp"
38+
39+
namespace td_ns = dpctl::tensor::type_dispatch;
40+
41+
namespace dpctl
42+
{
43+
namespace tensor
44+
{
45+
namespace py_internal
46+
{
47+
48+
template <typename sorting_contig_impl_fnT>
49+
std::pair<sycl::event, sycl::event>
50+
py_argsort(const dpctl::tensor::usm_ndarray &src,
51+
const int trailing_dims_to_sort,
52+
const dpctl::tensor::usm_ndarray &dst,
53+
sycl::queue &exec_q,
54+
const std::vector<sycl::event> &depends,
55+
const sorting_contig_impl_fnT &stable_sort_contig_fns)
56+
{
57+
int src_nd = src.get_ndim();
58+
int dst_nd = dst.get_ndim();
59+
if (src_nd != dst_nd) {
60+
throw py::value_error("The input and output arrays must have "
61+
"the same array ranks");
62+
}
63+
int iteration_nd = src_nd - trailing_dims_to_sort;
64+
if (trailing_dims_to_sort <= 0 || iteration_nd < 0) {
65+
throw py::value_error("Trailing_dim_to_sort must be positive, but no "
66+
"greater than rank of the array being sorted");
67+
}
68+
69+
const py::ssize_t *src_shape_ptr = src.get_shape_raw();
70+
const py::ssize_t *dst_shape_ptr = dst.get_shape_raw();
71+
72+
bool same_shapes = true;
73+
size_t iter_nelems(1);
74+
75+
for (int i = 0; same_shapes && (i < iteration_nd); ++i) {
76+
auto src_shape_i = src_shape_ptr[i];
77+
same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]);
78+
iter_nelems *= static_cast<size_t>(src_shape_i);
79+
}
80+
81+
size_t sort_nelems(1);
82+
for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) {
83+
auto src_shape_i = src_shape_ptr[i];
84+
same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]);
85+
sort_nelems *= static_cast<size_t>(src_shape_i);
86+
}
87+
88+
if (!same_shapes) {
89+
throw py::value_error(
90+
"Destination shape does not match the input shape");
91+
}
92+
93+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
94+
throw py::value_error(
95+
"Execution queue is not compatible with allocation queues");
96+
}
97+
98+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
99+
100+
if ((iter_nelems == 0) || (sort_nelems == 0)) {
101+
// Nothing to do
102+
return std::make_pair(sycl::event(), sycl::event());
103+
}
104+
105+
// check that dst and src do not overlap
106+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
107+
if (overlap(src, dst)) {
108+
throw py::value_error("Arrays index overlapping segments of memory");
109+
}
110+
111+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
112+
dst, sort_nelems * iter_nelems);
113+
114+
int src_typenum = src.get_typenum();
115+
int dst_typenum = dst.get_typenum();
116+
117+
const auto &array_types = td_ns::usm_ndarray_types();
118+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
119+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
120+
121+
if ((dst_typeid != static_cast<int>(td_ns::typenum_t::INT64)) &&
122+
(dst_typeid != static_cast<int>(td_ns::typenum_t::INT32)))
123+
{
124+
throw py::value_error(
125+
"Output index array must have data type int32 or int64");
126+
}
127+
128+
bool is_src_c_contig = src.is_c_contiguous();
129+
bool is_dst_c_contig = dst.is_c_contiguous();
130+
131+
if (is_src_c_contig && is_dst_c_contig) {
132+
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
133+
134+
auto fn = stable_sort_contig_fns[src_typeid][dst_typeid];
135+
136+
if (fn == nullptr) {
137+
throw py::value_error("Not implemented for given index type");
138+
}
139+
140+
sycl::event comp_ev =
141+
fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(),
142+
zero_offset, zero_offset, zero_offset, zero_offset, depends);
143+
144+
sycl::event keep_args_alive_ev =
145+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
146+
147+
return std::make_pair(keep_args_alive_ev, comp_ev);
148+
}
149+
150+
throw py::value_error(
151+
"Both source and destination arrays must be C-contiguous");
152+
}
153+
154+
} // end of namespace py_internal
155+
} // end of namespace tensor
156+
} // end of namespace dpctl

0 commit comments

Comments
 (0)