Skip to content

Commit e22b968

Browse files
committed
Use get_usm_allocations in compress
1 parent e0aa410 commit e22b968

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -267,31 +267,20 @@ def compress(condition, a, axis=None, out=None):
267267

268268
a_ary = dpnp.get_usm_ndarray(a)
269269
if not dpnp.is_supported_array_type(condition):
270-
usm_type = a_ary.usm_type
271-
q = a_ary.sycl_queue
272270
cond_ary = dpnp.as_usm_ndarray(
273271
condition,
274272
dtype=dpnp.bool,
275-
usm_type=usm_type,
276-
sycl_queue=q,
273+
usm_type=a_ary.usm_type,
274+
sycl_queue=a_ary.q,
277275
)
278-
queues_ = [q]
279-
usm_types_ = [usm_type]
280276
else:
281277
cond_ary = dpnp.get_usm_ndarray(condition)
282-
queues_ = [a_ary.sycl_queue, cond_ary.sycl_queue]
283-
usm_types_ = [a_ary.usm_type, cond_ary.usm_type]
284278
if not cond_ary.ndim == 1:
285279
raise ValueError(
286280
"`condition` must be a 1-D array or un-nested sequence"
287281
)
288282

289-
res_usm_type = dpu.get_coerced_usm_type(usm_types_)
290-
exec_q = dpu.get_execution_queue(queues_)
291-
if exec_q is None:
292-
raise dpu.ExecutionPlacementError(
293-
"arrays must be allocated on the same SYCL queue"
294-
)
283+
res_usm_type, exec_q = get_usm_allocations([a_ary, cond_ary])
295284

296285
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
297286
inds = _nonzero_impl(cond_ary)

0 commit comments

Comments
 (0)