|
26 | 26 | from __future__ import annotations
|
27 | 27 |
|
28 | 28 | import types
|
29 |
| -from collections import namedtuple |
30 |
| -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union |
| 29 | +from enum import Enum as EnumBase |
| 30 | +from typing import Any, Self, TypeVar, Union |
| 31 | + |
| 32 | +E = TypeVar("E", bound="Enum") |
31 | 33 |
|
32 | 34 | __all__ = (
|
33 | 35 | "Enum",
|
|
83 | 85 | )
|
84 | 86 |
|
85 | 87 |
|
86 |
| -def _create_value_cls(name, comparable): |
87 |
| - cls = namedtuple(f"_EnumValue_{name}", "name value") |
88 |
| - cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" |
89 |
| - cls.__str__ = lambda self: f"{name}.{self.name}" |
90 |
| - if comparable: |
91 |
| - cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value |
92 |
| - cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value |
93 |
| - cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value |
94 |
| - cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value |
95 |
| - return cls |
96 |
| - |
97 |
| - |
98 |
| -def _is_descriptor(obj): |
99 |
| - return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") |
100 |
| - |
101 |
| - |
102 |
| -class EnumMeta(type): |
103 |
| - if TYPE_CHECKING: |
104 |
| - __name__: ClassVar[str] |
105 |
| - _enum_member_names_: ClassVar[list[str]] |
106 |
| - _enum_member_map_: ClassVar[dict[str, Any]] |
107 |
| - _enum_value_map_: ClassVar[dict[Any, Any]] |
108 |
| - |
109 |
| - def __new__(cls, name, bases, attrs, *, comparable: bool = False): |
110 |
| - value_mapping = {} |
111 |
| - member_mapping = {} |
112 |
| - member_names = [] |
113 |
| - |
114 |
| - value_cls = _create_value_cls(name, comparable) |
115 |
| - for key, value in list(attrs.items()): |
116 |
| - is_descriptor = _is_descriptor(value) |
117 |
| - if key[0] == "_" and not is_descriptor: |
118 |
| - continue |
119 |
| - |
120 |
| - # Special case classmethod to just pass through |
121 |
| - if isinstance(value, classmethod): |
122 |
| - continue |
123 |
| - |
124 |
| - if is_descriptor: |
125 |
| - setattr(value_cls, key, value) |
126 |
| - del attrs[key] |
127 |
| - continue |
128 |
| - |
129 |
| - try: |
130 |
| - new_value = value_mapping[value] |
131 |
| - except KeyError: |
132 |
| - new_value = value_cls(name=key, value=value) |
133 |
| - value_mapping[value] = new_value |
134 |
| - member_names.append(key) |
135 |
| - |
136 |
| - member_mapping[key] = new_value |
137 |
| - attrs[key] = new_value |
138 |
| - |
139 |
| - attrs["_enum_value_map_"] = value_mapping |
140 |
| - attrs["_enum_member_map_"] = member_mapping |
141 |
| - attrs["_enum_member_names_"] = member_names |
142 |
| - attrs["_enum_value_cls_"] = value_cls |
143 |
| - actual_cls = super().__new__(cls, name, bases, attrs) |
144 |
| - value_cls._actual_enum_cls_ = actual_cls # type: ignore |
145 |
| - return actual_cls |
146 |
| - |
147 |
| - def __iter__(cls): |
148 |
| - return (cls._enum_member_map_[name] for name in cls._enum_member_names_) |
149 |
| - |
150 |
| - def __reversed__(cls): |
151 |
| - return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) |
152 |
| - |
153 |
| - def __len__(cls): |
154 |
| - return len(cls._enum_member_names_) |
155 |
| - |
156 |
| - def __repr__(cls): |
157 |
| - return f"<enum {cls.__name__}>" |
158 |
| - |
159 |
| - @property |
160 |
| - def __members__(cls): |
161 |
| - return types.MappingProxyType(cls._enum_member_map_) |
162 |
| - |
163 |
| - def __call__(cls, value): |
164 |
| - try: |
165 |
| - return cls._enum_value_map_[value] |
166 |
| - except (KeyError, TypeError) as e: |
167 |
| - raise ValueError(f"{value!r} is not a valid {cls.__name__}") from e |
168 |
| - |
169 |
| - def __getitem__(cls, key): |
170 |
| - return cls._enum_member_map_[key] |
171 |
| - |
172 |
| - def __setattr__(cls, name, value): |
173 |
| - raise TypeError("Enums are immutable.") |
| 88 | +class Enum(EnumBase): |
| 89 | + """An :class:`enum.Enum` subclass that implements a missing value creation behavior if it is |
| 90 | + not present in any of the members of it. |
| 91 | + """ |
174 | 92 |
|
175 |
| - def __delattr__(cls, attr): |
176 |
| - raise TypeError("Enums are immutable") |
| 93 | + def __init_subclass__(cls, *, comparable: bool = False) -> None: |
| 94 | + super().__init_subclass__() |
177 | 95 |
|
178 |
| - def __instancecheck__(self, instance): |
179 |
| - # isinstance(x, Y) |
180 |
| - # -> __instancecheck__(Y, x) |
181 |
| - try: |
182 |
| - return instance._actual_enum_cls_ is self |
183 |
| - except AttributeError: |
184 |
| - return False |
| 96 | + if comparable is True: |
| 97 | + cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value |
| 98 | + cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value |
| 99 | + cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value |
| 100 | + cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value |
185 | 101 |
|
| 102 | + @classmethod |
| 103 | + def _missing_(cls, value: Any) -> Self: |
| 104 | + name = f"unknown_{value}" |
| 105 | + if name in cls.__members__: |
| 106 | + return cls.__members__[name] |
186 | 107 |
|
187 |
| -if TYPE_CHECKING: |
188 |
| - from enum import Enum |
189 |
| -else: |
| 108 | + # this creates the new unknown value member |
| 109 | + obj = object.__new__(cls) |
| 110 | + obj._name_ = name |
| 111 | + obj._value_ = value |
190 | 112 |
|
191 |
| - class Enum(metaclass=EnumMeta): |
192 |
| - @classmethod |
193 |
| - def try_value(cls, value): |
194 |
| - try: |
195 |
| - return cls._enum_value_map_[value] |
196 |
| - except (KeyError, TypeError): |
197 |
| - return value |
| 113 | + # and adds it to the member mapping of this enum so we don't |
| 114 | + # create a different enum member value each time |
| 115 | + cls._member_map_[name] = obj |
| 116 | + cls._value2member_map_[value] = obj |
| 117 | + return obj |
198 | 118 |
|
199 | 119 |
|
200 | 120 | class ChannelType(Enum):
|
@@ -1078,22 +998,9 @@ def __int__(self):
|
1078 | 998 | return self.value
|
1079 | 999 |
|
1080 | 1000 |
|
1081 |
| -T = TypeVar("T") |
1082 |
| - |
1083 |
| - |
1084 |
| -def create_unknown_value(cls: type[T], val: Any) -> T: |
1085 |
| - value_cls = cls._enum_value_cls_ # type: ignore |
1086 |
| - name = f"unknown_{val}" |
1087 |
| - return value_cls(name=name, value=val) |
1088 |
| - |
1089 |
| - |
1090 |
| -def try_enum(cls: type[T], val: Any) -> T: |
| 1001 | +def try_enum(cls: type[E], val: Any) -> E: |
1091 | 1002 | """A function that tries to turn the value into enum ``cls``.
|
1092 | 1003 |
|
1093 | 1004 | If it fails it returns a proxy invalid value instead.
|
1094 | 1005 | """
|
1095 |
| - |
1096 |
| - try: |
1097 |
| - return cls._enum_value_map_[val] # type: ignore |
1098 |
| - except (KeyError, TypeError, AttributeError): |
1099 |
| - return create_unknown_value(cls, val) |
| 1006 | + return cls(val) |
0 commit comments