Skip to content

Commit 9e67af0

Browse files
committed
address review comments
1 parent 086f9c7 commit 9e67af0

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,24 +1200,27 @@ def isfortran(a):
12001200
def isin(element, test_elements, assume_unique=False, invert=False):
12011201
"""
12021202
Calculates ``element in test_elements``, broadcasting over `element` only.
1203-
Returns a boolean array of the same shape as `element` that is True
1204-
where an element of `element` is in `test_elements` and False otherwise.
1203+
Returns a boolean array of the same shape as `element` that is ``True``
1204+
where an element of `element` is in `test_elements` and ``False``
1205+
otherwise.
1206+
1207+
For full documentation refer to :obj:`numpy.isin`.
12051208
12061209
Parameters
12071210
----------
1208-
element : {array_like, dpnp.ndarray, usm_ndarray}
1211+
element : {dpnp.ndarray, usm_ndarray, scalar}
12091212
Input array.
1210-
test_elements : {array_like, dpnp.ndarray, usm_ndarray}
1213+
test_elements : {dpnp.ndarray, usm_ndarray, scalar}
12111214
The values against which to test each value of `element`.
1212-
This argument is flattened if it is an array or array_like.
1213-
See notes for behavior with non-array-like parameters.
1215+
This argument is flattened if it is an array.
12141216
assume_unique : bool, optional
12151217
Ignored
12161218
invert : bool, optional
1217-
If True, the values in the returned array are inverted, as if
1218-
calculating `element not in test_elements`. Default is False.
1219+
If ``True``, the values in the returned array are inverted, as if
1220+
calculating `element not in test_elements`.
12191221
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
12201222
than) ``dpnp.invert(dpnp.isin(a, b))``.
1223+
Default: ``False``.
12211224
12221225
12231226
Returns
@@ -1259,28 +1262,18 @@ def isin(element, test_elements, assume_unique=False, invert=False):
12591262
"""
12601263

12611264
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1262-
if dpnp.isscalar(element):
1263-
usm_element = dpt.asarray(
1264-
element,
1265-
sycl_queue=test_elements.sycl_queue,
1266-
usm_type=test_elements.usm_type,
1267-
)
1268-
usm_test = dpnp.get_usm_ndarray(test_elements)
1269-
elif dpnp.isscalar(test_elements):
1270-
usm_test = dpt.asarray(
1271-
test_elements,
1272-
sycl_queue=element.sycl_queue,
1273-
usm_type=element.usm_type,
1274-
)
1275-
usm_element = dpnp.get_usm_ndarray(element)
1276-
else:
1277-
usm_element = dpnp.get_usm_ndarray(element)
1278-
usm_test = dpnp.get_usm_ndarray(test_elements)
1265+
usm_element = dpnp.as_usm_ndarray(
1266+
element, usm_type=element.usm_type, sycl_queue=element.sycl_queue
1267+
)
1268+
usm_test = dpnp.as_usm_ndarray(
1269+
test_elements,
1270+
usm_type=test_elements.usm_type,
1271+
sycl_queue=test_elements.sycl_queue,
1272+
)
12791273
return dpnp.get_result_array(
12801274
dpt.isin(
12811275
usm_element,
12821276
usm_test,
1283-
assume_unique=assume_unique,
12841277
invert=invert,
12851278
)
12861279
)

0 commit comments

Comments
 (0)