Skip to content

Commit dad41e3

Browse files
committed
Break up _populate_choose_kernel_params
py_choose now uses multiple functions to move host data for choose kernel parameters to the device
1 parent 4e6c864 commit dad41e3

File tree

1 file changed

+114
-71
lines changed

1 file changed

+114
-71
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 114 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -54,77 +54,112 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
5454

5555
namespace py = pybind11;
5656

57-
std::vector<sycl::event>
58-
_populate_choose_kernel_params(sycl::queue &exec_q,
59-
std::vector<sycl::event> &host_task_events,
60-
char **device_chc_ptrs,
61-
py::ssize_t *device_shape_strides,
62-
py::ssize_t *device_chc_offsets,
63-
const py::ssize_t *shape,
64-
int shape_len,
65-
std::vector<py::ssize_t> &inp_strides,
66-
std::vector<py::ssize_t> &dst_strides,
67-
std::vector<py::ssize_t> &chc_strides,
68-
std::vector<char *> &chc_ptrs,
69-
std::vector<py::ssize_t> &chc_offsets,
70-
py::ssize_t n_chcs)
57+
namespace detail
7158
{
72-
using ptr_host_allocator_T =
73-
dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
74-
using ptrT = std::vector<char *, ptr_host_allocator_T>;
7559

76-
ptr_host_allocator_T ptr_allocator(exec_q);
77-
std::shared_ptr<ptrT> host_chc_ptrs_shp =
78-
std::make_shared<ptrT>(n_chcs, ptr_allocator);
60+
using host_ptrs_allocator_t =
61+
dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
62+
using ptrs_t = std::vector<char *, host_ptrs_allocator_t>;
63+
using host_ptrs_shp_t = std::shared_ptr<ptrs_t>;
7964

80-
using usm_host_allocatorT =
81-
dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t>;
82-
using shT = std::vector<py::ssize_t, usm_host_allocatorT>;
65+
host_ptrs_shp_t make_host_ptrs(sycl::queue &exec_q,
66+
const std::vector<char *> &ptrs)
67+
{
68+
host_ptrs_allocator_t ptrs_allocator(exec_q);
69+
host_ptrs_shp_t host_ptrs_shp =
70+
std::make_shared<ptrs_t>(ptrs.size(), ptrs_allocator);
71+
72+
std::copy(ptrs.begin(), ptrs.end(), host_ptrs_shp->begin());
73+
74+
return host_ptrs_shp;
75+
}
76+
77+
using host_sz_allocator_t =
78+
dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t>;
79+
using sz_t = std::vector<py::ssize_t, host_sz_allocator_t>;
80+
using host_sz_shp_t = std::shared_ptr<sz_t>;
8381

84-
usm_host_allocatorT sz_allocator(exec_q);
85-
std::shared_ptr<shT> host_shape_strides_shp =
86-
std::make_shared<shT>(shape_len * (3 + n_chcs), sz_allocator);
82+
host_sz_shp_t make_host_offsets(sycl::queue &exec_q,
83+
const std::vector<py::ssize_t> &offsets)
84+
{
85+
host_sz_allocator_t offsets_allocator(exec_q);
86+
host_sz_shp_t host_offsets_shp =
87+
std::make_shared<sz_t>(offsets.size(), offsets_allocator);
88+
89+
std::copy(offsets.begin(), offsets.end(), host_offsets_shp->begin());
8790

88-
std::shared_ptr<shT> host_chc_offsets_shp =
89-
std::make_shared<shT>(n_chcs, sz_allocator);
91+
return host_offsets_shp;
92+
}
93+
94+
host_sz_shp_t make_host_shape_strides(sycl::queue &exec_q,
95+
py::ssize_t n_chcs,
96+
std::vector<py::ssize_t> &shape,
97+
std::vector<py::ssize_t> &inp_strides,
98+
std::vector<py::ssize_t> &dst_strides,
99+
std::vector<py::ssize_t> &chc_strides)
100+
{
101+
auto nelems = shape.size();
102+
host_sz_allocator_t shape_strides_allocator(exec_q);
103+
host_sz_shp_t host_shape_strides_shp =
104+
std::make_shared<sz_t>(nelems * (3 + n_chcs), shape_strides_allocator);
90105

91-
std::copy(shape, shape + shape_len, host_shape_strides_shp->begin());
106+
std::copy(shape.begin(), shape.end(), host_shape_strides_shp->begin());
92107
std::copy(inp_strides.begin(), inp_strides.end(),
93-
host_shape_strides_shp->begin() + shape_len);
108+
host_shape_strides_shp->begin() + nelems);
94109
std::copy(dst_strides.begin(), dst_strides.end(),
95-
host_shape_strides_shp->begin() + 2 * shape_len);
110+
host_shape_strides_shp->begin() + 2 * nelems);
96111
std::copy(chc_strides.begin(), chc_strides.end(),
97-
host_shape_strides_shp->begin() + 3 * shape_len);
112+
host_shape_strides_shp->begin() + 3 * nelems);
98113

99-
std::copy(chc_ptrs.begin(), chc_ptrs.end(), host_chc_ptrs_shp->begin());
100-
std::copy(chc_offsets.begin(), chc_offsets.end(),
101-
host_chc_offsets_shp->begin());
114+
return host_shape_strides_shp;
115+
}
102116

103-
const sycl::event &device_chc_ptrs_copy_ev = exec_q.copy<char *>(
104-
host_chc_ptrs_shp->data(), device_chc_ptrs, host_chc_ptrs_shp->size());
117+
/* This function expects a queue and a non-trivial number of
118+
std::pairs of raw device pointers and host shared pointers
119+
(structured as <device_ptr, shared_ptr>),
120+
then enqueues a copy of the host shared pointer data into
121+
the device pointer.
122+
123+
Assumes the device pointer addresses sufficient memory for
124+
the size of the host memory.
125+
*/
126+
template <typename... DevHostPairs>
127+
std::vector<sycl::event> batched_copy(sycl::queue &exec_q,
128+
DevHostPairs &&...dev_host_pairs)
129+
{
130+
constexpr std::size_t n = sizeof...(DevHostPairs);
131+
static_assert(n > 0, "batched_copy requires at least one argument");
105132

106-
const sycl::event &device_shape_strides_copy_ev = exec_q.copy<py::ssize_t>(
107-
host_shape_strides_shp->data(), device_shape_strides,
108-
host_shape_strides_shp->size());
133+
std::vector<sycl::event> copy_evs;
134+
copy_evs.reserve(n);
135+
(copy_evs.emplace_back(exec_q.copy(dev_host_pairs.second->data(),
136+
dev_host_pairs.first,
137+
dev_host_pairs.second->size())),
138+
...);
109139

110-
const sycl::event &device_chc_offsets_copy_ev = exec_q.copy<py::ssize_t>(
111-
host_chc_offsets_shp->data(), device_chc_offsets,
112-
host_chc_offsets_shp->size());
140+
return copy_evs;
141+
}
142+
143+
/* This function takes as input a queue, sycl::event dependencies,
144+
and a non-trivial number of shared_ptrs and moves them into
145+
a host_task lambda capture, ensuring their lifetime until the
146+
host_task executes.
147+
*/
148+
template <typename... Shps>
149+
sycl::event async_shp_free(sycl::queue &exec_q,
150+
const std::vector<sycl::event> &depends,
151+
Shps &&...shps)
152+
{
153+
constexpr std::size_t n = sizeof...(Shps);
154+
static_assert(n > 0, "async_shp_free requires at least one argument");
113155

114156
const sycl::event &shared_ptr_cleanup_ev =
115157
exec_q.submit([&](sycl::handler &cgh) {
116-
cgh.depends_on({device_chc_offsets_copy_ev,
117-
device_shape_strides_copy_ev,
118-
device_chc_ptrs_copy_ev});
119-
cgh.host_task([host_chc_offsets_shp, host_shape_strides_shp,
120-
host_chc_ptrs_shp]() {});
158+
cgh.depends_on(depends);
159+
cgh.host_task([capture = std::tuple(std::move(shps)...)]() {});
121160
});
122-
host_task_events.push_back(shared_ptr_cleanup_ev);
123161

124-
std::vector<sycl::event> param_pack_deps{device_chc_ptrs_copy_ev,
125-
device_shape_strides_copy_ev,
126-
device_chc_offsets_copy_ev};
127-
return param_pack_deps;
162+
return shared_ptr_cleanup_ev;
128163
}
129164

130165
// copied from dpctl, remove if a similar utility is ever exposed
@@ -149,6 +184,8 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
149184
return res;
150185
}
151186

187+
} // namespace detail
188+
152189
std::pair<sycl::event, sycl::event>
153190
py_choose(const dpctl::tensor::usm_ndarray &src,
154191
const py::object &py_chcs,
@@ -158,7 +195,7 @@ std::pair<sycl::event, sycl::event>
158195
const std::vector<sycl::event> &depends)
159196
{
160197
std::vector<dpctl::tensor::usm_ndarray> chcs =
161-
parse_py_chcs(exec_q, py_chcs);
198+
detail::parse_py_chcs(exec_q, py_chcs);
162199

163200
// Python list max size must fit into py_ssize_t
164201
py::ssize_t n_chcs = chcs.size();
@@ -310,31 +347,37 @@ std::pair<sycl::event, sycl::event>
310347
host_task_events.reserve(2);
311348

312349
std::vector<sycl::event> pack_deps;
350+
std::vector<py::ssize_t> common_shape;
351+
std::vector<py::ssize_t> src_strides;
352+
std::vector<py::ssize_t> dst_strides;
313353
if (nd == 0) {
314354
// special case where all inputs are scalars
315355
// need to pass src, dst shape=1 and strides=0
316356
// chc_strides already initialized to 0 so ignore
317-
std::array<py::ssize_t, 1> scalar_sh{1};
318-
std::vector<py::ssize_t> src_strides{0};
319-
std::vector<py::ssize_t> dst_strides{0};
320-
321-
pack_deps = _populate_choose_kernel_params(
322-
exec_q, host_task_events, packed_chc_ptrs.get(),
323-
packed_shapes_strides.get(), packed_chc_offsets.get(),
324-
scalar_sh.data(), sh_nelems, src_strides, dst_strides, chc_strides,
325-
chc_ptrs, chc_offsets, n_chcs);
357+
common_shape = {1};
358+
src_strides = {0};
359+
dst_strides = {0};
326360
}
327361
else {
328-
auto src_strides = src.get_strides_vector();
329-
auto dst_strides = dst.get_strides_vector();
330-
331-
pack_deps = _populate_choose_kernel_params(
332-
exec_q, host_task_events, packed_chc_ptrs.get(),
333-
packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
334-
sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs,
335-
chc_offsets, n_chcs);
362+
common_shape = src.get_shape_vector();
363+
src_strides = src.get_strides_vector();
364+
dst_strides = dst.get_strides_vector();
336365
}
337366

367+
auto host_chc_ptrs = detail::make_host_ptrs(exec_q, chc_ptrs);
368+
auto host_chc_offsets = detail::make_host_offsets(exec_q, chc_offsets);
369+
auto host_shape_strides = detail::make_host_shape_strides(
370+
exec_q, n_chcs, common_shape, src_strides, dst_strides, chc_strides);
371+
372+
pack_deps = detail::batched_copy(
373+
exec_q, std::make_pair(packed_chc_ptrs.get(), host_chc_ptrs),
374+
std::make_pair(packed_chc_offsets.get(), host_chc_offsets),
375+
std::make_pair(packed_shapes_strides.get(), host_shape_strides));
376+
377+
host_task_events.push_back(
378+
detail::async_shp_free(exec_q, pack_deps, host_chc_ptrs,
379+
host_chc_offsets, host_shape_strides));
380+
338381
std::vector<sycl::event> all_deps;
339382
all_deps.reserve(depends.size() + pack_deps.size());
340383
all_deps.insert(std::end(all_deps), std::begin(pack_deps),

0 commit comments

Comments
 (0)