@@ -54,22 +54,6 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
5454
5555namespace py = pybind11;
5656
57- /*
58- Returns an std::unique_ptr wrapping a USM allocation and deleter.
59-
60- Must still be manually freed by host_task when allocation is needed
61- for duration of asynchronous kernel execution.
62- */
63- template <typename T>
64- auto usm_unique_ptr (std::size_t sz, sycl::queue &q)
65- {
66- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
67- auto deleter = [&q](T *usm) { sycl_free_noexcept (usm, q); };
68-
69- return std::unique_ptr<T, decltype (deleter)>(sycl::malloc_device<T>(sz, q),
70- deleter);
71- }
72-
7357std::vector<sycl::event>
7458 _populate_choose_kernel_params (sycl::queue &exec_q,
7559 std::vector<sycl::event> &host_task_events,
@@ -305,11 +289,8 @@ std::pair<sycl::event, sycl::event>
305289 std::to_string (src_type_id));
306290 }
307291
308- auto packed_chc_ptrs = usm_unique_ptr<char *>(n_chcs, exec_q);
309- if (packed_chc_ptrs.get () == nullptr ) {
310- throw std::runtime_error (
311- " Unable to allocate packed_chc_ptrs device memory" );
312- }
292+ auto packed_chc_ptrs =
293+ dpctl::tensor::alloc_utils::smart_malloc_device<char *>(n_chcs, exec_q);
313294
314295 // packed_shapes_strides = [common shape,
315296 // src.strides,
@@ -318,17 +299,12 @@ std::pair<sycl::event, sycl::event>
318299 // ...,
319300 // chcs[n_chcs].strides]
320301 auto packed_shapes_strides =
321- usm_unique_ptr<py::ssize_t >((3 + n_chcs) * sh_nelems, exec_q);
322- if (packed_shapes_strides.get () == nullptr ) {
323- throw std::runtime_error (
324- " Unable to allocate packed_shapes_strides device memory" );
325- }
302+ dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t >(
303+ (3 + n_chcs) * sh_nelems, exec_q);
326304
327- auto packed_chc_offsets = usm_unique_ptr<py::ssize_t >(n_chcs, exec_q);
328- if (packed_chc_offsets.get () == nullptr ) {
329- throw std::runtime_error (
330- " Unable to allocate packed_chc_offsets device memory" );
331- }
305+ auto packed_chc_offsets =
306+ dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t >(n_chcs,
307+ exec_q);
332308
333309 std::vector<sycl::event> host_task_events;
334310 host_task_events.reserve (2 );
@@ -370,23 +346,11 @@ std::pair<sycl::event, sycl::event>
370346 src_data, dst_data, packed_chc_ptrs.get (), src_offset, dst_offset,
371347 packed_chc_offsets.get (), all_deps);
372348
373- // release usm_unique_ptrs
374- auto chc_ptrs_ = packed_chc_ptrs.release ();
375- auto shapes_strides_ = packed_shapes_strides.release ();
376- auto chc_offsets_ = packed_chc_offsets.release ();
377-
378- // free packed temporaries
379- sycl::event temporaries_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
380- cgh.depends_on (choose_generic_ev);
381- const auto &ctx = exec_q.get_context ();
382-
383- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
384- cgh.host_task ([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() {
385- sycl_free_noexcept (chc_ptrs_, ctx);
386- sycl_free_noexcept (shapes_strides_, ctx);
387- sycl_free_noexcept (chc_offsets_, ctx);
388- });
389- });
349+ // async_smart_free releases owners
350+ sycl::event temporaries_cleanup_ev =
351+ dpctl::tensor::alloc_utils::async_smart_free (
352+ exec_q, {choose_generic_ev}, packed_chc_ptrs, packed_shapes_strides,
353+ packed_chc_offsets);
390354
391355 host_task_events.push_back (temporaries_cleanup_ev);
392356
0 commit comments