Skip to content

Commit 6790111

Browse files
authored
Add __contains__ to MapKeys (#99)
I ran into this when I tried to pass `.keys()` to something declared to take a `Collection`, and it failed because it lacked `__contains__`. (I didn't actually care about `__contains__`, but mypy did.) So I came to add it to the stubs and discovered that it *was* actually missing, and since the fallback is woefully slow I went and implemented it.
1 parent f797822 commit 6790111

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

immutables/_map.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2810,9 +2810,23 @@ map_new_items_view(MapObject *o)
28102810
/////////////////////////////////// _MapKeys_Type
28112811

28122812

2813+
static int
2814+
map_tp_contains(BaseMapObject *self, PyObject *key);
2815+
2816+
static int
2817+
_map_keys_tp_contains(MapView *self, PyObject *key)
2818+
{
2819+
return map_tp_contains((BaseMapObject *)self->mv_obj, key);
2820+
}
2821+
2822+
static PySequenceMethods _MapKeys_as_sequence = {
2823+
.sq_contains = (objobjproc)_map_keys_tp_contains,
2824+
};
2825+
28132826
PyTypeObject _MapKeys_Type = {
28142827
PyVarObject_HEAD_INIT(NULL, 0)
28152828
"keys",
2829+
.tp_as_sequence = &_MapKeys_as_sequence,
28162830
VIEW_TYPE_SHARED_SLOTS
28172831
};
28182832

immutables/_protocols.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
class MapKeys(Protocol[KT_co]):
3333
def __len__(self) -> int: ...
3434
def __iter__(self) -> Iterator[KT_co]: ...
35+
def __contains__(self, __key: object) -> bool: ...
3536

3637

3738
class MapValues(Protocol[VT_co]):

tests/test_map.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,10 @@ def test_kwarg_named_col(self):
13951395
self.assertEqual(dict(self.Map(a=0, col=1)), {"a": 0, "col": 1})
13961396
self.assertEqual(dict(self.Map({"a": 0}, col=1)), {"a": 0, "col": 1})
13971397

1398+
def test_map_keys_contains(self):
1399+
m = self.Map(foo="bar")
1400+
self.assertTrue("foo" in m.keys())
1401+
13981402

13991403
class PyMapTest(BaseMapTest, unittest.TestCase):
14001404

0 commit comments

Comments
 (0)