Skip to content

Commit 8793b54

Browse files
committed
Use unique_ptrs for temporary device allocations in choose
Based on suggestions by @AlexanderKalistratov Create unique_ptr wraps a device allocation, which still needs to be manually freed after kernel run, but will be deallocated automatically during validation leading to launch
1 parent 6b96da3 commit 8793b54

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,22 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
5454

5555
namespace py = pybind11;
5656

57+
/*
58+
Returns an std::unique_ptr wrapping a USM allocation and deleter.
59+
60+
Must still be manually freed by host_task when allocation is needed
61+
for duration of asynchronous kernel execution.
62+
*/
63+
template <typename T>
64+
auto usm_unique_ptr(std::size_t sz, sycl::queue &q)
65+
{
66+
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
67+
auto deleter = [&q](T *usm) { sycl_free_noexcept(usm, q); };
68+
69+
return std::unique_ptr<T, decltype(deleter)>(sycl::malloc_device<T>(sz, q),
70+
deleter);
71+
}
72+
5773
std::vector<sycl::event>
5874
_populate_choose_kernel_params(sycl::queue &exec_q,
5975
std::vector<sycl::event> &host_task_events,
@@ -279,9 +295,16 @@ std::pair<sycl::event, sycl::event>
279295
chc_offsets.push_back(py::ssize_t(0));
280296
}
281297

282-
char **packed_chc_ptrs = sycl::malloc_device<char *>(n_chcs, exec_q);
298+
auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
299+
: choose_wrap_dispatch_table[src_type_id][chc_type_id];
283300

284-
if (packed_chc_ptrs == nullptr) {
301+
if (fn == nullptr) {
302+
throw std::runtime_error("Indices must be integer type, got " +
303+
std::to_string(src_type_id));
304+
}
305+
306+
auto packed_chc_ptrs = usm_unique_ptr<char *>(n_chcs, exec_q);
307+
if (packed_chc_ptrs.get() == nullptr) {
285308
throw std::runtime_error(
286309
"Unable to allocate packed_chc_ptrs device memory");
287310
}
@@ -292,23 +315,15 @@ std::pair<sycl::event, sycl::event>
292315
// chcs[0].strides,
293316
// ...,
294317
// chcs[n_chcs].strides]
295-
py::ssize_t *packed_shapes_strides =
296-
sycl::malloc_device<py::ssize_t>((3 + n_chcs) * sh_nelems, exec_q);
297-
298-
if (packed_shapes_strides == nullptr) {
299-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
300-
sycl_free_noexcept(packed_chc_ptrs, exec_q);
318+
auto packed_shapes_strides =
319+
usm_unique_ptr<py::ssize_t>((3 + n_chcs) * sh_nelems, exec_q);
320+
if (packed_shapes_strides.get() == nullptr) {
301321
throw std::runtime_error(
302322
"Unable to allocate packed_shapes_strides device memory");
303323
}
304324

305-
py::ssize_t *packed_chc_offsets =
306-
sycl::malloc_device<py::ssize_t>(n_chcs, exec_q);
307-
308-
if (packed_chc_offsets == nullptr) {
309-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
310-
sycl_free_noexcept(packed_chc_ptrs, exec_q);
311-
sycl_free_noexcept(packed_shapes_strides, exec_q);
325+
auto packed_chc_offsets = usm_unique_ptr<py::ssize_t>(n_chcs, exec_q);
326+
if (packed_chc_offsets.get() == nullptr) {
312327
throw std::runtime_error(
313328
"Unable to allocate packed_chc_offsets device memory");
314329
}
@@ -320,44 +335,37 @@ std::pair<sycl::event, sycl::event>
320335
host_task_events.reserve(2);
321336

322337
std::vector<sycl::event> pack_deps = _populate_choose_kernel_params(
323-
exec_q, host_task_events, packed_chc_ptrs, packed_shapes_strides,
324-
packed_chc_offsets, src_shape, sh_nelems, src_strides, dst_strides,
325-
chc_strides, chc_ptrs, chc_offsets, n_chcs);
338+
exec_q, host_task_events, packed_chc_ptrs.get(),
339+
packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
340+
sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs, chc_offsets,
341+
n_chcs);
326342

327343
std::vector<sycl::event> all_deps;
328344
all_deps.reserve(depends.size() + pack_deps.size());
329345
all_deps.insert(std::end(all_deps), std::begin(pack_deps),
330346
std::end(pack_deps));
331347
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
332348

333-
auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
334-
: choose_wrap_dispatch_table[src_type_id][chc_type_id];
335-
336-
if (fn == nullptr) {
337-
sycl::event::wait(host_task_events);
338-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
339-
sycl_free_noexcept(packed_chc_ptrs, exec_q);
340-
sycl_free_noexcept(packed_shapes_strides, exec_q);
341-
sycl_free_noexcept(packed_chc_offsets, exec_q);
342-
throw std::runtime_error("Indices must be integer type, got " +
343-
std::to_string(src_type_id));
344-
}
345-
346349
sycl::event choose_generic_ev =
347-
fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides, src_data,
348-
dst_data, packed_chc_ptrs, src_offset, dst_offset,
349-
packed_chc_offsets, all_deps);
350+
fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides.get(),
351+
src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset,
352+
packed_chc_offsets.get(), all_deps);
353+
354+
// release usm_unique_ptrs
355+
auto chc_ptrs_ = packed_chc_ptrs.release();
356+
auto shapes_strides_ = packed_shapes_strides.release();
357+
auto chc_offsets_ = packed_chc_offsets.release();
350358

351359
// free packed temporaries
352360
sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
353361
cgh.depends_on(choose_generic_ev);
354362
const auto &ctx = exec_q.get_context();
363+
355364
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
356-
cgh.host_task([packed_shapes_strides, packed_chc_ptrs,
357-
packed_chc_offsets, ctx]() {
358-
sycl_free_noexcept(packed_shapes_strides, ctx);
359-
sycl_free_noexcept(packed_chc_ptrs, ctx);
360-
sycl_free_noexcept(packed_chc_offsets, ctx);
365+
cgh.host_task([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() {
366+
sycl_free_noexcept(chc_ptrs_, ctx);
367+
sycl_free_noexcept(shapes_strides_, ctx);
368+
sycl_free_noexcept(chc_offsets_, ctx);
361369
});
362370
});
363371

0 commit comments

Comments
 (0)