Skip to content

Commit a54f420

Browse files
committed
Use smart pointer utilities from dpctl
1 parent 763fc25 commit a54f420

File tree

1 file changed

+12
-48
lines changed

1 file changed

+12
-48
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 12 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,6 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
5454

5555
namespace py = pybind11;
5656

57-
/*
58-
Returns an std::unique_ptr wrapping a USM allocation and deleter.
59-
60-
Must still be manually freed by host_task when allocation is needed
61-
for duration of asynchronous kernel execution.
62-
*/
63-
template <typename T>
64-
auto usm_unique_ptr(std::size_t sz, sycl::queue &q)
65-
{
66-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
67-
auto deleter = [&q](T *usm) { sycl_free_noexcept(usm, q); };
68-
69-
return std::unique_ptr<T, decltype(deleter)>(sycl::malloc_device<T>(sz, q),
70-
deleter);
71-
}
72-
7357
std::vector<sycl::event>
7458
_populate_choose_kernel_params(sycl::queue &exec_q,
7559
std::vector<sycl::event> &host_task_events,
@@ -305,11 +289,8 @@ std::pair<sycl::event, sycl::event>
305289
std::to_string(src_type_id));
306290
}
307291

308-
auto packed_chc_ptrs = usm_unique_ptr<char *>(n_chcs, exec_q);
309-
if (packed_chc_ptrs.get() == nullptr) {
310-
throw std::runtime_error(
311-
"Unable to allocate packed_chc_ptrs device memory");
312-
}
292+
auto packed_chc_ptrs =
293+
dpctl::tensor::alloc_utils::smart_malloc_device<char *>(n_chcs, exec_q);
313294

314295
// packed_shapes_strides = [common shape,
315296
// src.strides,
@@ -318,17 +299,12 @@ std::pair<sycl::event, sycl::event>
318299
// ...,
319300
// chcs[n_chcs].strides]
320301
auto packed_shapes_strides =
321-
usm_unique_ptr<py::ssize_t>((3 + n_chcs) * sh_nelems, exec_q);
322-
if (packed_shapes_strides.get() == nullptr) {
323-
throw std::runtime_error(
324-
"Unable to allocate packed_shapes_strides device memory");
325-
}
302+
dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(
303+
(3 + n_chcs) * sh_nelems, exec_q);
326304

327-
auto packed_chc_offsets = usm_unique_ptr<py::ssize_t>(n_chcs, exec_q);
328-
if (packed_chc_offsets.get() == nullptr) {
329-
throw std::runtime_error(
330-
"Unable to allocate packed_chc_offsets device memory");
331-
}
305+
auto packed_chc_offsets =
306+
dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(n_chcs,
307+
exec_q);
332308

333309
std::vector<sycl::event> host_task_events;
334310
host_task_events.reserve(2);
@@ -370,23 +346,11 @@ std::pair<sycl::event, sycl::event>
370346
src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset,
371347
packed_chc_offsets.get(), all_deps);
372348

373-
// release usm_unique_ptrs
374-
auto chc_ptrs_ = packed_chc_ptrs.release();
375-
auto shapes_strides_ = packed_shapes_strides.release();
376-
auto chc_offsets_ = packed_chc_offsets.release();
377-
378-
// free packed temporaries
379-
sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
380-
cgh.depends_on(choose_generic_ev);
381-
const auto &ctx = exec_q.get_context();
382-
383-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
384-
cgh.host_task([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() {
385-
sycl_free_noexcept(chc_ptrs_, ctx);
386-
sycl_free_noexcept(shapes_strides_, ctx);
387-
sycl_free_noexcept(chc_offsets_, ctx);
388-
});
389-
});
349+
// async_smart_free releases owners
350+
sycl::event temporaries_cleanup_ev =
351+
dpctl::tensor::alloc_utils::async_smart_free(
352+
exec_q, {choose_generic_ev}, packed_chc_ptrs, packed_shapes_strides,
353+
packed_chc_offsets);
390354

391355
host_task_events.push_back(temporaries_cleanup_ev);
392356

0 commit comments

Comments
 (0)