@@ -330,17 +330,34 @@ std::pair<sycl::event, sycl::event>
330330 " Unable to allocate packed_chc_offsets device memory" );
331331 }
332332
333- auto src_strides = src.get_strides_vector ();
334- auto dst_strides = dst.get_strides_vector ();
335-
336333 std::vector<sycl::event> host_task_events;
337334 host_task_events.reserve (2 );
338335
339- std::vector<sycl::event> pack_deps = _populate_choose_kernel_params (
340- exec_q, host_task_events, packed_chc_ptrs.get (),
341- packed_shapes_strides.get (), packed_chc_offsets.get (), src_shape,
342- sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs, chc_offsets,
343- n_chcs);
336+ std::vector<sycl::event> pack_deps;
337+ if (nd == 0 ) {
338+ // special case where all inputs are scalars
339+ // need to pass src, dst shape=1 and strides=0
340+ // chc_strides already initialized to 0 so ignore
341+ std::array<py::ssize_t , 1 > scalar_sh{1 };
342+ std::vector<py::ssize_t > src_strides{0 };
343+ std::vector<py::ssize_t > dst_strides{0 };
344+
345+ pack_deps = _populate_choose_kernel_params (
346+ exec_q, host_task_events, packed_chc_ptrs.get (),
347+ packed_shapes_strides.get (), packed_chc_offsets.get (),
348+ scalar_sh.data (), sh_nelems, src_strides, dst_strides, chc_strides,
349+ chc_ptrs, chc_offsets, n_chcs);
350+ }
351+ else {
352+ auto src_strides = src.get_strides_vector ();
353+ auto dst_strides = dst.get_strides_vector ();
354+
355+ pack_deps = _populate_choose_kernel_params (
356+ exec_q, host_task_events, packed_chc_ptrs.get (),
357+ packed_shapes_strides.get (), packed_chc_offsets.get (), src_shape,
358+ sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs,
359+ chc_offsets, n_chcs);
360+ }
344361
345362 std::vector<sycl::event> all_deps;
346363 all_deps.reserve (depends.size () + pack_deps.size ());
0 commit comments