@@ -79,7 +79,7 @@ using value_type_of_t = typename value_type_of<T>::type;
7979typedef sycl::event (*nan_to_num_fn_ptr_t )(sycl::queue &,
8080 int ,
8181 size_t ,
82- py::ssize_t *,
82+ const py::ssize_t *,
8383 const py::object &,
8484 const py::object &,
8585 const py::object &,
@@ -93,7 +93,7 @@ template <typename T>
9393sycl::event nan_to_num_call (sycl::queue &exec_q,
9494 int nd,
9595 size_t nelems,
96- py::ssize_t *shape_strides,
96+ const py::ssize_t *shape_strides,
9797 const py::object &py_nan,
9898 const py::object &py_posinf,
9999 const py::object &py_neginf,
@@ -302,15 +302,12 @@ std::pair<sycl::event, sycl::event>
302302 std::vector<sycl::event> host_tasks{};
303303 host_tasks.reserve (2 );
304304
305- const auto & ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t >(
305+ auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t >(
306306 q, host_tasks, simplified_shape, simplified_src_strides,
307307 simplified_dst_strides);
308- py:: ssize_t *shape_strides = std::get<0 >(ptr_size_event_triple_);
308+ auto shape_strides_owner = std::move (std:: get<0 >(ptr_size_event_triple_) );
309309 const sycl::event ©_shape_ev = std::get<2 >(ptr_size_event_triple_);
310-
311- if (shape_strides == nullptr ) {
312- throw std::runtime_error (" Device memory allocation failed" );
313- }
310+ const py::ssize_t *shape_strides = shape_strides_owner.get ();
314311
315312 std::vector<sycl::event> all_deps;
316313 all_deps.reserve (depends.size () + 1 );
@@ -322,13 +319,9 @@ std::pair<sycl::event, sycl::event>
322319 src_offset, dst_data, dst_offset, all_deps);
323320
324321 // async free of shape_strides temporary
325- auto ctx = q.get_context ();
326- sycl::event tmp_cleanup_ev = q.submit ([&](sycl::handler &cgh) {
327- cgh.depends_on (comp_ev);
328- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
329- cgh.host_task (
330- [ctx, shape_strides]() { sycl_free_noexcept (shape_strides, ctx); });
331- });
322+ sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
323+ q, {comp_ev}, shape_strides_owner);
324+
332325 host_tasks.push_back (tmp_cleanup_ev);
333326
334327 return std::make_pair (
0 commit comments