Skip to content

Commit 24d7839

Browse files
committed
Prevent dangling host tasks in indexing functions
- Host tasks are now collected and kept alive
1 parent 56bb65f commit 24d7839

File tree

1 file changed

+45
-28
lines changed

1 file changed

+45
-28
lines changed

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ using dpctl::utils::keep_args_alive;
7070

7171
std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
7272
sycl::queue exec_q,
73+
std::vector<sycl::event> &host_task_events,
7374
py::ssize_t *device_orthog_shapes_strides,
7475
py::ssize_t *device_axes_shapes_strides,
7576
const py::ssize_t *inp_shape,
@@ -210,20 +211,21 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
210211
exec_q.copy<py::ssize_t>(packed_host_shapes_strides_shp->data(),
211212
device_orthog_shapes_strides,
212213
packed_host_shapes_strides_shp->size());
213-
exec_q.submit([&](sycl::handler &cgh) {
214-
cgh.depends_on(device_orthog_shapes_strides_copy_ev);
215-
cgh.host_task([packed_host_shapes_strides_shp] {});
216-
});
217214

218215
sycl::event device_axes_shapes_strides_copy_ev =
219216
exec_q.copy<py::ssize_t>(
220217
packed_host_axes_shapes_strides_shp->data(),
221218
device_axes_shapes_strides,
222219
packed_host_axes_shapes_strides_shp->size());
223-
exec_q.submit([&](sycl::handler &cgh) {
224-
cgh.depends_on(device_axes_shapes_strides_copy_ev);
225-
cgh.host_task([packed_host_axes_shapes_strides_shp]() {});
226-
});
220+
221+
sycl::event clean_up_host_task_ev =
222+
exec_q.submit([&](sycl::handler &cgh) {
223+
cgh.depends_on(device_axes_shapes_strides_copy_ev);
224+
cgh.depends_on(device_orthog_shapes_strides_copy_ev);
225+
cgh.host_task([packed_host_axes_shapes_strides_shp,
226+
packed_host_shapes_strides_shp]() {});
227+
});
228+
host_task_events.push_back(clean_up_host_task_ev);
227229

228230
std::vector<sycl::event> v = {device_orthog_shapes_strides_copy_ev,
229231
device_axes_shapes_strides_copy_ev};
@@ -268,10 +270,13 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
268270
packed_host_axes_shapes_strides_shp->data(),
269271
device_axes_shapes_strides,
270272
packed_host_axes_shapes_strides_shp->size());
271-
exec_q.submit([&](sycl::handler &cgh) {
272-
cgh.depends_on(device_axes_shapes_strides_copy_ev);
273-
cgh.host_task([packed_host_axes_shapes_strides_shp]() {});
274-
});
273+
274+
sycl::event clean_up_host_task_ev =
275+
exec_q.submit([&](sycl::handler &cgh) {
276+
cgh.depends_on(device_axes_shapes_strides_copy_ev);
277+
cgh.host_task([packed_host_axes_shapes_strides_shp]() {});
278+
});
279+
host_task_events.push_back(clean_up_host_task_ev);
275280

276281
std::vector<sycl::event> v = {device_orthog_shapes_strides_fill_ev,
277282
device_axes_shapes_strides_copy_ev};
@@ -590,28 +595,33 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
590595
std::copy(ind_offsets.begin(), ind_offsets.end(),
591596
host_ind_offsets_shp->begin());
592597

598+
std::vector<sycl::event> host_task_events(5);
599+
593600
sycl::event packed_ind_ptrs_copy_ev = exec_q.copy<char *>(
594601
host_ind_ptrs_shp->data(), packed_ind_ptrs, host_ind_ptrs_shp->size());
595-
exec_q.submit([&](sycl::handler &cgh) {
602+
sycl::event ind_ptrs_host_task = exec_q.submit([&](sycl::handler &cgh) {
596603
cgh.depends_on(packed_ind_ptrs_copy_ev);
597604
cgh.host_task([host_ind_ptrs_shp]() {});
598605
});
606+
host_task_events.push_back(ind_ptrs_host_task);
599607

600608
sycl::event packed_ind_shapes_strides_copy_ev = exec_q.copy<py::ssize_t>(
601609
host_ind_shapes_strides_shp->data(), packed_ind_shapes_strides,
602610
host_ind_shapes_strides_shp->size());
603-
exec_q.submit([&](sycl::handler &cgh) {
611+
sycl::event ind_sh_st_host_task = exec_q.submit([&](sycl::handler &cgh) {
604612
cgh.depends_on(packed_ind_shapes_strides_copy_ev);
605613
cgh.host_task([host_ind_shapes_strides_shp]() {});
606614
});
615+
host_task_events.push_back(ind_sh_st_host_task);
607616

608617
sycl::event packed_ind_offsets_copy_ev = exec_q.copy<py::ssize_t>(
609618
host_ind_offsets_shp->data(), packed_ind_offsets,
610619
host_ind_offsets_shp->size());
611-
exec_q.submit([&](sycl::handler &cgh) {
620+
sycl::event ind_offsets_host_task = exec_q.submit([&](sycl::handler &cgh) {
612621
cgh.depends_on(packed_ind_offsets_copy_ev);
613622
cgh.host_task([host_ind_offsets_shp]() {});
614623
});
624+
host_task_events.push_back(ind_offsets_host_task);
615625

616626
std::vector<sycl::event> ind_pack_depends{packed_ind_ptrs_copy_ev,
617627
packed_ind_shapes_strides_copy_ev,
@@ -650,10 +660,10 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
650660

651661
std::vector<sycl::event> src_dst_pack_deps =
652662
_populate_packed_shapes_strides_for_indexing(
653-
exec_q, packed_shapes_strides, packed_axes_shapes_strides,
654-
src_shape, src_strides, is_src_c_contig, is_src_f_contig, dst_shape,
655-
dst_strides, is_dst_c_contig, is_dst_f_contig, axis_start, k,
656-
ind_nd, src_nd, dst_nd);
663+
exec_q, host_task_events, packed_shapes_strides,
664+
packed_axes_shapes_strides, src_shape, src_strides, is_src_c_contig,
665+
is_src_f_contig, dst_shape, dst_strides, is_dst_c_contig,
666+
is_dst_f_contig, axis_start, k, ind_nd, src_nd, dst_nd);
657667

658668
std::vector<sycl::event> all_deps(depends.size() + ind_pack_depends.size() +
659669
src_dst_pack_deps.size());
@@ -690,9 +700,10 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
690700
sycl::free(packed_ind_offsets, ctx);
691701
});
692702
});
703+
host_task_events.push_back(take_generic_ev);
693704

694705
sycl::event host_task_ev =
695-
keep_args_alive(exec_q, {src, py_ind, dst}, {take_generic_ev});
706+
keep_args_alive(exec_q, {src, py_ind, dst}, host_task_events);
696707

697708
return std::make_pair(host_task_ev, take_generic_ev);
698709
}
@@ -977,28 +988,33 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
977988
std::copy(ind_offsets.begin(), ind_offsets.end(),
978989
host_ind_offsets_shp->begin());
979990

991+
std::vector<sycl::event> host_task_events(5);
992+
980993
sycl::event device_ind_ptrs_copy_ev = exec_q.copy<char *>(
981994
host_ind_ptrs_shp->data(), packed_ind_ptrs, host_ind_ptrs_shp->size());
982-
exec_q.submit([&](sycl::handler &cgh) {
995+
sycl::event ind_ptrs_host_task = exec_q.submit([&](sycl::handler &cgh) {
983996
cgh.depends_on(device_ind_ptrs_copy_ev);
984997
cgh.host_task([host_ind_ptrs_shp]() {});
985998
});
999+
host_task_events.push_back(ind_ptrs_host_task);
9861000

9871001
sycl::event device_ind_shapes_strides_copy_ev = exec_q.copy<py::ssize_t>(
9881002
host_ind_shapes_strides_shp->data(), packed_ind_shapes_strides,
9891003
host_ind_shapes_strides_shp->size());
990-
exec_q.submit([&](sycl::handler &cgh) {
1004+
sycl::event ind_sh_st_host_task = exec_q.submit([&](sycl::handler &cgh) {
9911005
cgh.depends_on(device_ind_shapes_strides_copy_ev);
9921006
cgh.host_task([host_ind_shapes_strides_shp]() {});
9931007
});
1008+
host_task_events.push_back(ind_sh_st_host_task);
9941009

9951010
sycl::event device_ind_offsets_copy_ev = exec_q.copy<py::ssize_t>(
9961011
host_ind_offsets_shp->data(), packed_ind_offsets,
9971012
host_ind_offsets_shp->size());
998-
exec_q.submit([&](sycl::handler &cgh) {
1013+
sycl::event ind_offsets_host_task = exec_q.submit([&](sycl::handler &cgh) {
9991014
cgh.depends_on(device_ind_offsets_copy_ev);
10001015
cgh.host_task([host_ind_offsets_shp]() {});
10011016
});
1017+
host_task_events.push_back(ind_offsets_host_task);
10021018

10031019
std::vector<sycl::event> ind_pack_depends{device_ind_ptrs_copy_ev,
10041020
device_ind_shapes_strides_copy_ev,
@@ -1037,10 +1053,10 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10371053

10381054
std::vector<sycl::event> copy_shapes_strides_deps =
10391055
_populate_packed_shapes_strides_for_indexing(
1040-
exec_q, packed_shapes_strides, packed_axes_shapes_strides,
1041-
dst_shape, dst_strides, is_dst_c_contig, is_dst_f_contig, val_shape,
1042-
val_strides, is_val_c_contig, is_val_f_contig, axis_start, k,
1043-
ind_nd, dst_nd, val_nd);
1056+
exec_q, host_task_events, packed_shapes_strides,
1057+
packed_axes_shapes_strides, dst_shape, dst_strides, is_dst_c_contig,
1058+
is_dst_f_contig, val_shape, val_strides, is_val_c_contig,
1059+
is_val_f_contig, axis_start, k, ind_nd, dst_nd, val_nd);
10441060

10451061
std::vector<sycl::event> all_deps(depends.size() +
10461062
copy_shapes_strides_deps.size() +
@@ -1078,9 +1094,10 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10781094
sycl::free(packed_ind_offsets, ctx);
10791095
});
10801096
});
1097+
host_task_events.push_back(put_generic_ev);
10811098

10821099
return std::make_pair(
1083-
keep_args_alive(exec_q, {dst, py_ind, val}, {put_generic_ev}),
1100+
keep_args_alive(exec_q, {dst, py_ind, val}, host_task_events),
10841101
put_generic_ev);
10851102
}
10861103

0 commit comments

Comments
 (0)