|
26 | 26 | from __future__ import annotations |
27 | 27 |
|
28 | 28 | import types |
29 | | -from collections import namedtuple |
30 | | -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union |
| 29 | +from enum import Enum as EnumBase |
| 30 | +from typing import Any, Self, TypeVar, Union |
| 31 | + |
| 32 | +E = TypeVar("E", bound="Enum") |
31 | 33 |
|
32 | 34 | __all__ = ( |
33 | 35 | "Enum", |
|
83 | 85 | ) |
84 | 86 |
|
85 | 87 |
|
86 | | -def _create_value_cls(name, comparable): |
87 | | - cls = namedtuple(f"_EnumValue_{name}", "name value") |
88 | | - cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" |
89 | | - cls.__str__ = lambda self: f"{name}.{self.name}" |
90 | | - if comparable: |
91 | | - cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value |
92 | | - cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value |
93 | | - cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value |
94 | | - cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value |
95 | | - return cls |
96 | | - |
97 | | - |
98 | | -def _is_descriptor(obj): |
99 | | - return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") |
100 | | - |
101 | | - |
102 | | -class EnumMeta(type): |
103 | | - if TYPE_CHECKING: |
104 | | - __name__: ClassVar[str] |
105 | | - _enum_member_names_: ClassVar[list[str]] |
106 | | - _enum_member_map_: ClassVar[dict[str, Any]] |
107 | | - _enum_value_map_: ClassVar[dict[Any, Any]] |
108 | | - |
109 | | - def __new__(cls, name, bases, attrs, *, comparable: bool = False): |
110 | | - value_mapping = {} |
111 | | - member_mapping = {} |
112 | | - member_names = [] |
113 | | - |
114 | | - value_cls = _create_value_cls(name, comparable) |
115 | | - for key, value in list(attrs.items()): |
116 | | - is_descriptor = _is_descriptor(value) |
117 | | - if key[0] == "_" and not is_descriptor: |
118 | | - continue |
119 | | - |
120 | | - # Special case classmethod to just pass through |
121 | | - if isinstance(value, classmethod): |
122 | | - continue |
123 | | - |
124 | | - if is_descriptor: |
125 | | - setattr(value_cls, key, value) |
126 | | - del attrs[key] |
127 | | - continue |
128 | | - |
129 | | - try: |
130 | | - new_value = value_mapping[value] |
131 | | - except KeyError: |
132 | | - new_value = value_cls(name=key, value=value) |
133 | | - value_mapping[value] = new_value |
134 | | - member_names.append(key) |
135 | | - |
136 | | - member_mapping[key] = new_value |
137 | | - attrs[key] = new_value |
138 | | - |
139 | | - attrs["_enum_value_map_"] = value_mapping |
140 | | - attrs["_enum_member_map_"] = member_mapping |
141 | | - attrs["_enum_member_names_"] = member_names |
142 | | - attrs["_enum_value_cls_"] = value_cls |
143 | | - actual_cls = super().__new__(cls, name, bases, attrs) |
144 | | - value_cls._actual_enum_cls_ = actual_cls # type: ignore |
145 | | - return actual_cls |
146 | | - |
147 | | - def __iter__(cls): |
148 | | - return (cls._enum_member_map_[name] for name in cls._enum_member_names_) |
149 | | - |
150 | | - def __reversed__(cls): |
151 | | - return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) |
152 | | - |
153 | | - def __len__(cls): |
154 | | - return len(cls._enum_member_names_) |
| 88 | +class Enum(EnumBase): |
| 89 | + """An :class:`enum.Enum` subclass that implements a missing value creation behavior if it is |
| 90 | + not present in any of the members of it. |
| 91 | + """ |
155 | 92 |
|
156 | | - def __repr__(cls): |
157 | | - return f"<enum {cls.__name__}>" |
| 93 | + def __init_subclass__(cls, *, comparable: bool = False) -> None: |
| 94 | + super().__init_subclass__() |
158 | 95 |
|
159 | | - @property |
160 | | - def __members__(cls): |
161 | | - return types.MappingProxyType(cls._enum_member_map_) |
| 96 | + if comparable: |
162 | 97 |
|
163 | | - def __call__(cls, value): |
164 | | - try: |
165 | | - return cls._enum_value_map_[value] |
166 | | - except (KeyError, TypeError) as e: |
167 | | - raise ValueError(f"{value!r} is not a valid {cls.__name__}") from e |
| 98 | + def __lt__(self: Enum, other: object) -> bool: |
| 99 | + if not isinstance(other, cls): |
| 100 | + return NotImplemented |
| 101 | + return self.value < other.value |
168 | 102 |
|
169 | | - def __getitem__(cls, key): |
170 | | - return cls._enum_member_map_[key] |
| 103 | + def __gt__(self: Enum, other: object) -> bool: |
| 104 | + if not isinstance(other, cls): |
| 105 | + return NotImplemented |
| 106 | + return self.value > other.value |
171 | 107 |
|
172 | | - def __setattr__(cls, name, value): |
173 | | - raise TypeError("Enums are immutable.") |
| 108 | + def __le__(self: Enum, other: object) -> bool: |
| 109 | + if not isinstance(other, cls): |
| 110 | + return NotImplemented |
| 111 | + return self.value <= other.value |
174 | 112 |
|
175 | | - def __delattr__(cls, attr): |
176 | | - raise TypeError("Enums are immutable") |
| 113 | + def __ge__(self: Enum, other: object) -> bool: |
| 114 | + if not isinstance(other, cls): |
| 115 | + return NotImplemented |
| 116 | + return self.value >= other.value |
177 | 117 |
|
178 | | - def __instancecheck__(self, instance): |
179 | | - # isinstance(x, Y) |
180 | | - # -> __instancecheck__(Y, x) |
181 | | - try: |
182 | | - return instance._actual_enum_cls_ is self |
183 | | - except AttributeError: |
184 | | - return False |
| 118 | + cls.__lt__ = __lt__ |
| 119 | + cls.__gt__ = __gt__ |
| 120 | + cls.__le__ = __le__ |
| 121 | + cls.__ge__ = __ge__ |
185 | 122 |
|
| 123 | + @classmethod |
| 124 | + def _missing_(cls, value: Any) -> Self: |
| 125 | + name = f"unknown_{value}" |
| 126 | + if name in cls.__members__: |
| 127 | + return cls.__members__[name] |
186 | 128 |
|
187 | | -if TYPE_CHECKING: |
188 | | - from enum import Enum |
189 | | -else: |
| 129 | + # this creates the new unknown value member |
| 130 | + obj = object.__new__(cls) |
| 131 | + obj._name_ = name |
| 132 | + obj._value_ = value |
190 | 133 |
|
191 | | - class Enum(metaclass=EnumMeta): |
192 | | - @classmethod |
193 | | - def try_value(cls, value): |
194 | | - try: |
195 | | - return cls._enum_value_map_[value] |
196 | | - except (KeyError, TypeError): |
197 | | - return value |
| 134 | + # and adds it to the member mapping of this enum so we don't |
| 135 | + # create a different enum member value each time |
| 136 | + cls._member_map_[name] = obj |
| 137 | + cls._value2member_map_[value] = obj |
| 138 | + return obj |
198 | 139 |
|
199 | 140 |
|
200 | 141 | class ChannelType(Enum): |
@@ -1078,22 +1019,9 @@ def __int__(self): |
1078 | 1019 | return self.value |
1079 | 1020 |
|
1080 | 1021 |
|
1081 | | -T = TypeVar("T") |
1082 | | - |
1083 | | - |
1084 | | -def create_unknown_value(cls: type[T], val: Any) -> T: |
1085 | | - value_cls = cls._enum_value_cls_ # type: ignore |
1086 | | - name = f"unknown_{val}" |
1087 | | - return value_cls(name=name, value=val) |
1088 | | - |
1089 | | - |
1090 | | -def try_enum(cls: type[T], val: Any) -> T: |
| 1022 | +def try_enum(cls: type[E], val: Any) -> E: |
1091 | 1023 | """A function that tries to turn the value into enum ``cls``. |
1092 | 1024 |
|
1093 | 1025 | If it fails it returns a proxy invalid value instead. |
1094 | 1026 | """ |
1095 | | - |
1096 | | - try: |
1097 | | - return cls._enum_value_map_[val] # type: ignore |
1098 | | - except (KeyError, TypeError, AttributeError): |
1099 | | - return create_unknown_value(cls, val) |
| 1027 | + return cls(val) |
0 commit comments