@@ -94,24 +94,21 @@ struct use_radix_sort<
94
94
};
95
95
96
96
template <typename argTy, typename IndexTy>
97
- sycl::event
98
- topk_caller (sycl::queue &exec_q,
99
- std::size_t iter_nelems, // number of sub-arrays to sort (num. of
100
- // rows in a matrix when sorting over rows)
101
- std::size_t axis_nelems, // size of each array to sort (length of
102
- // rows, i.e. number of columns)
103
- std::size_t k,
104
- bool largest,
105
- const char *arg_cp,
106
- char *vals_cp,
107
- char *inds_cp,
108
- py::ssize_t iter_arg_offset,
109
- py::ssize_t iter_vals_offset,
110
- py::ssize_t iter_inds_offset,
111
- py::ssize_t axis_arg_offset,
112
- py::ssize_t axis_vals_offset,
113
- py::ssize_t axis_inds_offset,
114
- const std::vector<sycl::event> &depends)
97
+ sycl::event topk_caller (sycl::queue &exec_q,
98
+ std::size_t iter_nelems, // number of sub-arrays
99
+ std::size_t axis_nelems, // size of each sub-array
100
+ std::size_t k,
101
+ bool largest,
102
+ const char *arg_cp,
103
+ char *vals_cp,
104
+ char *inds_cp,
105
+ py::ssize_t iter_arg_offset,
106
+ py::ssize_t iter_vals_offset,
107
+ py::ssize_t iter_inds_offset,
108
+ py::ssize_t axis_arg_offset,
109
+ py::ssize_t axis_vals_offset,
110
+ py::ssize_t axis_inds_offset,
111
+ const std::vector<sycl::event> &depends)
115
112
{
116
113
if constexpr (use_radix_sort<argTy>::value) {
117
114
using dpctl::tensor::kernels::topk_radix_impl;
@@ -147,7 +144,7 @@ topk_caller(sycl::queue &exec_q,
147
144
148
145
std::pair<sycl::event, sycl::event>
149
146
py_topk (const dpctl::tensor::usm_ndarray &src,
150
- const int trailing_dims_to_search,
147
+ std::optional< const int > trailing_dims_to_search,
151
148
const std::size_t k,
152
149
const bool largest,
153
150
const dpctl::tensor::usm_ndarray &vals,
@@ -158,48 +155,70 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
158
155
int src_nd = src.get_ndim ();
159
156
int vals_nd = vals.get_ndim ();
160
157
int inds_nd = inds.get_ndim ();
161
- if (src_nd != vals_nd || src_nd != inds_nd) {
162
- throw py::value_error (" The input and output arrays must have "
163
- " the same array ranks" );
164
- }
165
- int iteration_nd = src_nd - trailing_dims_to_search;
166
- if (trailing_dims_to_search <= 0 || iteration_nd < 0 ) {
167
- throw py::value_error (
168
- " trailing_dims_to_search must be positive, but no "
169
- " greater than rank of the array being searched" );
170
- }
171
158
172
159
const py::ssize_t *src_shape_ptr = src.get_shape_raw ();
173
160
const py::ssize_t *vals_shape_ptr = vals.get_shape_raw ();
174
161
const py::ssize_t *inds_shape_ptr = inds.get_shape_raw ();
175
162
176
- bool same_shapes = true ;
163
+ std:: size_t axis_nelems ( 1 ) ;
177
164
std::size_t iter_nelems (1 );
178
- for (int i = 0 ; same_shapes && (i < iteration_nd); ++i) {
179
- auto src_shape_i = src_shape_ptr[i];
180
- same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] &&
181
- src_shape_i == inds_shape_ptr[i]);
182
- iter_nelems *= static_cast <std::size_t >(src_shape_i);
183
- }
165
+ if (trailing_dims_to_search.has_value ()) {
166
+ if (src_nd != vals_nd || src_nd != inds_nd) {
167
+ throw py::value_error (" The input and output arrays must have "
168
+ " the same array ranks" );
169
+ }
184
170
185
- if (!same_shapes) {
186
- throw py::value_error (
187
- " Destination shape does not match the input shape" );
188
- }
171
+ auto trailing_dims = trailing_dims_to_search.value ();
172
+ int iter_nd = src_nd - trailing_dims;
173
+ if (trailing_dims <= 0 || iter_nd < 0 ) {
174
+ throw py::value_error (
175
+ " trailing_dims_to_search must be positive, but no "
176
+ " greater than rank of the array being searched" );
177
+ }
189
178
190
- std::size_t vals_k (1 );
191
- std::size_t inds_k (1 );
192
- std::size_t axis_nelems (1 );
193
- for (int i = iteration_nd; i < src_nd; ++i) {
194
- axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
195
- vals_k *= static_cast <std::size_t >(vals_shape_ptr[i]);
196
- inds_k *= static_cast <std::size_t >(inds_shape_ptr[i]);
179
+ bool same_shapes = true ;
180
+ for (int i = 0 ; same_shapes && (i < iter_nd); ++i) {
181
+ auto src_shape_i = src_shape_ptr[i];
182
+ same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] &&
183
+ src_shape_i == inds_shape_ptr[i]);
184
+ iter_nelems *= static_cast <std::size_t >(src_shape_i);
185
+ }
186
+
187
+ if (!same_shapes) {
188
+ throw py::value_error (
189
+ " Destination shape does not match the input shape" );
190
+ }
191
+
192
+ std::size_t vals_k (1 );
193
+ std::size_t inds_k (1 );
194
+ for (int i = iter_nd; i < src_nd; ++i) {
195
+ axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
196
+ vals_k *= static_cast <std::size_t >(vals_shape_ptr[i]);
197
+ inds_k *= static_cast <std::size_t >(inds_shape_ptr[i]);
198
+ }
199
+
200
+ bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
201
+ if (!valid_k) {
202
+ throw py::value_error (" The value of k is invalid for the input and "
203
+ " destination arrays" );
204
+ }
197
205
}
206
+ else {
207
+ if (vals_nd != 1 || inds_nd != 1 ) {
208
+ throw py::value_error (" Output arrays must be one-dimensional" );
209
+ }
198
210
199
- bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k);
200
- if (!valid_k) {
201
- throw py::value_error (
202
- " The value of k is invalid for the input and destination arrays" );
211
+ for (int i = 0 ; i < src_nd; ++i) {
212
+ axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
213
+ }
214
+
215
+ bool valid_k = (axis_nelems >= k &&
216
+ static_cast <std::size_t >(vals_shape_ptr[0 ]) == k &&
217
+ static_cast <std::size_t >(inds_shape_ptr[0 ]) == k);
218
+ if (!valid_k) {
219
+ throw py::value_error (" The value of k is invalid for the input and "
220
+ " destination arrays" );
221
+ }
203
222
}
204
223
205
224
if (!dpctl::utils::queues_are_compatible (exec_q, {src, vals, inds})) {
@@ -267,103 +286,6 @@ py_topk(const dpctl::tensor::usm_ndarray &src,
267
286
return std::make_pair (sycl::event (), sycl::event ());
268
287
}
269
288
270
- std::pair<sycl::event, sycl::event>
271
- py_topk (const dpctl::tensor::usm_ndarray &src,
272
- const std::size_t k,
273
- const bool largest,
274
- const dpctl::tensor::usm_ndarray &vals,
275
- const dpctl::tensor::usm_ndarray &inds,
276
- sycl::queue &exec_q,
277
- const std::vector<sycl::event> &depends)
278
- {
279
- int src_nd = src.get_ndim ();
280
- int vals_nd = vals.get_ndim ();
281
- int inds_nd = inds.get_ndim ();
282
- if (vals_nd != 1 || inds_nd != 1 ) {
283
- throw py::value_error (" Output arrays must be one-dimensional" );
284
- }
285
-
286
- const py::ssize_t *src_shape_ptr = src.get_shape_raw ();
287
- const py::ssize_t *vals_shape_ptr = vals.get_shape_raw ();
288
- const py::ssize_t *inds_shape_ptr = inds.get_shape_raw ();
289
-
290
- std::size_t axis_nelems (1 );
291
- for (int i = 0 ; i < src_nd; ++i) {
292
- axis_nelems *= static_cast <std::size_t >(src_shape_ptr[i]);
293
- }
294
-
295
- bool valid_k =
296
- (axis_nelems >= k && static_cast <std::size_t >(vals_shape_ptr[0 ]) == k &&
297
- static_cast <std::size_t >(inds_shape_ptr[0 ]) == k);
298
- if (!valid_k) {
299
- throw py::value_error (
300
- " The value of k is invalid for the input and destination arrays" );
301
- }
302
-
303
- if (!dpctl::utils::queues_are_compatible (exec_q, {src, vals, inds})) {
304
- throw py::value_error (
305
- " Execution queue is not compatible with allocation queues" );
306
- }
307
-
308
- dpctl::tensor::validation::CheckWritable::throw_if_not_writable (vals);
309
- dpctl::tensor::validation::CheckWritable::throw_if_not_writable (inds);
310
-
311
- if (axis_nelems == 0 ) {
312
- // Nothing to do
313
- return std::make_pair (sycl::event (), sycl::event ());
314
- }
315
-
316
- auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
317
- if (overlap (src, vals) || overlap (src, inds)) {
318
- throw py::value_error (" Arrays index overlapping segments of memory" );
319
- }
320
-
321
- dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (vals, k);
322
-
323
- dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (inds, k);
324
-
325
- int src_typenum = src.get_typenum ();
326
- int vals_typenum = vals.get_typenum ();
327
- int inds_typenum = inds.get_typenum ();
328
-
329
- const auto &array_types = td_ns::usm_ndarray_types ();
330
- int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
331
- int vals_typeid = array_types.typenum_to_lookup_id (vals_typenum);
332
- int inds_typeid = array_types.typenum_to_lookup_id (inds_typenum);
333
-
334
- if (src_typeid != vals_typeid) {
335
- throw py::value_error (" Input array and vals array must have "
336
- " the same data type" );
337
- }
338
-
339
- if (inds_typeid != static_cast <int >(td_ns::typenum_t ::INT64)) {
340
- throw py::value_error (" Inds array must have data type int64" );
341
- }
342
-
343
- bool is_src_c_contig = src.is_c_contiguous ();
344
- bool is_vals_c_contig = vals.is_c_contiguous ();
345
- bool is_inds_c_contig = inds.is_c_contiguous ();
346
-
347
- if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) {
348
- static constexpr py::ssize_t zero_offset = py::ssize_t (0 );
349
- static constexpr std::size_t iter_nelems = 1 ;
350
-
351
- auto fn = topk_dispatch_vector[src_typeid];
352
-
353
- sycl::event comp_ev =
354
- fn (exec_q, iter_nelems, axis_nelems, k, largest, src.get_data (),
355
- vals.get_data (), inds.get_data (), zero_offset, zero_offset,
356
- zero_offset, zero_offset, zero_offset, zero_offset, depends);
357
-
358
- sycl::event keep_args_alive_ev =
359
- dpctl::utils::keep_args_alive (exec_q, {src, vals, inds}, {comp_ev});
360
-
361
- return std::make_pair (keep_args_alive_ev, comp_ev);
362
- }
363
-
364
- return std::make_pair (sycl::event (), sycl::event ());
365
- }
366
-
367
289
template <typename fnT, typename T> struct TopKFactory
368
290
{
369
291
fnT get ()
@@ -385,26 +307,7 @@ void init_topk_functions(py::module_ m)
385
307
{
386
308
dpctl::tensor::py_internal::init_topk_dispatch_vectors ();
387
309
388
- auto py_topk = [](const dpctl::tensor::usm_ndarray &src,
389
- std::optional<const int > trailing_dims_to_search,
390
- const std::size_t k, const bool largest,
391
- const dpctl::tensor::usm_ndarray &vals,
392
- const dpctl::tensor::usm_ndarray &inds,
393
- sycl::queue &exec_q,
394
- const std::vector<sycl::event> &depends)
395
- -> std::pair<sycl::event, sycl::event> {
396
- if (trailing_dims_to_search) {
397
- return dpctl::tensor::py_internal::py_topk (
398
- src, trailing_dims_to_search.value (), k, largest, vals, inds,
399
- exec_q, depends);
400
- }
401
- else {
402
- return dpctl::tensor::py_internal::py_topk (src, k, largest, vals,
403
- inds, exec_q, depends);
404
- }
405
- };
406
-
407
- m.def (" _topk" , py_topk, py::arg (" src" ), py::arg (" trailing_dims_to_search" ),
310
+ m.def (" _topk" , &py_topk, py::arg (" src" ), py::arg (" trailing_dims_to_search" ),
408
311
py::arg (" k" ), py::arg (" largest" ), py::arg (" vals" ), py::arg (" inds" ),
409
312
py::arg (" sycl_queue" ), py::arg (" depends" ) = py::list ());
410
313
}
0 commit comments