Skip to content

Commit c37f1de

Browse files
committed
Reduce code duplication by using std::optional in py_topk directly
Instead of using an overload to handle the `axis=None` case, use std::optional and check for trailing_dims_to_search in validation logic
1 parent b4d7ba4 commit c37f1de

File tree

1 file changed

+71
-168
lines changed
  • dpctl/tensor/libtensor/source/sorting

1 file changed

+71
-168
lines changed

dpctl/tensor/libtensor/source/sorting/topk.cpp

Lines changed: 71 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -94,24 +94,21 @@ struct use_radix_sort<
9494
};
9595

9696
template <typename argTy, typename IndexTy>
97-
sycl::event
98-
topk_caller(sycl::queue &exec_q,
99-
std::size_t iter_nelems, // number of sub-arrays to sort (num. of
100-
// rows in a matrix when sorting over rows)
101-
std::size_t axis_nelems, // size of each array to sort (length of
102-
// rows, i.e. number of columns)
103-
std::size_t k,
104-
bool largest,
105-
const char *arg_cp,
106-
char *vals_cp,
107-
char *inds_cp,
108-
py::ssize_t iter_arg_offset,
109-
py::ssize_t iter_vals_offset,
110-
py::ssize_t iter_inds_offset,
111-
py::ssize_t axis_arg_offset,
112-
py::ssize_t axis_vals_offset,
113-
py::ssize_t axis_inds_offset,
114-
const std::vector<sycl::event> &depends)
97+
sycl::event topk_caller(sycl::queue &exec_q,
98+
std::size_t iter_nelems, // number of sub-arrays
99+
std::size_t axis_nelems, // size of each sub-array
100+
std::size_t k,
101+
bool largest,
102+
const char *arg_cp,
103+
char *vals_cp,
104+
char *inds_cp,
105+
py::ssize_t iter_arg_offset,
106+
py::ssize_t iter_vals_offset,
107+
py::ssize_t iter_inds_offset,
108+
py::ssize_t axis_arg_offset,
109+
py::ssize_t axis_vals_offset,
110+
py::ssize_t axis_inds_offset,
111+
const std::vector<sycl::event> &depends)
115112
{
116113
if constexpr (use_radix_sort<argTy>::value) {
117114
using dpctl::tensor::kernels::topk_radix_impl;
@@ -147,7 +144,7 @@ topk_caller(sycl::queue &exec_q,
147144

148145
std::pair<sycl::event, sycl::event>
149146
py_topk(const dpctl::tensor::usm_ndarray &src,
150-
const int trailing_dims_to_search,
147+
std::optional<const int> trailing_dims_to_search,
151148
const std::size_t k,
152149
const bool largest,
153150
const dpctl::tensor::usm_ndarray &vals,
@@ -158,48 +155,70 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
158155
int src_nd = src.get_ndim();
159156
int vals_nd = vals.get_ndim();
160157
int inds_nd = inds.get_ndim();
161-
if (src_nd != vals_nd || src_nd != inds_nd) {
162-
throw py::value_error("The input and output arrays must have "
163-
"the same array ranks");
164-
}
165-
int iteration_nd = src_nd - trailing_dims_to_search;
166-
if (trailing_dims_to_search <= 0 || iteration_nd < 0) {
167-
throw py::value_error(
168-
"trailing_dims_to_search must be positive, but no "
169-
"greater than rank of the array being searched");
170-
}
171158

172159
const py::ssize_t *src_shape_ptr = src.get_shape_raw();
173160
const py::ssize_t *vals_shape_ptr = vals.get_shape_raw();
174161
const py::ssize_t *inds_shape_ptr = inds.get_shape_raw();
175162

176-
bool same_shapes = true;
163+
std::size_t axis_nelems(1);
177164
std::size_t iter_nelems(1);
178-
for (int i = 0; same_shapes && (i < iteration_nd); ++i) {
179-
auto src_shape_i = src_shape_ptr[i];
180-
same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] &&
181-
src_shape_i == inds_shape_ptr[i]);
182-
iter_nelems *= static_cast<std::size_t>(src_shape_i);
183-
}
165+
if (trailing_dims_to_search.has_value()) {
166+
if (src_nd != vals_nd || src_nd != inds_nd) {
167+
throw py::value_error("The input and output arrays must have "
168+
"the same array ranks");
169+
}
184170

185-
if (!same_shapes) {
186-
throw py::value_error(
187-
"Destination shape does not match the input shape");
188-
}
171+
auto trailing_dims = trailing_dims_to_search.value();
172+
int iter_nd = src_nd - trailing_dims;
173+
if (trailing_dims <= 0 || iter_nd < 0) {
174+
throw py::value_error(
175+
"trailing_dims_to_search must be positive, but no "
176+
"greater than rank of the array being searched");
177+
}
189178

190-
std::size_t vals_k(1);
191-
std::size_t inds_k(1);
192-
std::size_t axis_nelems(1);
193-
for (int i = iteration_nd; i < src_nd; ++i) {
194-
axis_nelems *= static_cast<std::size_t>(src_shape_ptr[i]);
195-
vals_k *= static_cast<std::size_t>(vals_shape_ptr[i]);
196-
inds_k *= static_cast<std::size_t>(inds_shape_ptr[i]);
179+
bool same_shapes = true;
180+
for (int i = 0; same_shapes && (i < iter_nd); ++i) {
181+
auto src_shape_i = src_shape_ptr[i];
182+
same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] &&
183+
src_shape_i == inds_shape_ptr[i]);
184+
iter_nelems *= static_cast<std::size_t>(src_shape_i);
185+
}
186+
187+
if (!same_shapes) {
188+
throw py::value_error(
189+
"Destination shape does not match the input shape");
190+
}
191+
192+
std::size_t vals_k(1);
193+
std::size_t inds_k(1);
194+
for (int i = iter_nd; i < src_nd; ++i) {
195+
axis_nelems *= static_cast<std::size_t>(src_shape_ptr[i]);
196+
vals_k *= static_cast<std::size_t>(vals_shape_ptr[i]);
197+
inds_k *= static_cast<std::size_t>(inds_shape_ptr[i]);
198+
}
199+
200+
bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
201+
if (!valid_k) {
202+
throw py::value_error("The value of k is invalid for the input and "
203+
"destination arrays");
204+
}
197205
}
206+
else {
207+
if (vals_nd != 1 || inds_nd != 1) {
208+
throw py::value_error("Output arrays must be one-dimensional");
209+
}
198210

199-
bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
200-
if (!valid_k) {
201-
throw py::value_error(
202-
"The value of k is invalid for the input and destination arrays");
211+
for (int i = 0; i < src_nd; ++i) {
212+
axis_nelems *= static_cast<std::size_t>(src_shape_ptr[i]);
213+
}
214+
215+
bool valid_k = (axis_nelems >= k &&
216+
static_cast<std::size_t>(vals_shape_ptr[0]) == k &&
217+
static_cast<std::size_t>(inds_shape_ptr[0]) == k);
218+
if (!valid_k) {
219+
throw py::value_error("The value of k is invalid for the input and "
220+
"destination arrays");
221+
}
203222
}
204223

205224
if (!dpctl::utils::queues_are_compatible(exec_q, {src, vals, inds})) {
@@ -267,103 +286,6 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
267286
return std::make_pair(sycl::event(), sycl::event());
268287
}
269288

270-
std::pair<sycl::event, sycl::event>
271-
py_topk(const dpctl::tensor::usm_ndarray &src,
272-
const std::size_t k,
273-
const bool largest,
274-
const dpctl::tensor::usm_ndarray &vals,
275-
const dpctl::tensor::usm_ndarray &inds,
276-
sycl::queue &exec_q,
277-
const std::vector<sycl::event> &depends)
278-
{
279-
int src_nd = src.get_ndim();
280-
int vals_nd = vals.get_ndim();
281-
int inds_nd = inds.get_ndim();
282-
if (vals_nd != 1 || inds_nd != 1) {
283-
throw py::value_error("Output arrays must be one-dimensional");
284-
}
285-
286-
const py::ssize_t *src_shape_ptr = src.get_shape_raw();
287-
const py::ssize_t *vals_shape_ptr = vals.get_shape_raw();
288-
const py::ssize_t *inds_shape_ptr = inds.get_shape_raw();
289-
290-
std::size_t axis_nelems(1);
291-
for (int i = 0; i < src_nd; ++i) {
292-
axis_nelems *= static_cast<std::size_t>(src_shape_ptr[i]);
293-
}
294-
295-
bool valid_k =
296-
(axis_nelems >= k && static_cast<std::size_t>(vals_shape_ptr[0]) == k &&
297-
static_cast<std::size_t>(inds_shape_ptr[0]) == k);
298-
if (!valid_k) {
299-
throw py::value_error(
300-
"The value of k is invalid for the input and destination arrays");
301-
}
302-
303-
if (!dpctl::utils::queues_are_compatible(exec_q, {src, vals, inds})) {
304-
throw py::value_error(
305-
"Execution queue is not compatible with allocation queues");
306-
}
307-
308-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vals);
309-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(inds);
310-
311-
if (axis_nelems == 0) {
312-
// Nothing to do
313-
return std::make_pair(sycl::event(), sycl::event());
314-
}
315-
316-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
317-
if (overlap(src, vals) || overlap(src, inds)) {
318-
throw py::value_error("Arrays index overlapping segments of memory");
319-
}
320-
321-
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vals, k);
322-
323-
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(inds, k);
324-
325-
int src_typenum = src.get_typenum();
326-
int vals_typenum = vals.get_typenum();
327-
int inds_typenum = inds.get_typenum();
328-
329-
const auto &array_types = td_ns::usm_ndarray_types();
330-
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
331-
int vals_typeid = array_types.typenum_to_lookup_id(vals_typenum);
332-
int inds_typeid = array_types.typenum_to_lookup_id(inds_typenum);
333-
334-
if (src_typeid != vals_typeid) {
335-
throw py::value_error("Input array and vals array must have "
336-
"the same data type");
337-
}
338-
339-
if (inds_typeid != static_cast<int>(td_ns::typenum_t::INT64)) {
340-
throw py::value_error("Inds array must have data type int64");
341-
}
342-
343-
bool is_src_c_contig = src.is_c_contiguous();
344-
bool is_vals_c_contig = vals.is_c_contiguous();
345-
bool is_inds_c_contig = inds.is_c_contiguous();
346-
347-
if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) {
348-
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
349-
static constexpr std::size_t iter_nelems = 1;
350-
351-
auto fn = topk_dispatch_vector[src_typeid];
352-
353-
sycl::event comp_ev =
354-
fn(exec_q, iter_nelems, axis_nelems, k, largest, src.get_data(),
355-
vals.get_data(), inds.get_data(), zero_offset, zero_offset,
356-
zero_offset, zero_offset, zero_offset, zero_offset, depends);
357-
358-
sycl::event keep_args_alive_ev =
359-
dpctl::utils::keep_args_alive(exec_q, {src, vals, inds}, {comp_ev});
360-
361-
return std::make_pair(keep_args_alive_ev, comp_ev);
362-
}
363-
364-
return std::make_pair(sycl::event(), sycl::event());
365-
}
366-
367289
template <typename fnT, typename T> struct TopKFactory
368290
{
369291
fnT get()
@@ -385,26 +307,7 @@ void init_topk_functions(py::module_ m)
385307
{
386308
dpctl::tensor::py_internal::init_topk_dispatch_vectors();
387309

388-
auto py_topk = [](const dpctl::tensor::usm_ndarray &src,
389-
std::optional<const int> trailing_dims_to_search,
390-
const std::size_t k, const bool largest,
391-
const dpctl::tensor::usm_ndarray &vals,
392-
const dpctl::tensor::usm_ndarray &inds,
393-
sycl::queue &exec_q,
394-
const std::vector<sycl::event> &depends)
395-
-> std::pair<sycl::event, sycl::event> {
396-
if (trailing_dims_to_search) {
397-
return dpctl::tensor::py_internal::py_topk(
398-
src, trailing_dims_to_search.value(), k, largest, vals, inds,
399-
exec_q, depends);
400-
}
401-
else {
402-
return dpctl::tensor::py_internal::py_topk(src, k, largest, vals,
403-
inds, exec_q, depends);
404-
}
405-
};
406-
407-
m.def("_topk", py_topk, py::arg("src"), py::arg("trailing_dims_to_search"),
310+
m.def("_topk", &py_topk, py::arg("src"), py::arg("trailing_dims_to_search"),
408311
py::arg("k"), py::arg("largest"), py::arg("vals"), py::arg("inds"),
409312
py::arg("sycl_queue"), py::arg("depends") = py::list());
410313
}

0 commit comments

Comments
 (0)