Skip to content

Commit 2381662

Browse files
committed
Remove branching when condition is an array
Also tweaks to docstring
1 parent fc85e73 commit 2381662

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,10 @@ def _take_1d_index(x, inds, axis, q, usm_type, out=None):
227227

228228
def compress(condition, a, axis=None, out=None):
229229
"""
230-
A copy of `a` without the slices along `axis` for which `condition` is
231-
``False``.
230+
Return selected slices of an array along given axis.
231+
232+
A slice of `a` is returned for each index along `axis` where `condition`
233+
is ``True``.
232234
233235
For full documentation refer to :obj:`numpy.choose`.
234236
@@ -299,15 +301,13 @@ def compress(condition, a, axis=None, out=None):
299301
axis = normalize_axis_index(operator.index(axis), a.ndim)
300302

301303
a_ary = dpnp.get_usm_ndarray(a)
302-
if not dpnp.is_supported_array_type(condition):
303-
cond_ary = dpnp.as_usm_ndarray(
304-
condition,
305-
dtype=dpnp.bool,
306-
usm_type=a_ary.usm_type,
307-
sycl_queue=a_ary.sycl_queue,
308-
)
309-
else:
310-
cond_ary = dpnp.get_usm_ndarray(condition)
304+
cond_ary = dpnp.as_usm_ndarray(
305+
condition,
306+
dtype=dpnp.bool,
307+
usm_type=a_ary.usm_type,
308+
sycl_queue=a_ary.sycl_queue,
309+
)
310+
311311
if not cond_ary.ndim == 1:
312312
raise ValueError(
313313
"`condition` must be a 1-D array or un-nested sequence"

0 commit comments

Comments
 (0)