|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from collections import OrderedDict |
2 | | -from collections.abc import Hashable, Iterable, MutableSet |
| 4 | +from collections.abc import Hashable, Iterable, Iterator, MutableSet |
3 | 5 | from typing import TypeVar |
4 | 6 |
|
5 | | -_KeyType = TypeVar("_KeyType", bound=Hashable) |
| 7 | +_T = TypeVar("_T", bound=Hashable) |
6 | 8 |
|
7 | 9 |
|
8 | | -class OrderedSet(OrderedDict[_KeyType, None], MutableSet[_KeyType]): |
| 10 | +class OrderedSet(MutableSet[_T]): |
9 | 11 | """Ordered collection of distinct elements.""" |
10 | 12 |
|
11 | | - def __init__(self, elements: Iterable[_KeyType]): |
12 | | - super().__init__([(element, None) for element in elements]) |
| 13 | + def __init__(self, elements: Iterable[_T]): |
| 14 | + super().__init__() |
| 15 | + self.ordered_dict = OrderedDict[_T, None]([(element, None) for element in elements]) |
13 | 16 |
|
14 | | - def difference_update(self, elements: set[_KeyType]) -> None: |
| 17 | + def difference_update(self, elements: set[_T]) -> None: |
15 | 18 | """Removes all specified elements from the OrderedSet.""" |
16 | 19 |
|
17 | 20 | for element in elements: |
18 | 21 | self.discard(element) |
19 | 22 |
|
20 | | - def add(self, element: _KeyType) -> None: |
| 23 | + def add(self, element: _T) -> None: |
21 | 24 | """Adds the specified element to the OrderedSet.""" |
22 | 25 |
|
23 | | - self[element] = None |
| 26 | + self.ordered_dict[element] = None |
24 | 27 |
|
25 | | - def __add__(self, other: "OrderedSet[_KeyType]") -> "OrderedSet[_KeyType]": |
| 28 | + def __add__(self, other: OrderedSet[_T]) -> OrderedSet[_T]: |
26 | 29 | """Creates a new OrderedSet with the elements of self followed by the elements of other.""" |
27 | 30 |
|
28 | 31 | return OrderedSet([*self, *other]) |
29 | 32 |
|
30 | | - def discard(self, value: _KeyType) -> None: |
| 33 | + def discard(self, value: _T) -> None: |
31 | 34 | if value in self: |
32 | | - del self[value] |
| 35 | + del self.ordered_dict[value] |
| 36 | + |
| 37 | + def __iter__(self) -> Iterator[_T]: |
| 38 | + return self.ordered_dict.__iter__() |
| 39 | + |
| 40 | + def __len__(self) -> int: |
| 41 | + return len(self.ordered_dict) |
| 42 | + |
| 43 | + def __contains__(self, element: object) -> bool: |
| 44 | + return element in self.ordered_dict |
0 commit comments