@@ -120,7 +120,7 @@ def _build_choices_list(choices):
120120 """
121121 Gather queues and USM types for the input, expected to be an array or
122122 list of arrays. If a single array of dimension greater than one, the array
123- will be split along its first axis .
123+ will be unstacked .
124124
125125 Returns a list of :class:`dpctl.tensor.usm_ndarray`s, a list of
126126 :class:`dpctl.SyclQueue`s, and a list of strings representing USM types.
@@ -132,8 +132,7 @@ def _build_choices_list(choices):
132132 choices_sh = choices .shape
133133 if len (choices_sh ) > 1 :
134134 choices = [
135- dpnp .get_usm_ndarray (chc )
136- for chc in dpnp .array_split (choices , choices_sh [0 ])
135+ dpnp .get_usm_ndarray (chc ) for chc in dpnp .unstack (choices )
137136 ]
138137 else :
139138 choices = [choices ]
@@ -170,8 +169,8 @@ def choose(x, choices, out=None, mode="wrap"):
170169 tuple of usm_ndarrays, list of dpnp.ndarrays,
171170 list of usm_ndarrays}
172171 Choice arrays. `x` and choice arrays must be broadcast-compatible.
173- If `choices` is an array, the array is split along its outermost
174- (i.e., 0th) dimension into a sequence of arrays.
172+ If `choices` is an array, the array is unstacked into a sequence of
173+ arrays.
175174 out : {None, dpnp.ndarray, usm_ndarray}, optional
176175 If provided, the result will be placed in this array. It should
177176 be of the appropriate shape and dtype.
0 commit comments