@@ -54,77 +54,112 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
5454
5555namespace py = pybind11;
5656
57- std::vector<sycl::event>
58- _populate_choose_kernel_params (sycl::queue &exec_q,
59- std::vector<sycl::event> &host_task_events,
60- char **device_chc_ptrs,
61- py::ssize_t *device_shape_strides,
62- py::ssize_t *device_chc_offsets,
63- const py::ssize_t *shape,
64- int shape_len,
65- std::vector<py::ssize_t > &inp_strides,
66- std::vector<py::ssize_t > &dst_strides,
67- std::vector<py::ssize_t > &chc_strides,
68- std::vector<char *> &chc_ptrs,
69- std::vector<py::ssize_t > &chc_offsets,
70- py::ssize_t n_chcs)
57+ namespace detail
7158{
72- using ptr_host_allocator_T =
73- dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
74- using ptrT = std::vector<char *, ptr_host_allocator_T>;
7559
76- ptr_host_allocator_T ptr_allocator (exec_q);
77- std::shared_ptr<ptrT> host_chc_ptrs_shp =
78- std::make_shared<ptrT>(n_chcs, ptr_allocator);
60+ using host_ptrs_allocator_t =
61+ dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
62+ using ptrs_t = std::vector<char *, host_ptrs_allocator_t >;
63+ using host_ptrs_shp_t = std::shared_ptr<ptrs_t >;
7964
80- using usm_host_allocatorT =
81- dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t >;
82- using shT = std::vector<py::ssize_t , usm_host_allocatorT>;
65+ host_ptrs_shp_t make_host_ptrs (sycl::queue &exec_q,
66+ const std::vector<char *> &ptrs)
67+ {
68+ host_ptrs_allocator_t ptrs_allocator (exec_q);
69+ host_ptrs_shp_t host_ptrs_shp =
70+ std::make_shared<ptrs_t >(ptrs.size (), ptrs_allocator);
71+
72+ std::copy (ptrs.begin (), ptrs.end (), host_ptrs_shp->begin ());
73+
74+ return host_ptrs_shp;
75+ }
76+
77+ using host_sz_allocator_t =
78+ dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t >;
79+ using sz_t = std::vector<py::ssize_t , host_sz_allocator_t >;
80+ using host_sz_shp_t = std::shared_ptr<sz_t >;
8381
84- usm_host_allocatorT sz_allocator (exec_q);
85- std::shared_ptr<shT> host_shape_strides_shp =
86- std::make_shared<shT>(shape_len * (3 + n_chcs), sz_allocator);
82+ host_sz_shp_t make_host_offsets (sycl::queue &exec_q,
83+ const std::vector<py::ssize_t > &offsets)
84+ {
85+ host_sz_allocator_t offsets_allocator (exec_q);
86+ host_sz_shp_t host_offsets_shp =
87+ std::make_shared<sz_t >(offsets.size (), offsets_allocator);
88+
89+ std::copy (offsets.begin (), offsets.end (), host_offsets_shp->begin ());
8790
88- std::shared_ptr<shT> host_chc_offsets_shp =
89- std::make_shared<shT>(n_chcs, sz_allocator);
91+ return host_offsets_shp;
92+ }
93+
94+ host_sz_shp_t make_host_shape_strides (sycl::queue &exec_q,
95+ py::ssize_t n_chcs,
96+ std::vector<py::ssize_t > &shape,
97+ std::vector<py::ssize_t > &inp_strides,
98+ std::vector<py::ssize_t > &dst_strides,
99+ std::vector<py::ssize_t > &chc_strides)
100+ {
101+ auto nelems = shape.size ();
102+ host_sz_allocator_t shape_strides_allocator (exec_q);
103+ host_sz_shp_t host_shape_strides_shp =
104+ std::make_shared<sz_t >(nelems * (3 + n_chcs), shape_strides_allocator);
90105
91- std::copy (shape, shape + shape_len , host_shape_strides_shp->begin ());
106+ std::copy (shape. begin () , shape. end () , host_shape_strides_shp->begin ());
92107 std::copy (inp_strides.begin (), inp_strides.end (),
93- host_shape_strides_shp->begin () + shape_len );
108+ host_shape_strides_shp->begin () + nelems );
94109 std::copy (dst_strides.begin (), dst_strides.end (),
95- host_shape_strides_shp->begin () + 2 * shape_len );
110+ host_shape_strides_shp->begin () + 2 * nelems );
96111 std::copy (chc_strides.begin (), chc_strides.end (),
97- host_shape_strides_shp->begin () + 3 * shape_len );
112+ host_shape_strides_shp->begin () + 3 * nelems );
98113
99- std::copy (chc_ptrs.begin (), chc_ptrs.end (), host_chc_ptrs_shp->begin ());
100- std::copy (chc_offsets.begin (), chc_offsets.end (),
101- host_chc_offsets_shp->begin ());
114+ return host_shape_strides_shp;
115+ }
102116
103- const sycl::event &device_chc_ptrs_copy_ev = exec_q.copy <char *>(
104- host_chc_ptrs_shp->data (), device_chc_ptrs, host_chc_ptrs_shp->size ());
117+ /* This function expects a queue and a non-trivial number of
118+ std::pairs of raw device pointers and host shared pointers
119+ (structured as <device_ptr, shared_ptr>),
120+ then enqueues a copy of the host shared pointer data into
121+ the device pointer.
122+
123+ Assumes the device pointer addresses sufficient memory for
124+ the size of the host memory.
125+ */
126+ template <typename ... DevHostPairs>
127+ std::vector<sycl::event> batched_copy (sycl::queue &exec_q,
128+ DevHostPairs &&...dev_host_pairs)
129+ {
130+ constexpr std::size_t n = sizeof ...(DevHostPairs);
131+ static_assert (n > 0 , " batched_copy requires at least one argument" );
105132
106- const sycl::event &device_shape_strides_copy_ev = exec_q.copy <py::ssize_t >(
107- host_shape_strides_shp->data (), device_shape_strides,
108- host_shape_strides_shp->size ());
133+ std::vector<sycl::event> copy_evs;
134+ copy_evs.reserve (n);
135+ (copy_evs.emplace_back (exec_q.copy (dev_host_pairs.second ->data (),
136+ dev_host_pairs.first ,
137+ dev_host_pairs.second ->size ())),
138+ ...);
109139
110- const sycl::event &device_chc_offsets_copy_ev = exec_q.copy <py::ssize_t >(
111- host_chc_offsets_shp->data (), device_chc_offsets,
112- host_chc_offsets_shp->size ());
140+ return copy_evs;
141+ }
142+
143+ /* This function takes as input a queue, sycl::event dependencies,
144+ and a non-trivial number of shared_ptrs and moves them into
145+ a host_task lambda capture, ensuring their lifetime until the
146+ host_task executes.
147+ */
148+ template <typename ... Shps>
149+ sycl::event async_shp_free (sycl::queue &exec_q,
150+ const std::vector<sycl::event> &depends,
151+ Shps &&...shps)
152+ {
153+ constexpr std::size_t n = sizeof ...(Shps);
154+ static_assert (n > 0 , " async_shp_free requires at least one argument" );
113155
114156 const sycl::event &shared_ptr_cleanup_ev =
115157 exec_q.submit ([&](sycl::handler &cgh) {
116- cgh.depends_on ({device_chc_offsets_copy_ev,
117- device_shape_strides_copy_ev,
118- device_chc_ptrs_copy_ev});
119- cgh.host_task ([host_chc_offsets_shp, host_shape_strides_shp,
120- host_chc_ptrs_shp]() {});
158+ cgh.depends_on (depends);
159+ cgh.host_task ([capture = std::tuple (std::move (shps)...)]() {});
121160 });
122- host_task_events.push_back (shared_ptr_cleanup_ev);
123161
124- std::vector<sycl::event> param_pack_deps{device_chc_ptrs_copy_ev,
125- device_shape_strides_copy_ev,
126- device_chc_offsets_copy_ev};
127- return param_pack_deps;
162+ return shared_ptr_cleanup_ev;
128163}
129164
130165// copied from dpctl, remove if a similar utility is ever exposed
@@ -149,6 +184,8 @@ std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
149184 return res;
150185}
151186
187+ } // namespace detail
188+
152189std::pair<sycl::event, sycl::event>
153190 py_choose (const dpctl::tensor::usm_ndarray &src,
154191 const py::object &py_chcs,
@@ -158,7 +195,7 @@ std::pair<sycl::event, sycl::event>
158195 const std::vector<sycl::event> &depends)
159196{
160197 std::vector<dpctl::tensor::usm_ndarray> chcs =
161- parse_py_chcs (exec_q, py_chcs);
198+ detail:: parse_py_chcs (exec_q, py_chcs);
162199
163200 // Python list max size must fit into py_ssize_t
164201 py::ssize_t n_chcs = chcs.size ();
@@ -310,31 +347,37 @@ std::pair<sycl::event, sycl::event>
310347 host_task_events.reserve (2 );
311348
312349 std::vector<sycl::event> pack_deps;
350+ std::vector<py::ssize_t > common_shape;
351+ std::vector<py::ssize_t > src_strides;
352+ std::vector<py::ssize_t > dst_strides;
313353 if (nd == 0 ) {
314354 // special case where all inputs are scalars
315355 // need to pass src, dst shape=1 and strides=0
316356 // chc_strides already initialized to 0 so ignore
317- std::array<py::ssize_t , 1 > scalar_sh{1 };
318- std::vector<py::ssize_t > src_strides{0 };
319- std::vector<py::ssize_t > dst_strides{0 };
320-
321- pack_deps = _populate_choose_kernel_params (
322- exec_q, host_task_events, packed_chc_ptrs.get (),
323- packed_shapes_strides.get (), packed_chc_offsets.get (),
324- scalar_sh.data (), sh_nelems, src_strides, dst_strides, chc_strides,
325- chc_ptrs, chc_offsets, n_chcs);
357+ common_shape = {1 };
358+ src_strides = {0 };
359+ dst_strides = {0 };
326360 }
327361 else {
328- auto src_strides = src.get_strides_vector ();
329- auto dst_strides = dst.get_strides_vector ();
330-
331- pack_deps = _populate_choose_kernel_params (
332- exec_q, host_task_events, packed_chc_ptrs.get (),
333- packed_shapes_strides.get (), packed_chc_offsets.get (), src_shape,
334- sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs,
335- chc_offsets, n_chcs);
362+ common_shape = src.get_shape_vector ();
363+ src_strides = src.get_strides_vector ();
364+ dst_strides = dst.get_strides_vector ();
336365 }
337366
367+ auto host_chc_ptrs = detail::make_host_ptrs (exec_q, chc_ptrs);
368+ auto host_chc_offsets = detail::make_host_offsets (exec_q, chc_offsets);
369+ auto host_shape_strides = detail::make_host_shape_strides (
370+ exec_q, n_chcs, common_shape, src_strides, dst_strides, chc_strides);
371+
372+ pack_deps = detail::batched_copy (
373+ exec_q, std::make_pair (packed_chc_ptrs.get (), host_chc_ptrs),
374+ std::make_pair (packed_chc_offsets.get (), host_chc_offsets),
375+ std::make_pair (packed_shapes_strides.get (), host_shape_strides));
376+
377+ host_task_events.push_back (
378+ detail::async_shp_free (exec_q, pack_deps, host_chc_ptrs,
379+ host_chc_offsets, host_shape_strides));
380+
338381 std::vector<sycl::event> all_deps;
339382 all_deps.reserve (depends.size () + pack_deps.size ());
340383 all_deps.insert (std::end (all_deps), std::begin (pack_deps),
0 commit comments