Skip to content

Commit 84ba81a

Browse files
Ensure that indices are also kept alive
1 parent 51e0fbb commit 84ba81a

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include <algorithm>
2828
#include <complex>
2929
#include <cstdint>
30-
#include <iostream>
3130
#include <pybind11/complex.h>
3231
#include <pybind11/pybind11.h>
3332
#include <pybind11/stl.h>
@@ -280,15 +279,50 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
280279
}
281280
}
282281

282+
/* Utility to parse python object py_ind into vector of `usm_ndarray`s */
283+
std::vector<dpctl::tensor::usm_ndarray> parse_py_ind(const sycl::queue &q,
284+
py::object py_ind)
285+
{
286+
size_t ind_count = py::len(py_ind);
287+
std::vector<dpctl::tensor::usm_ndarray> res;
288+
res.reserve(ind_count);
289+
290+
bool acquired = false;
291+
int nd = -1;
292+
for (size_t i = 0; i < ind_count; ++i) {
293+
auto el_i = py_ind[py::cast(i)];
294+
auto arr_i = py::cast<dpctl::tensor::usm_ndarray>(el_i);
295+
if (!dpctl::utils::queues_are_compatible(q, {arr_i})) {
296+
throw py::value_error("Index allocation queue is not compatible "
297+
"with execution queue");
298+
}
299+
if (acquired) {
300+
if (nd != arr_i.get_ndim()) {
301+
throw py::value_error(
302+
"Indices must have the same number of dimensions.");
303+
}
304+
}
305+
else {
306+
acquired = true;
307+
nd = arr_i.get_ndim();
308+
}
309+
res.push_back(arr_i);
310+
}
311+
312+
return res;
313+
}
314+
283315
std::pair<sycl::event, sycl::event>
284316
usm_ndarray_take(dpctl::tensor::usm_ndarray src,
285-
std::vector<dpctl::tensor::usm_ndarray> ind,
317+
py::object py_ind,
286318
dpctl::tensor::usm_ndarray dst,
287319
int axis_start,
288320
uint8_t mode,
289321
sycl::queue exec_q,
290322
const std::vector<sycl::event> &depends)
291323
{
324+
std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind(exec_q, py_ind);
325+
292326
int k = ind.size();
293327

294328
if (k == 0) {
@@ -636,15 +670,12 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
636670
std::to_string(ind_type_id));
637671
}
638672

639-
std::cout << "Submitting take" << std::endl;
640673
sycl::event take_generic_ev =
641674
fn(exec_q, orthog_nelems, ind_nelems, orthog_nd, ind_nd, k,
642675
packed_shapes_strides, packed_axes_shapes_strides,
643676
packed_ind_shapes_strides, src_data, dst_data, packed_ind_ptrs,
644677
src_offset, dst_offset, packed_ind_offsets, all_deps);
645678

646-
std::cout << "Submitting take clean-up host task" << std::endl;
647-
648679
// free packed temporaries
649680
auto ctx = exec_q.get_context();
650681
exec_q.submit([&](sycl::handler &cgh) {
@@ -661,19 +692,20 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
661692
});
662693

663694
return std::make_pair(
664-
keep_args_alive(exec_q, {src, dst}, {take_generic_ev}),
695+
keep_args_alive(exec_q, {src, py_ind, dst}, {take_generic_ev}),
665696
take_generic_ev);
666697
}
667698

668699
std::pair<sycl::event, sycl::event>
669700
usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
670-
std::vector<dpctl::tensor::usm_ndarray> ind,
701+
py::object py_ind,
671702
dpctl::tensor::usm_ndarray val,
672703
int axis_start,
673704
uint8_t mode,
674705
sycl::queue exec_q,
675706
const std::vector<sycl::event> &depends)
676707
{
708+
std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind(exec_q, py_ind);
677709
int k = ind.size();
678710

679711
if (k == 0) {
@@ -1046,8 +1078,9 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10461078
});
10471079
});
10481080

1049-
return std::make_pair(keep_args_alive(exec_q, {dst, val}, {put_generic_ev}),
1050-
put_generic_ev);
1081+
return std::make_pair(
1082+
keep_args_alive(exec_q, {dst, py_ind, val}, {put_generic_ev}),
1083+
put_generic_ev);
10511084
}
10521085

10531086
void init_advanced_indexing_dispatch_tables(void)

dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,24 @@ namespace py_internal
3838
{
3939

4040
extern std::pair<sycl::event, sycl::event>
41-
usm_ndarray_take(dpctl::tensor::usm_ndarray src,
42-
std::vector<dpctl::tensor::usm_ndarray> ind,
43-
dpctl::tensor::usm_ndarray dst,
44-
int axis_start,
45-
uint8_t mode,
46-
sycl::queue exec_q,
47-
const std::vector<sycl::event> &depends = {});
41+
usm_ndarray_take(dpctl::tensor::usm_ndarray,
42+
py::object,
43+
dpctl::tensor::usm_ndarray,
44+
int,
45+
uint8_t,
46+
sycl::queue,
47+
const std::vector<sycl::event> & = {});
4848

4949
extern std::pair<sycl::event, sycl::event>
50-
usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
51-
std::vector<dpctl::tensor::usm_ndarray> ind,
52-
dpctl::tensor::usm_ndarray val,
53-
int axis_start,
54-
uint8_t mode,
55-
sycl::queue exec_q,
56-
const std::vector<sycl::event> &depends = {});
50+
usm_ndarray_put(dpctl::tensor::usm_ndarray,
51+
py::object,
52+
dpctl::tensor::usm_ndarray,
53+
int,
54+
uint8_t,
55+
sycl::queue,
56+
const std::vector<sycl::event> & = {});
5757

58-
extern void init_advanced_indexing_dispatch_tables();
58+
extern void init_advanced_indexing_dispatch_tables(void);
5959

6060
} // namespace py_internal
6161
} // namespace tensor

0 commit comments

Comments
 (0)