Skip to content

Commit 295de83

Browse files
committed
unstack choices in choose when input is array
This squeezes the output, removing trivial out dimension
1 parent 9125814 commit 295de83

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _build_choices_list(choices):
121121
"""
122122
Gather queues and USM types for the input, expected to be an array or
123123
list of arrays. If a single array of dimension greater than one, the array
124-
will be split along its first axis.
124+
will be unstacked.
125125
126126
Returns a list of :class:`dpctl.tensor.usm_ndarray`s, a list of
127127
:class:`dpctl.SyclQueue`s, and a list of strings representing USM types.
@@ -133,8 +133,7 @@ def _build_choices_list(choices):
133133
choices_sh = choices.shape
134134
if len(choices_sh) > 1:
135135
choices = [
136-
dpnp.get_usm_ndarray(chc)
137-
for chc in dpnp.array_split(choices, choices_sh[0])
136+
dpnp.get_usm_ndarray(chc) for chc in dpnp.unstack(choices)
138137
]
139138
else:
140139
choices = [choices]
@@ -171,8 +170,8 @@ def choose(x, choices, out=None, mode="wrap"):
171170
tuple of usm_ndarrays, list of dpnp.ndarrays,
172171
list of usm_ndarrays}
173172
Choice arrays. `x` and choice arrays must be broadcast-compatible.
174-
If `choices` is an array, the array is split along its outermost
175-
(i.e., 0th) dimension into a sequence of arrays.
173+
If `choices` is an array, the array is unstacked into a sequence of
174+
arrays.
176175
out : {None, dpnp.ndarray, usm_ndarray}, optional
177176
If provided, the result will be placed in this array. It should
178177
be of the appropriate shape and dtype.

0 commit comments

Comments
 (0)