@@ -94,24 +94,21 @@ struct use_radix_sort<
9494};
9595
9696template <typename argTy, typename IndexTy>
97- sycl::event
98- topk_caller (sycl::queue &exec_q,
99- std::size_t iter_nelems, // number of sub-arrays to sort (num. of
100- // rows in a matrix when sorting over rows)
101- std::size_t axis_nelems, // size of each array to sort (length of
102- // rows, i.e. number of columns)
103- std::size_t k,
104- bool largest,
105- const char *arg_cp,
106- char *vals_cp,
107- char *inds_cp,
108- py::ssize_t iter_arg_offset,
109- py::ssize_t iter_vals_offset,
110- py::ssize_t iter_inds_offset,
111- py::ssize_t axis_arg_offset,
112- py::ssize_t axis_vals_offset,
113- py::ssize_t axis_inds_offset,
114- const std::vector<sycl::event> &depends)
97+ sycl::event topk_caller (sycl::queue &exec_q,
98+ std::size_t iter_nelems, // number of sub-arrays
99+ std::size_t axis_nelems, // size of each sub-array
100+ std::size_t k,
101+ bool largest,
102+ const char *arg_cp,
103+ char *vals_cp,
104+ char *inds_cp,
105+ py::ssize_t iter_arg_offset,
106+ py::ssize_t iter_vals_offset,
107+ py::ssize_t iter_inds_offset,
108+ py::ssize_t axis_arg_offset,
109+ py::ssize_t axis_vals_offset,
110+ py::ssize_t axis_inds_offset,
111+ const std::vector<sycl::event> &depends)
115112{
116113 if constexpr (use_radix_sort<argTy>::value) {
117114 using dpctl::tensor::kernels::topk_radix_impl;
@@ -147,7 +144,7 @@ topk_caller(sycl::queue &exec_q,
147144
148145std::pair<sycl::event, sycl::event>
149146py_topk (const dpctl::tensor::usm_ndarray &src,
150- const int trailing_dims_to_search,
147+ std::optional< const int > trailing_dims_to_search,
151148 const std::size_t k,
152149 const bool largest,
153150 const dpctl::tensor::usm_ndarray &vals,
@@ -158,48 +155,70 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
158155 int src_nd = src.get_ndim ();
159156 int vals_nd = vals.get_ndim ();
160157 int inds_nd = inds.get_ndim ();
161- if (src_nd != vals_nd || src_nd != inds_nd) {
162- throw py::value_error (" The input and output arrays must have "
163- " the same array ranks" );
164- }
165- int iteration_nd = src_nd - trailing_dims_to_search;
166- if (trailing_dims_to_search <= 0 || iteration_nd < 0 ) {
167- throw py::value_error (
168- " trailing_dims_to_search must be positive, but no "
169- " greater than rank of the array being searched" );
170- }
171158
172159 const py::ssize_t *src_shape_ptr = src.get_shape_raw ();
173160 const py::ssize_t *vals_shape_ptr = vals.get_shape_raw ();
174161 const py::ssize_t *inds_shape_ptr = inds.get_shape_raw ();
175162
176- bool same_shapes = true ;
163+ std:: size_t axis_nelems ( 1 ) ;
177164 std::size_t iter_nelems (1 );
178- for (int i = 0 ; same_shapes && (i < iteration_nd); ++i) {
179- auto src_shape_i = src_shape_ptr[i];
180- same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] &&
181- src_shape_i == inds_shape_ptr[i]);
182- iter_nelems *= static_cast <std::size_t >(src_shape_i);
183- }
165+ if (trailing_dims_to_search.has_value ()) {
166+ if (src_nd != vals_nd || src_nd != inds_nd) {
167+ throw py::value_error (" The input and output arrays must have "
168+ " the same array ranks" );
169+ }
184170
185- if (!same_shapes) {
186- throw py::value_error (
187- " Destination shape does not match the input shape" );
188- }
171+ auto trailing_dims = trailing_dims_to_search.value ();
172+ int iter_nd = src_nd - trailing_dims;
173+ if (trailing_dims <= 0 || iter_nd < 0 ) {
174+ throw py::value_error (
175+ " trailing_dims_to_search must be positive, but no "
176+ " greater than rank of the array being searched" );
177+ }
189178
190- std::size_t vals_k (1 );
191- std::size_t inds_k (1 );
192- std::size_t axis_nelems (1 );
193- for (int i = iteration_nd; i < src_nd; ++i) {
194- axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
195- vals_k *= static_cast <std::size_t >(vals_shape_ptr[i]);
196- inds_k *= static_cast <std::size_t >(inds_shape_ptr[i]);
179+ bool same_shapes = true ;
180+ for (int i = 0 ; same_shapes && (i < iter_nd); ++i) {
181+ auto src_shape_i = src_shape_ptr[i];
182+ same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] &&
183+ src_shape_i == inds_shape_ptr[i]);
184+ iter_nelems *= static_cast <std::size_t >(src_shape_i);
185+ }
186+
187+ if (!same_shapes) {
188+ throw py::value_error (
189+ " Destination shape does not match the input shape" );
190+ }
191+
192+ std::size_t vals_k (1 );
193+ std::size_t inds_k (1 );
194+ for (int i = iter_nd; i < src_nd; ++i) {
195+ axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
196+ vals_k *= static_cast <std::size_t >(vals_shape_ptr[i]);
197+ inds_k *= static_cast <std::size_t >(inds_shape_ptr[i]);
198+ }
199+
200+ bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
201+ if (!valid_k) {
202+ throw py::value_error (" The value of k is invalid for the input and "
203+ " destination arrays" );
204+ }
197205 }
206+ else {
207+ if (vals_nd != 1 || inds_nd != 1 ) {
208+ throw py::value_error (" Output arrays must be one-dimensional" );
209+ }
198210
199- bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
200- if (!valid_k) {
201- throw py::value_error (
202- " The value of k is invalid for the input and destination arrays" );
211+ for (int i = 0 ; i < src_nd; ++i) {
212+ axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
213+ }
214+
215+ bool valid_k = (axis_nelems >= k &&
216+ static_cast <std::size_t >(vals_shape_ptr[0 ]) == k &&
217+ static_cast <std::size_t >(inds_shape_ptr[0 ]) == k);
218+ if (!valid_k) {
219+ throw py::value_error (" The value of k is invalid for the input and "
220+ " destination arrays" );
221+ }
203222 }
204223
205224 if (!dpctl::utils::queues_are_compatible (exec_q, {src, vals, inds})) {
@@ -267,103 +286,6 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
267286 return std::make_pair (sycl::event (), sycl::event ());
268287}
269288
270- std::pair<sycl::event, sycl::event>
271- py_topk (const dpctl::tensor::usm_ndarray &src,
272- const std::size_t k,
273- const bool largest,
274- const dpctl::tensor::usm_ndarray &vals,
275- const dpctl::tensor::usm_ndarray &inds,
276- sycl::queue &exec_q,
277- const std::vector<sycl::event> &depends)
278- {
279- int src_nd = src.get_ndim ();
280- int vals_nd = vals.get_ndim ();
281- int inds_nd = inds.get_ndim ();
282- if (vals_nd != 1 || inds_nd != 1 ) {
283- throw py::value_error (" Output arrays must be one-dimensional" );
284- }
285-
286- const py::ssize_t *src_shape_ptr = src.get_shape_raw ();
287- const py::ssize_t *vals_shape_ptr = vals.get_shape_raw ();
288- const py::ssize_t *inds_shape_ptr = inds.get_shape_raw ();
289-
290- std::size_t axis_nelems (1 );
291- for (int i = 0 ; i < src_nd; ++i) {
292- axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
293- }
294-
295- bool valid_k =
296- (axis_nelems >= k && static_cast <std::size_t >(vals_shape_ptr[0 ]) == k &&
297- static_cast <std::size_t >(inds_shape_ptr[0 ]) == k);
298- if (!valid_k) {
299- throw py::value_error (
300- " The value of k is invalid for the input and destination arrays" );
301- }
302-
303- if (!dpctl::utils::queues_are_compatible (exec_q, {src, vals, inds})) {
304- throw py::value_error (
305- " Execution queue is not compatible with allocation queues" );
306- }
307-
308- dpctl::tensor::validation::CheckWritable::throw_if_not_writable (vals);
309- dpctl::tensor::validation::CheckWritable::throw_if_not_writable (inds);
310-
311- if (axis_nelems == 0 ) {
312- // Nothing to do
313- return std::make_pair (sycl::event (), sycl::event ());
314- }
315-
316- auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
317- if (overlap (src, vals) || overlap (src, inds)) {
318- throw py::value_error (" Arrays index overlapping segments of memory" );
319- }
320-
321- dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (vals, k);
322-
323- dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (inds, k);
324-
325- int src_typenum = src.get_typenum ();
326- int vals_typenum = vals.get_typenum ();
327- int inds_typenum = inds.get_typenum ();
328-
329- const auto &array_types = td_ns::usm_ndarray_types ();
330- int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
331- int vals_typeid = array_types.typenum_to_lookup_id (vals_typenum);
332- int inds_typeid = array_types.typenum_to_lookup_id (inds_typenum);
333-
334- if (src_typeid != vals_typeid) {
335- throw py::value_error (" Input array and vals array must have "
336- " the same data type" );
337- }
338-
339- if (inds_typeid != static_cast <int >(td_ns::typenum_t ::INT64)) {
340- throw py::value_error (" Inds array must have data type int64" );
341- }
342-
343- bool is_src_c_contig = src.is_c_contiguous ();
344- bool is_vals_c_contig = vals.is_c_contiguous ();
345- bool is_inds_c_contig = inds.is_c_contiguous ();
346-
347- if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) {
348- static constexpr py::ssize_t zero_offset = py::ssize_t (0 );
349- static constexpr std::size_t iter_nelems = 1 ;
350-
351- auto fn = topk_dispatch_vector[src_typeid];
352-
353- sycl::event comp_ev =
354- fn (exec_q, iter_nelems, axis_nelems, k, largest, src.get_data (),
355- vals.get_data (), inds.get_data (), zero_offset, zero_offset,
356- zero_offset, zero_offset, zero_offset, zero_offset, depends);
357-
358- sycl::event keep_args_alive_ev =
359- dpctl::utils::keep_args_alive (exec_q, {src, vals, inds}, {comp_ev});
360-
361- return std::make_pair (keep_args_alive_ev, comp_ev);
362- }
363-
364- return std::make_pair (sycl::event (), sycl::event ());
365- }
366-
367289template <typename fnT, typename T> struct TopKFactory
368290{
369291 fnT get ()
@@ -385,26 +307,7 @@ void init_topk_functions(py::module_ m)
385307{
386308 dpctl::tensor::py_internal::init_topk_dispatch_vectors ();
387309
388- auto py_topk = [](const dpctl::tensor::usm_ndarray &src,
389- std::optional<const int > trailing_dims_to_search,
390- const std::size_t k, const bool largest,
391- const dpctl::tensor::usm_ndarray &vals,
392- const dpctl::tensor::usm_ndarray &inds,
393- sycl::queue &exec_q,
394- const std::vector<sycl::event> &depends)
395- -> std::pair<sycl::event, sycl::event> {
396- if (trailing_dims_to_search) {
397- return dpctl::tensor::py_internal::py_topk (
398- src, trailing_dims_to_search.value (), k, largest, vals, inds,
399- exec_q, depends);
400- }
401- else {
402- return dpctl::tensor::py_internal::py_topk (src, k, largest, vals,
403- inds, exec_q, depends);
404- }
405- };
406-
407- m.def (" _topk" , py_topk, py::arg (" src" ), py::arg (" trailing_dims_to_search" ),
310+ m.def (" _topk" , &py_topk, py::arg (" src" ), py::arg (" trailing_dims_to_search" ),
408311 py::arg (" k" ), py::arg (" largest" ), py::arg (" vals" ), py::arg (" inds" ),
409312 py::arg (" sycl_queue" ), py::arg (" depends" ) = py::list ());
410313}
0 commit comments