Skip to content

Commit eff36de

Browse files
committed
unstack choices in choose when input is array
This squeezes the output, removing trivial out dimension
1 parent 5c926ee commit eff36de

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _build_choices_list(choices):
120120
"""
121121
Gather queues and USM types for the input, expected to be an array or
122122
list of arrays. If a single array of dimension greater than one, the array
123-
will be split along its first axis.
123+
will be unstacked.
124124
125125
Returns a list of :class:`dpctl.tensor.usm_ndarray`s, a list of
126126
:class:`dpctl.SyclQueue`s, and a list of strings representing USM types.
@@ -132,8 +132,7 @@ def _build_choices_list(choices):
132132
choices_sh = choices.shape
133133
if len(choices_sh) > 1:
134134
choices = [
135-
dpnp.get_usm_ndarray(chc)
136-
for chc in dpnp.array_split(choices, choices_sh[0])
135+
dpnp.get_usm_ndarray(chc) for chc in dpnp.unstack(choices)
137136
]
138137
else:
139138
choices = [choices]
@@ -170,8 +169,8 @@ def choose(x, choices, out=None, mode="wrap"):
170169
tuple of usm_ndarrays, list of dpnp.ndarrays,
171170
list of usm_ndarrays}
172171
Choice arrays. `x` and choice arrays must be broadcast-compatible.
173-
If `choices` is an array, the array is split along its outermost
174-
(i.e., 0th) dimension into a sequence of arrays.
172+
If `choices` is an array, the array is unstacked into a sequence of
173+
arrays.
175174
out : {None, dpnp.ndarray, usm_ndarray}, optional
176175
If provided, the result will be placed in this array. It should
177176
be of the appropriate shape and dtype.

0 commit comments

Comments
 (0)