Skip to content

Commit cb4e17d

Browse files
authored
ENH: add new function isin (data-apis#485)
1 parent 03d4d28 commit cb4e17d

File tree

4 files changed

+134
-1
lines changed

4 files changed

+134
-1
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
cov,
77
expand_dims,
88
isclose,
9+
isin,
910
nan_to_num,
1011
one_hot,
1112
pad,
@@ -39,6 +40,7 @@
3940
"default_dtype",
4041
"expand_dims",
4142
"isclose",
43+
"isin",
4244
"kron",
4345
"lazy_apply",
4446
"nan_to_num",

src/array_api_extra/_delegation.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,3 +836,62 @@ def argpartition(
836836
# kth is not small compared to x.size
837837

838838
return _funcs.argpartition(a, kth, axis=axis, xp=xp)
839+
840+
841+
def isin(
842+
a: Array,
843+
b: Array,
844+
/,
845+
*,
846+
assume_unique: bool = False,
847+
invert: bool = False,
848+
kind: str | None = None,
849+
xp: ModuleType | None = None,
850+
) -> Array:
851+
"""
852+
Determine whether each element in `a` is present in `b`.
853+
854+
Return a boolean array of the same shape as `a` that is True for elements
855+
that are in `b` and False otherwise.
856+
857+
Parameters
858+
----------
859+
a : array
860+
Input elements.
861+
b : array
862+
The elements against which to test each element of `a`.
863+
assume_unique : bool, optional
864+
If True, the input arrays are both assumed to be unique which can speed
865+
up the calculation. Default: False.
866+
invert : bool, optional
867+
If True, the values in the returned array are inverted. Default: False.
868+
kind : str | None, optional
869+
The algorithm or method to use. This will not affect the final result,
870+
but will affect the speed and memory use.
871+
For NumPy the options are {None, "sort", "table"}.
872+
For Jax the mapped parameter is instead `method` and the options are
873+
{"compare_all", "binary_search", "sort", and "auto" (default)}
874+
For CuPy, Dask, Torch and the default case this parameter is not present and
875+
thus ignored. Default: None.
876+
xp : array_namespace, optional
877+
The standard-compatible namespace for `a` and `b`. Default: infer.
878+
879+
Returns
880+
-------
881+
array
882+
An array having the same shape as that of `a` that is True for elements
883+
that are in `b` and False otherwise.
884+
"""
885+
if xp is None:
886+
xp = array_namespace(a, b)
887+
888+
if is_numpy_namespace(xp):
889+
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, kind=kind)
890+
if is_jax_namespace(xp):
891+
if kind is None:
892+
kind = "auto"
893+
return xp.isin(a, b, assume_unique=assume_unique, invert=invert, method=kind)
894+
if is_cupy_namespace(xp) or is_torch_namespace(xp) or is_dask_namespace(xp):
895+
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)
896+
897+
return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,22 @@ def argpartition( # numpydoc ignore=PR01,RT01
801801
) -> Array:
802802
"""See docstring in `array_api_extra._delegation.py`."""
803803
return xp.argsort(x, axis=axis, stable=False)
804+
805+
806+
def isin( # numpydoc ignore=PR01,RT01
807+
a: Array,
808+
b: Array,
809+
/,
810+
*,
811+
assume_unique: bool = False,
812+
invert: bool = False,
813+
xp: ModuleType,
814+
) -> Array:
815+
"""See docstring in `array_api_extra._delegation.py`."""
816+
original_a_shape = a.shape
817+
a = xp.reshape(a, (-1,))
818+
b = xp.reshape(b, (-1,))
819+
return xp.reshape(
820+
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
821+
original_a_shape,
822+
)

tests/test_funcs.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
default_dtype,
2323
expand_dims,
2424
isclose,
25+
isin,
2526
kron,
2627
nan_to_num,
2728
nunique,
@@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
888889
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
889890
res = isclose(a, b, equal_nan=equal_nan)
890891
assert get_device(res) == device
891-
892+
892893
def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device):
893894
a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device)
894895
b = 1
@@ -1476,3 +1477,55 @@ def test_nd(self, xp: ModuleType, ndim: int):
14761477
@override
14771478
def test_input_validation(self, xp: ModuleType):
14781479
self._test_input_validation(xp)
1480+
1481+
1482+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
1483+
class TestIsIn:
1484+
def test_simple(self, xp: ModuleType, library: Backend):
1485+
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
1486+
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")
1487+
1488+
b = xp.asarray([1, 2, 3, 4])
1489+
1490+
# `a` with 1 dimension
1491+
a = xp.asarray([1, 3, 6, 10])
1492+
expected = xp.asarray([True, True, False, False])
1493+
res = isin(a, b)
1494+
xp_assert_equal(res, expected)
1495+
1496+
# `a` with 2 dimensions
1497+
a = xp.asarray([[0, 2], [4, 6]])
1498+
expected = xp.asarray([[False, True], [True, False]])
1499+
res = isin(a, b)
1500+
xp_assert_equal(res, expected)
1501+
1502+
def test_device(self, xp: ModuleType, device: Device, library: Backend):
1503+
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
1504+
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")
1505+
1506+
a = xp.asarray([1, 3, 6], device=device)
1507+
b = xp.asarray([1, 2, 3], device=device)
1508+
assert get_device(isin(a, b)) == device
1509+
1510+
def test_assume_unique_and_invert(
1511+
self, xp: ModuleType, device: Device, library: Backend
1512+
):
1513+
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
1514+
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")
1515+
1516+
a = xp.asarray([0, 3, 6, 10], device=device)
1517+
b = xp.asarray([1, 2, 3, 10], device=device)
1518+
expected = xp.asarray([True, False, True, False])
1519+
res = isin(a, b, assume_unique=True, invert=True)
1520+
assert get_device(res) == device
1521+
xp_assert_equal(res, expected)
1522+
1523+
def test_kind(self, xp: ModuleType, library: Backend):
1524+
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
1525+
pytest.xfail("NumPy <1.24 has no kind kwarg in isin")
1526+
1527+
a = xp.asarray([0, 3, 6, 10])
1528+
b = xp.asarray([1, 2, 3, 10])
1529+
expected = xp.asarray([False, True, False, True])
1530+
res = isin(a, b, kind="sort")
1531+
xp_assert_equal(res, expected)

0 commit comments

Comments
 (0)