Skip to content

Commit d74eb6d

Browse files
committed
address review comments
1 parent 42e9893 commit d74eb6d

File tree

2 files changed

+108
-80
lines changed

2 files changed

+108
-80
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,8 @@ def isin(
12021202
test_elements,
12031203
assume_unique=False, # pylint: disable=unused-argument
12041204
invert=False,
1205+
*,
1206+
kind=None, # pylint: disable=unused-argument
12051207
):
12061208
"""
12071209
Calculates ``element in test_elements``, broadcasting over `element` only.
@@ -1221,22 +1223,27 @@ def isin(
12211223
assume_unique : bool, optional
12221224
Ignored, as no performance benefit is gained by assuming the
12231225
input arrays are unique. Included for compatibility with NumPy.
1226+
12241227
Default: ``False``.
12251228
invert : bool, optional
12261229
If ``True``, the values in the returned array are inverted, as if
1227-
calculating `element not in test_elements`.
1230+
calculating ``element not in test_elements``.
12281231
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
12291232
than) ``dpnp.invert(dpnp.isin(a, b))``.
1233+
12301234
Default: ``False``.
1235+
kind : {None, "sort"}, optional
1236+
Ignored, as the only algorithm implemented is ``"sort"``. Included for
1237+
compatibility with NumPy.
12311238
1239+
Default: ``None``.
12321240
12331241
Returns
12341242
-------
12351243
isin : dpnp.ndarray of bool dtype
12361244
Has the same shape as `element`. The values `element[isin]`
12371245
are in `test_elements`.
12381246
1239-
12401247
Examples
12411248
--------
12421249
>>> import dpnp as np
@@ -1269,14 +1276,32 @@ def isin(
12691276
"""
12701277

12711278
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1272-
usm_element = dpnp.as_usm_ndarray(
1273-
element, usm_type=element.usm_type, sycl_queue=element.sycl_queue
1274-
)
1275-
usm_test = dpnp.as_usm_ndarray(
1276-
test_elements,
1277-
usm_type=test_elements.usm_type,
1278-
sycl_queue=test_elements.sycl_queue,
1279-
)
1279+
if dpnp.isscalar(element):
1280+
usm_element = dpnp.as_usm_ndarray(
1281+
element,
1282+
usm_type=test_elements.usm_type,
1283+
sycl_queue=test_elements.sycl_queue,
1284+
)
1285+
usm_test = dpnp.get_usm_ndarray(test_elements)
1286+
elif dpnp.isscalar(test_elements):
1287+
usm_test = dpnp.as_usm_ndarray(
1288+
test_elements,
1289+
usm_type=element.usm_type,
1290+
sycl_queue=element.sycl_queue,
1291+
)
1292+
usm_element = dpnp.get_usm_ndarray(element)
1293+
else:
1294+
if (
1295+
dpu.get_execution_queue(
1296+
(element.sycl_queue, test_elements.sycl_queue)
1297+
)
1298+
is None
1299+
):
1300+
raise dpu.ExecutionPlacementError(
1301+
"Input arrays have incompatible allocation queues"
1302+
)
1303+
usm_element = dpnp.get_usm_ndarray(element)
1304+
usm_test = dpnp.get_usm_ndarray(test_elements)
12801305
return dpnp.get_result_array(
12811306
dpt.isin(
12821307
usm_element,

dpnp/tests/test_logic.py

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -797,98 +797,101 @@ def test_array_equal_nan(a):
797797
assert_equal(result, expected)
798798

799799

800-
@pytest.mark.parametrize(
801-
"a",
802-
[
803-
numpy.array([1, 2, 3, 4]),
804-
numpy.array([[1, 2], [3, 4]]),
805-
],
806-
)
807-
@pytest.mark.parametrize(
808-
"b",
809-
[
810-
numpy.array([2, 4, 6]),
811-
numpy.array([[1, 3], [5, 7]]),
812-
],
813-
)
814-
def test_isin_basic(a, b):
815-
dp_a = dpnp.array(a)
816-
dp_b = dpnp.array(b)
800+
class TestIsin:
801+
@pytest.mark.parametrize(
802+
"a",
803+
[
804+
numpy.array([1, 2, 3, 4]),
805+
numpy.array([[1, 2], [3, 4]]),
806+
],
807+
)
808+
@pytest.mark.parametrize(
809+
"b",
810+
[
811+
numpy.array([2, 4, 6]),
812+
numpy.array([[1, 3], [5, 7]]),
813+
],
814+
)
815+
def test_isin_basic(a, b):
816+
dp_a = dpnp.array(a)
817+
dp_b = dpnp.array(b)
817818

818-
expected = numpy.isin(a, b)
819-
result = dpnp.isin(dp_a, dp_b)
820-
assert_equal(result, expected)
819+
expected = numpy.isin(a, b)
820+
result = dpnp.isin(dp_a, dp_b)
821+
assert_equal(result, expected)
821822

822823

823-
@pytest.mark.parametrize("dtype", get_all_dtypes())
824-
def test_isin_dtype(dtype):
825-
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826-
b = numpy.array([2, 4], dtype=dtype)
824+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
825+
def test_isin_dtype(dtype):
826+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
827+
b = numpy.array([2, 4], dtype=dtype)
827828

828-
dp_a = dpnp.array(a, dtype=dtype)
829-
dp_b = dpnp.array(b, dtype=dtype)
829+
dp_a = dpnp.array(a, dtype=dtype)
830+
dp_b = dpnp.array(b, dtype=dtype)
830831

831-
expected = numpy.isin(a, b)
832-
result = dpnp.isin(dp_a, dp_b)
833-
assert_equal(result, expected)
832+
expected = numpy.isin(a, b)
833+
result = dpnp.isin(dp_a, dp_b)
834+
assert_equal(result, expected)
834835

835836

836-
@pytest.mark.parametrize("sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))])
837-
def test_isin_broadcast(sh_a, sh_b):
838-
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
839-
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
837+
@pytest.mark.parametrize(
838+
"sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))]
839+
)
840+
def test_isin_broadcast(sh_a, sh_b):
841+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
842+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
840843

841-
dp_a = dpnp.array(a)
842-
dp_b = dpnp.array(b)
844+
dp_a = dpnp.array(a)
845+
dp_b = dpnp.array(b)
843846

844-
expected = numpy.isin(a, b)
845-
result = dpnp.isin(dp_a, dp_b)
846-
assert_equal(result, expected)
847+
expected = numpy.isin(a, b)
848+
result = dpnp.isin(dp_a, dp_b)
849+
assert_equal(result, expected)
847850

848851

849-
def test_isin_scalar_elements():
850-
a = numpy.array([1, 2, 3])
851-
b = 2
852+
def test_isin_scalar_elements():
853+
a = numpy.array([1, 2, 3])
854+
b = 2
852855

853-
dp_a = dpnp.array(a)
854-
dp_b = dpnp.array(b)
856+
dp_a = dpnp.array(a)
857+
dp_b = dpnp.array(b)
855858

856-
expected = numpy.isin(a, b)
857-
result = dpnp.isin(dp_a, dp_b)
858-
assert_equal(result, expected)
859+
expected = numpy.isin(a, b)
860+
result = dpnp.isin(dp_a, dp_b)
861+
assert_equal(result, expected)
859862

860863

861-
def test_isin_scalar_test_elements():
862-
a = 2
863-
b = numpy.array([1, 2, 3])
864+
def test_isin_scalar_test_elements():
865+
a = 2
866+
b = numpy.array([1, 2, 3])
864867

865-
dp_a = dpnp.array(a)
866-
dp_b = dpnp.array(b)
868+
dp_a = dpnp.array(a)
869+
dp_b = dpnp.array(b)
867870

868-
expected = numpy.isin(a, b)
869-
result = dpnp.isin(dp_a, dp_b)
870-
assert_equal(result, expected)
871+
expected = numpy.isin(a, b)
872+
result = dpnp.isin(dp_a, dp_b)
873+
assert_equal(result, expected)
871874

872875

873-
def test_isin_empty():
874-
a = numpy.array([], dtype=int)
875-
b = numpy.array([1, 2, 3])
876+
def test_isin_empty():
877+
a = numpy.array([], dtype=int)
878+
b = numpy.array([1, 2, 3])
876879

877-
dp_a = dpnp.array(a)
878-
dp_b = dpnp.array(b)
880+
dp_a = dpnp.array(a)
881+
dp_b = dpnp.array(b)
879882

880-
expected = numpy.isin(a, b)
881-
result = dpnp.isin(dp_a, dp_b)
882-
assert_equal(result, expected)
883+
expected = numpy.isin(a, b)
884+
result = dpnp.isin(dp_a, dp_b)
885+
assert_equal(result, expected)
883886

884887

885-
def test_isin_errors():
886-
a = dpnp.arange(5)
887-
b = dpnp.arange(3)
888+
def test_isin_errors():
889+
a = dpnp.arange(5)
890+
b = dpnp.arange(3)
888891

889-
# unsupported type for elements or test_elements
890-
with pytest.raises(TypeError):
891-
dpnp.isin(dict(), b)
892+
# unsupported type for elements or test_elements
893+
with pytest.raises(TypeError):
894+
dpnp.isin(dict(), b)
892895

893-
with pytest.raises(TypeError):
894-
dpnp.isin(a, dict())
896+
with pytest.raises(TypeError):
897+
dpnp.isin(a, dict())

0 commit comments

Comments
 (0)