Skip to content

Commit 2c13bcb

Browse files
committed
Fixes a bug for 0d inputs to choose and adds a test
Logic now handles 0d inputs to _populate_choose_kernel_params to avoid dereferencing empty shape and strides of input arrays
1 parent c3d57ee commit 2c13bcb

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,17 +330,34 @@ std::pair<sycl::event, sycl::event>
330330
"Unable to allocate packed_chc_offsets device memory");
331331
}
332332

333-
auto src_strides = src.get_strides_vector();
334-
auto dst_strides = dst.get_strides_vector();
335-
336333
std::vector<sycl::event> host_task_events;
337334
host_task_events.reserve(2);
338335

339-
std::vector<sycl::event> pack_deps = _populate_choose_kernel_params(
340-
exec_q, host_task_events, packed_chc_ptrs.get(),
341-
packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
342-
sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs, chc_offsets,
343-
n_chcs);
336+
std::vector<sycl::event> pack_deps;
337+
if (nd == 0) {
338+
// special case where all inputs are scalars
339+
// need to pass src, dst shape=1 and strides=0
340+
// chc_strides already initialized to 0 so ignore
341+
std::array<py::ssize_t, 1> scalar_sh{1};
342+
std::vector<py::ssize_t> src_strides{0};
343+
std::vector<py::ssize_t> dst_strides{0};
344+
345+
pack_deps = _populate_choose_kernel_params(
346+
exec_q, host_task_events, packed_chc_ptrs.get(),
347+
packed_shapes_strides.get(), packed_chc_offsets.get(),
348+
scalar_sh.data(), sh_nelems, src_strides, dst_strides, chc_strides,
349+
chc_ptrs, chc_offsets, n_chcs);
350+
}
351+
else {
352+
auto src_strides = src.get_strides_vector();
353+
auto dst_strides = dst.get_strides_vector();
354+
355+
pack_deps = _populate_choose_kernel_params(
356+
exec_q, host_task_events, packed_chc_ptrs.get(),
357+
packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
358+
sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs,
359+
chc_offsets, n_chcs);
360+
}
344361

345362
std::vector<sycl::event> all_deps;
346363
all_deps.reserve(depends.size() + pack_deps.size());

dpnp/tests/test_indexing.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,3 +1506,10 @@ def test_choose_empty(self):
15061506
assert r.shape == sh[1:]
15071507
r = dpnp.choose(inds, [chcs])
15081508
assert r.shape == sh
1509+
1510+
def test_choose_0d_inputs(self):
1511+
sh = ()
1512+
inds = dpnp.zeros(sh, dtype="i4")
1513+
chc = dpnp.ones(sh, dtype="i4")
1514+
r = dpnp.choose(inds, [chc])
1515+
assert r == chc

0 commit comments

Comments
 (0)