Skip to content

Commit 661eff7

Browse files
committed
address review comments
1 parent 42e9893 commit 661eff7

File tree

2 files changed

+108
-86
lines changed

2 files changed

+108
-86
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,8 @@ def isin(
12021202
test_elements,
12031203
assume_unique=False, # pylint: disable=unused-argument
12041204
invert=False,
1205+
*,
1206+
kind=None, # pylint: disable=unused-argument
12051207
):
12061208
"""
12071209
Calculates ``element in test_elements``, broadcasting over `element` only.
@@ -1221,22 +1223,27 @@ def isin(
12211223
assume_unique : bool, optional
12221224
Ignored, as no performance benefit is gained by assuming the
12231225
input arrays are unique. Included for compatibility with NumPy.
1226+
12241227
Default: ``False``.
12251228
invert : bool, optional
12261229
If ``True``, the values in the returned array are inverted, as if
1227-
calculating `element not in test_elements`.
1230+
calculating ``element not in test_elements``.
12281231
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
12291232
than) ``dpnp.invert(dpnp.isin(a, b))``.
1233+
12301234
Default: ``False``.
1235+
kind : {None, "sort"}, optional
1236+
Ignored, as the only algorithm implemented is ``"sort"``. Included for
1237+
compatibility with NumPy.
12311238
1239+
Default: ``None``.
12321240
12331241
Returns
12341242
-------
12351243
isin : dpnp.ndarray of bool dtype
12361244
Has the same shape as `element`. The values `element[isin]`
12371245
are in `test_elements`.
12381246
1239-
12401247
Examples
12411248
--------
12421249
>>> import dpnp as np
@@ -1269,14 +1276,32 @@ def isin(
12691276
"""
12701277

12711278
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1272-
usm_element = dpnp.as_usm_ndarray(
1273-
element, usm_type=element.usm_type, sycl_queue=element.sycl_queue
1274-
)
1275-
usm_test = dpnp.as_usm_ndarray(
1276-
test_elements,
1277-
usm_type=test_elements.usm_type,
1278-
sycl_queue=test_elements.sycl_queue,
1279-
)
1279+
if dpnp.isscalar(element):
1280+
usm_element = dpnp.as_usm_ndarray(
1281+
element,
1282+
usm_type=test_elements.usm_type,
1283+
sycl_queue=test_elements.sycl_queue,
1284+
)
1285+
usm_test = dpnp.get_usm_ndarray(test_elements)
1286+
elif dpnp.isscalar(test_elements):
1287+
usm_test = dpnp.as_usm_ndarray(
1288+
test_elements,
1289+
usm_type=element.usm_type,
1290+
sycl_queue=element.sycl_queue,
1291+
)
1292+
usm_element = dpnp.get_usm_ndarray(element)
1293+
else:
1294+
if (
1295+
dpu.get_execution_queue(
1296+
(element.sycl_queue, test_elements.sycl_queue)
1297+
)
1298+
is None
1299+
):
1300+
raise dpu.ExecutionPlacementError(
1301+
"Input arrays have incompatible allocation queues"
1302+
)
1303+
usm_element = dpnp.get_usm_ndarray(element)
1304+
usm_test = dpnp.get_usm_ndarray(test_elements)
12801305
return dpnp.get_result_array(
12811306
dpt.isin(
12821307
usm_element,

dpnp/tests/test_logic.py

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -797,98 +797,95 @@ def test_array_equal_nan(a):
797797
assert_equal(result, expected)
798798

799799

800-
@pytest.mark.parametrize(
801-
"a",
802-
[
803-
numpy.array([1, 2, 3, 4]),
804-
numpy.array([[1, 2], [3, 4]]),
805-
],
806-
)
807-
@pytest.mark.parametrize(
808-
"b",
809-
[
810-
numpy.array([2, 4, 6]),
811-
numpy.array([[1, 3], [5, 7]]),
812-
],
813-
)
814-
def test_isin_basic(a, b):
815-
dp_a = dpnp.array(a)
816-
dp_b = dpnp.array(b)
817-
818-
expected = numpy.isin(a, b)
819-
result = dpnp.isin(dp_a, dp_b)
820-
assert_equal(result, expected)
821-
822-
823-
@pytest.mark.parametrize("dtype", get_all_dtypes())
824-
def test_isin_dtype(dtype):
825-
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826-
b = numpy.array([2, 4], dtype=dtype)
827-
828-
dp_a = dpnp.array(a, dtype=dtype)
829-
dp_b = dpnp.array(b, dtype=dtype)
830-
831-
expected = numpy.isin(a, b)
832-
result = dpnp.isin(dp_a, dp_b)
833-
assert_equal(result, expected)
834-
800+
class TestIsin:
801+
@pytest.mark.parametrize(
802+
"a",
803+
[
804+
numpy.array([1, 2, 3, 4]),
805+
numpy.array([[1, 2], [3, 4]]),
806+
],
807+
)
808+
@pytest.mark.parametrize(
809+
"b",
810+
[
811+
numpy.array([2, 4, 6]),
812+
numpy.array([[1, 3], [5, 7]]),
813+
],
814+
)
815+
def test_isin_basic(a, b):
816+
dp_a = dpnp.array(a)
817+
dp_b = dpnp.array(b)
835818

836-
@pytest.mark.parametrize("sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))])
837-
def test_isin_broadcast(sh_a, sh_b):
838-
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
839-
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
819+
expected = numpy.isin(a, b)
820+
result = dpnp.isin(dp_a, dp_b)
821+
assert_equal(result, expected)
840822

841-
dp_a = dpnp.array(a)
842-
dp_b = dpnp.array(b)
823+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
824+
def test_isin_dtype(dtype):
825+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826+
b = numpy.array([2, 4], dtype=dtype)
843827

844-
expected = numpy.isin(a, b)
845-
result = dpnp.isin(dp_a, dp_b)
846-
assert_equal(result, expected)
828+
dp_a = dpnp.array(a, dtype=dtype)
829+
dp_b = dpnp.array(b, dtype=dtype)
847830

831+
expected = numpy.isin(a, b)
832+
result = dpnp.isin(dp_a, dp_b)
833+
assert_equal(result, expected)
848834

849-
def test_isin_scalar_elements():
850-
a = numpy.array([1, 2, 3])
851-
b = 2
835+
@pytest.mark.parametrize(
836+
"sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))]
837+
)
838+
def test_isin_broadcast(sh_a, sh_b):
839+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
840+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
852841

853-
dp_a = dpnp.array(a)
854-
dp_b = dpnp.array(b)
842+
dp_a = dpnp.array(a)
843+
dp_b = dpnp.array(b)
855844

856-
expected = numpy.isin(a, b)
857-
result = dpnp.isin(dp_a, dp_b)
858-
assert_equal(result, expected)
845+
expected = numpy.isin(a, b)
846+
result = dpnp.isin(dp_a, dp_b)
847+
assert_equal(result, expected)
859848

849+
def test_isin_scalar_elements():
850+
a = numpy.array([1, 2, 3])
851+
b = 2
860852

861-
def test_isin_scalar_test_elements():
862-
a = 2
863-
b = numpy.array([1, 2, 3])
853+
dp_a = dpnp.array(a)
854+
dp_b = dpnp.array(b)
864855

865-
dp_a = dpnp.array(a)
866-
dp_b = dpnp.array(b)
856+
expected = numpy.isin(a, b)
857+
result = dpnp.isin(dp_a, dp_b)
858+
assert_equal(result, expected)
867859

868-
expected = numpy.isin(a, b)
869-
result = dpnp.isin(dp_a, dp_b)
870-
assert_equal(result, expected)
860+
def test_isin_scalar_test_elements():
861+
a = 2
862+
b = numpy.array([1, 2, 3])
871863

864+
dp_a = dpnp.array(a)
865+
dp_b = dpnp.array(b)
872866

873-
def test_isin_empty():
874-
a = numpy.array([], dtype=int)
875-
b = numpy.array([1, 2, 3])
867+
expected = numpy.isin(a, b)
868+
result = dpnp.isin(dp_a, dp_b)
869+
assert_equal(result, expected)
876870

877-
dp_a = dpnp.array(a)
878-
dp_b = dpnp.array(b)
871+
def test_isin_empty():
872+
a = numpy.array([], dtype=int)
873+
b = numpy.array([1, 2, 3])
879874

880-
expected = numpy.isin(a, b)
881-
result = dpnp.isin(dp_a, dp_b)
882-
assert_equal(result, expected)
875+
dp_a = dpnp.array(a)
876+
dp_b = dpnp.array(b)
883877

878+
expected = numpy.isin(a, b)
879+
result = dpnp.isin(dp_a, dp_b)
880+
assert_equal(result, expected)
884881

885-
def test_isin_errors():
886-
a = dpnp.arange(5)
887-
b = dpnp.arange(3)
882+
def test_isin_errors():
883+
a = dpnp.arange(5)
884+
b = dpnp.arange(3)
888885

889-
# unsupported type for elements or test_elements
890-
with pytest.raises(TypeError):
891-
dpnp.isin(dict(), b)
886+
# unsupported type for elements or test_elements
887+
with pytest.raises(TypeError):
888+
dpnp.isin(dict(), b)
892889

893-
with pytest.raises(TypeError):
894-
dpnp.isin(a, dict())
890+
with pytest.raises(TypeError):
891+
dpnp.isin(a, dict())

0 commit comments

Comments
 (0)