Skip to content

Commit 49b0593

Browse files
committed
Use get_usm_allocations in choose
Removes need for accumulating a list of USM types and queues
1 parent ac252fd commit 49b0593

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,28 +125,21 @@ def _build_choices_list(choices):
125125
list of arrays. If a single array of dimension greater than one, the array
126126
will be unstacked.
127127
128-
Returns a list of :class:`dpctl.tensor.usm_ndarray`s, a list of
129-
:class:`dpctl.SyclQueue`s, and a list of strings representing USM types.
128+
Returns a list of :class:`dpctl.tensor.usm_ndarray`s.
130129
"""
131130

132131
if dpnp.is_supported_array_type(choices):
133-
queues = [choices.sycl_queue]
134-
usm_types = [choices.usm_type]
135132
choices = [dpnp.get_usm_ndarray(chc) for chc in dpnp.unstack(choices)]
136133
elif isinstance(choices, (tuple, list)):
137-
queues = []
138-
usm_types = []
139134
choices_ = []
140135
for chc in choices:
141136
chc_ = dpnp.get_usm_ndarray(chc)
142137
choices_.append(dpnp.get_usm_ndarray(chc_))
143-
queues.append(chc_.sycl_queue)
144-
usm_types.append(chc_.usm_type)
145138
choices = choices_
146139
else:
147140
raise TypeError("`choices` must be an array or sequence of arrays")
148141

149-
return choices, queues, usm_types
142+
return choices
150143

151144

152145
def _choose_run(inds, chcs, q, usm_type, out=None, mode=0):
@@ -242,14 +235,9 @@ def choose(a, choices, out=None, mode="wrap"):
242235
if not dpnp.issubdtype(ind_dt, dpnp.integer):
243236
raise ValueError("input index array must be of integer data type")
244237

245-
choices, queues, usm_types = _build_choices_list(choices)
238+
choices = _build_choices_list(choices)
246239

247-
res_usm_type = dpu.get_coerced_usm_type(usm_types)
248-
exec_q = dpu.get_execution_queue(queues)
249-
if exec_q is None:
250-
raise dpu.ExecutionPlacementError(
251-
"arrays must be allocated on the same SYCL queue"
252-
)
240+
res_usm_type, exec_q = get_usm_allocations(choices + [inds])
253241
# apply type promotion to input choices
254242
res_dt = dpt.result_type(*choices)
255243
if len(choices) > 1:

0 commit comments

Comments
 (0)