@@ -54,6 +54,22 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
5454
5555namespace py = pybind11;
5656
57+ /*
58+ Returns an std::unique_ptr wrapping a USM allocation and deleter.
59+
60+ Must still be manually freed by host_task when allocation is needed
61+ for duration of asynchronous kernel execution.
62+ */
63+ template <typename T>
64+ auto usm_unique_ptr (std::size_t sz, sycl::queue &q)
65+ {
66+ using dpctl::tensor::alloc_utils::sycl_free_noexcept;
67+ auto deleter = [&q](T *usm) { sycl_free_noexcept (usm, q); };
68+
69+ return std::unique_ptr<T, decltype (deleter)>(sycl::malloc_device<T>(sz, q),
70+ deleter);
71+ }
72+
5773std::vector<sycl::event>
5874 _populate_choose_kernel_params (sycl::queue &exec_q,
5975 std::vector<sycl::event> &host_task_events,
@@ -279,9 +295,16 @@ std::pair<sycl::event, sycl::event>
279295 chc_offsets.push_back (py::ssize_t (0 ));
280296 }
281297
282- char **packed_chc_ptrs = sycl::malloc_device<char *>(n_chcs, exec_q);
298+ auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
299+ : choose_wrap_dispatch_table[src_type_id][chc_type_id];
283300
284- if (packed_chc_ptrs == nullptr ) {
301+ if (fn == nullptr ) {
302+ throw std::runtime_error (" Indices must be integer type, got " +
303+ std::to_string (src_type_id));
304+ }
305+
306+ auto packed_chc_ptrs = usm_unique_ptr<char *>(n_chcs, exec_q);
307+ if (packed_chc_ptrs.get () == nullptr ) {
285308 throw std::runtime_error (
286309 " Unable to allocate packed_chc_ptrs device memory" );
287310 }
@@ -292,23 +315,15 @@ std::pair<sycl::event, sycl::event>
292315 // chcs[0].strides,
293316 // ...,
294317 // chcs[n_chcs].strides]
295- py::ssize_t *packed_shapes_strides =
296- sycl::malloc_device<py::ssize_t >((3 + n_chcs) * sh_nelems, exec_q);
297-
298- if (packed_shapes_strides == nullptr ) {
299- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
300- sycl_free_noexcept (packed_chc_ptrs, exec_q);
318+ auto packed_shapes_strides =
319+ usm_unique_ptr<py::ssize_t >((3 + n_chcs) * sh_nelems, exec_q);
320+ if (packed_shapes_strides.get () == nullptr ) {
301321 throw std::runtime_error (
302322 " Unable to allocate packed_shapes_strides device memory" );
303323 }
304324
305- py::ssize_t *packed_chc_offsets =
306- sycl::malloc_device<py::ssize_t >(n_chcs, exec_q);
307-
308- if (packed_chc_offsets == nullptr ) {
309- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
310- sycl_free_noexcept (packed_chc_ptrs, exec_q);
311- sycl_free_noexcept (packed_shapes_strides, exec_q);
325+ auto packed_chc_offsets = usm_unique_ptr<py::ssize_t >(n_chcs, exec_q);
326+ if (packed_chc_offsets.get () == nullptr ) {
312327 throw std::runtime_error (
313328 " Unable to allocate packed_chc_offsets device memory" );
314329 }
@@ -320,44 +335,37 @@ std::pair<sycl::event, sycl::event>
320335 host_task_events.reserve (2 );
321336
322337 std::vector<sycl::event> pack_deps = _populate_choose_kernel_params (
323- exec_q, host_task_events, packed_chc_ptrs, packed_shapes_strides,
324- packed_chc_offsets, src_shape, sh_nelems, src_strides, dst_strides,
325- chc_strides, chc_ptrs, chc_offsets, n_chcs);
338+ exec_q, host_task_events, packed_chc_ptrs.get (),
339+ packed_shapes_strides.get (), packed_chc_offsets.get (), src_shape,
340+ sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs, chc_offsets,
341+ n_chcs);
326342
327343 std::vector<sycl::event> all_deps;
328344 all_deps.reserve (depends.size () + pack_deps.size ());
329345 all_deps.insert (std::end (all_deps), std::begin (pack_deps),
330346 std::end (pack_deps));
331347 all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
332348
333- auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
334- : choose_wrap_dispatch_table[src_type_id][chc_type_id];
335-
336- if (fn == nullptr ) {
337- sycl::event::wait (host_task_events);
338- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
339- sycl_free_noexcept (packed_chc_ptrs, exec_q);
340- sycl_free_noexcept (packed_shapes_strides, exec_q);
341- sycl_free_noexcept (packed_chc_offsets, exec_q);
342- throw std::runtime_error (" Indices must be integer type, got " +
343- std::to_string (src_type_id));
344- }
345-
346349 sycl::event choose_generic_ev =
347- fn (exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides, src_data,
348- dst_data, packed_chc_ptrs, src_offset, dst_offset,
349- packed_chc_offsets, all_deps);
350+ fn (exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides.get (),
351+ src_data, dst_data, packed_chc_ptrs.get (), src_offset, dst_offset,
352+ packed_chc_offsets.get (), all_deps);
353+
354+ // release usm_unique_ptrs
355+ auto chc_ptrs_ = packed_chc_ptrs.release ();
356+ auto shapes_strides_ = packed_shapes_strides.release ();
357+ auto chc_offsets_ = packed_chc_offsets.release ();
350358
351359 // free packed temporaries
352360 sycl::event temporaries_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
353361 cgh.depends_on (choose_generic_ev);
354362 const auto &ctx = exec_q.get_context ();
363+
355364 using dpctl::tensor::alloc_utils::sycl_free_noexcept;
356- cgh.host_task ([packed_shapes_strides, packed_chc_ptrs,
357- packed_chc_offsets, ctx]() {
358- sycl_free_noexcept (packed_shapes_strides, ctx);
359- sycl_free_noexcept (packed_chc_ptrs, ctx);
360- sycl_free_noexcept (packed_chc_offsets, ctx);
365+ cgh.host_task ([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() {
366+ sycl_free_noexcept (chc_ptrs_, ctx);
367+ sycl_free_noexcept (shapes_strides_, ctx);
368+ sycl_free_noexcept (chc_offsets_, ctx);
361369 });
362370 });
363371
0 commit comments