27
27
#include < algorithm>
28
28
#include < complex>
29
29
#include < cstdint>
30
- #include < iostream>
31
30
#include < pybind11/complex.h>
32
31
#include < pybind11/pybind11.h>
33
32
#include < pybind11/stl.h>
@@ -280,15 +279,50 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
280
279
}
281
280
}
282
281
282
+ /* Utility to parse python object py_ind into vector of `usm_ndarray`s */
283
+ std::vector<dpctl::tensor::usm_ndarray> parse_py_ind (const sycl::queue &q,
284
+ py::object py_ind)
285
+ {
286
+ size_t ind_count = py::len (py_ind);
287
+ std::vector<dpctl::tensor::usm_ndarray> res;
288
+ res.reserve (ind_count);
289
+
290
+ bool acquired = false ;
291
+ int nd = -1 ;
292
+ for (size_t i = 0 ; i < ind_count; ++i) {
293
+ auto el_i = py_ind[py::cast (i)];
294
+ auto arr_i = py::cast<dpctl::tensor::usm_ndarray>(el_i);
295
+ if (!dpctl::utils::queues_are_compatible (q, {arr_i})) {
296
+ throw py::value_error (" Index allocation queue is not compatible "
297
+ " with execution queue" );
298
+ }
299
+ if (acquired) {
300
+ if (nd != arr_i.get_ndim ()) {
301
+ throw py::value_error (
302
+ " Indices must have the same number of dimensions." );
303
+ }
304
+ }
305
+ else {
306
+ acquired = true ;
307
+ nd = arr_i.get_ndim ();
308
+ }
309
+ res.push_back (arr_i);
310
+ }
311
+
312
+ return res;
313
+ }
314
+
283
315
std::pair<sycl::event, sycl::event>
284
316
usm_ndarray_take (dpctl::tensor::usm_ndarray src,
285
- std::vector<dpctl::tensor::usm_ndarray> ind ,
317
+ py::object py_ind ,
286
318
dpctl::tensor::usm_ndarray dst,
287
319
int axis_start,
288
320
uint8_t mode,
289
321
sycl::queue exec_q,
290
322
const std::vector<sycl::event> &depends)
291
323
{
324
+ std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind (exec_q, py_ind);
325
+
292
326
int k = ind.size ();
293
327
294
328
if (k == 0 ) {
@@ -636,15 +670,12 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
636
670
std::to_string (ind_type_id));
637
671
}
638
672
639
- std::cout << " Submitting take" << std::endl;
640
673
sycl::event take_generic_ev =
641
674
fn (exec_q, orthog_nelems, ind_nelems, orthog_nd, ind_nd, k,
642
675
packed_shapes_strides, packed_axes_shapes_strides,
643
676
packed_ind_shapes_strides, src_data, dst_data, packed_ind_ptrs,
644
677
src_offset, dst_offset, packed_ind_offsets, all_deps);
645
678
646
- std::cout << " Submitting take clean-up host task" << std::endl;
647
-
648
679
// free packed temporaries
649
680
auto ctx = exec_q.get_context ();
650
681
exec_q.submit ([&](sycl::handler &cgh) {
@@ -661,19 +692,20 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
661
692
});
662
693
663
694
return std::make_pair (
664
- keep_args_alive (exec_q, {src, dst}, {take_generic_ev}),
695
+ keep_args_alive (exec_q, {src, py_ind, dst}, {take_generic_ev}),
665
696
take_generic_ev);
666
697
}
667
698
668
699
std::pair<sycl::event, sycl::event>
669
700
usm_ndarray_put (dpctl::tensor::usm_ndarray dst,
670
- std::vector<dpctl::tensor::usm_ndarray> ind ,
701
+ py::object py_ind ,
671
702
dpctl::tensor::usm_ndarray val,
672
703
int axis_start,
673
704
uint8_t mode,
674
705
sycl::queue exec_q,
675
706
const std::vector<sycl::event> &depends)
676
707
{
708
+ std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind (exec_q, py_ind);
677
709
int k = ind.size ();
678
710
679
711
if (k == 0 ) {
@@ -1046,8 +1078,9 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
1046
1078
});
1047
1079
});
1048
1080
1049
- return std::make_pair (keep_args_alive (exec_q, {dst, val}, {put_generic_ev}),
1050
- put_generic_ev);
1081
+ return std::make_pair (
1082
+ keep_args_alive (exec_q, {dst, py_ind, val}, {put_generic_ev}),
1083
+ put_generic_ev);
1051
1084
}
1052
1085
1053
1086
void init_advanced_indexing_dispatch_tables (void )
0 commit comments