@@ -121,7 +121,7 @@ def _build_choices_list(choices):
121121 """
122122 Gather queues and USM types for the input, expected to be an array or
123123 list of arrays. If a single array of dimension greater than one, the array
124- will be split along its first axis .
124+ will be unstacked .
125125
126126 Returns a list of :class:`dpctl.tensor.usm_ndarray`s, a list of
127127 :class:`dpctl.SyclQueue`s, and a list of strings representing USM types.
@@ -133,8 +133,7 @@ def _build_choices_list(choices):
133133 choices_sh = choices .shape
134134 if len (choices_sh ) > 1 :
135135 choices = [
136- dpnp .get_usm_ndarray (chc )
137- for chc in dpnp .array_split (choices , choices_sh [0 ])
136+ dpnp .get_usm_ndarray (chc ) for chc in dpnp .unstack (choices )
138137 ]
139138 else :
140139 choices = [choices ]
@@ -171,8 +170,8 @@ def choose(x, choices, out=None, mode="wrap"):
171170 tuple of usm_ndarrays, list of dpnp.ndarrays,
172171 list of usm_ndarrays}
173172 Choice arrays. `x` and choice arrays must be broadcast-compatible.
174- If `choices` is an array, the array is split along its outermost
175- (i.e., 0th) dimension into a sequence of arrays.
173+ If `choices` is an array, the array is unstacked into a sequence of
174+ arrays.
176175 out : {None, dpnp.ndarray, usm_ndarray}, optional
177176 If provided, the result will be placed in this array. It should
178177 be of the appropriate shape and dtype.
0 commit comments