Skip to content

Commit 7cb402a

Browse files
committed
Add __contains__() method
1 parent 81870b7 commit 7cb402a

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

dpnp/dpnp_array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def __bool__(self):
242242
def __complex__(self):
243243
return self._array_obj.__complex__()
244244

245-
# '__contains__',
245+
def __contains__(self, key, /):
246+
"""Return :math:`key in self`."""
247+
return (self == key).any()
246248

247249
def __copy__(self):
248250
"""

dpnp/tests/test_ndarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,39 @@ def test_attributes(self):
7474
assert_equal(self.two.itemsize, self.two.dtype.itemsize)
7575

7676

77+
@testing.parameterize(*testing.product({"xp": [dpnp, numpy]}))
78+
class TestContains:
79+
def test_basic(self):
80+
a = self.xp.arange(10).reshape((2, 5))
81+
assert 4 in a
82+
assert 20 not in a
83+
84+
def test_broadcast(self):
85+
xp = self.xp
86+
a = xp.arange(6).reshape((2, 3))
87+
assert 4 in a
88+
assert xp.array([0, 1, 2]) in a
89+
assert xp.array([5, 3, 4]) not in a
90+
91+
def test_broadcast_error(self):
92+
a = self.xp.arange(10).reshape((2, 5))
93+
with pytest.raises(
94+
ValueError,
95+
match="operands could not be broadcast together with shapes",
96+
):
97+
self.xp.array([1, 2]) in a
98+
99+
def test_strides(self):
100+
xp = self.xp
101+
a = xp.arange(10).reshape((2, 5))
102+
a = a[:, ::2]
103+
assert 4 in a
104+
assert 8 not in a
105+
assert xp.full(a.shape[-1], fill_value=2) in a
106+
assert xp.full_like(a, fill_value=7) in a
107+
assert xp.full_like(a, fill_value=6) not in a
108+
109+
77110
class TestView:
78111
def test_none_dtype(self):
79112
a = numpy.ones((1, 2, 4), dtype=numpy.int32)

0 commit comments

Comments
 (0)