|
26 | 26 | from __future__ import annotations |
27 | 27 |
|
28 | 28 | import types |
29 | | -from collections import namedtuple |
30 | | -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union |
| 29 | +from enum import Enum as EnumBase |
| 30 | +from typing import Any, Self, TypeVar, Union |
| 31 | + |
| 32 | +E = TypeVar("E", bound="Enum") |
31 | 33 |
|
32 | 34 | __all__ = ( |
33 | 35 | "Enum", |
|
84 | 86 | ) |
85 | 87 |
|
86 | 88 |
|
87 | | -def _create_value_cls(name, comparable): |
88 | | - cls = namedtuple(f"_EnumValue_{name}", "name value") |
89 | | - cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" |
90 | | - cls.__str__ = lambda self: f"{name}.{self.name}" |
91 | | - if comparable: |
92 | | - cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value |
93 | | - cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value |
94 | | - cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value |
95 | | - cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value |
96 | | - return cls |
97 | | - |
98 | | - |
99 | | -def _is_descriptor(obj): |
100 | | - return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") |
101 | | - |
102 | | - |
103 | | -class EnumMeta(type): |
104 | | - if TYPE_CHECKING: |
105 | | - __name__: ClassVar[str] |
106 | | - _enum_member_names_: ClassVar[list[str]] |
107 | | - _enum_member_map_: ClassVar[dict[str, Any]] |
108 | | - _enum_value_map_: ClassVar[dict[Any, Any]] |
109 | | - |
110 | | - def __new__(cls, name, bases, attrs, *, comparable: bool = False): |
111 | | - value_mapping = {} |
112 | | - member_mapping = {} |
113 | | - member_names = [] |
114 | | - |
115 | | - value_cls = _create_value_cls(name, comparable) |
116 | | - for key, value in list(attrs.items()): |
117 | | - is_descriptor = _is_descriptor(value) |
118 | | - if key[0] == "_" and not is_descriptor: |
119 | | - continue |
120 | | - |
121 | | - # Special case classmethod to just pass through |
122 | | - if isinstance(value, classmethod): |
123 | | - continue |
124 | | - |
125 | | - if is_descriptor: |
126 | | - setattr(value_cls, key, value) |
127 | | - del attrs[key] |
128 | | - continue |
129 | | - |
130 | | - try: |
131 | | - new_value = value_mapping[value] |
132 | | - except KeyError: |
133 | | - new_value = value_cls(name=key, value=value) |
134 | | - value_mapping[value] = new_value |
135 | | - member_names.append(key) |
136 | | - |
137 | | - member_mapping[key] = new_value |
138 | | - attrs[key] = new_value |
139 | | - |
140 | | - attrs["_enum_value_map_"] = value_mapping |
141 | | - attrs["_enum_member_map_"] = member_mapping |
142 | | - attrs["_enum_member_names_"] = member_names |
143 | | - attrs["_enum_value_cls_"] = value_cls |
144 | | - actual_cls = super().__new__(cls, name, bases, attrs) |
145 | | - value_cls._actual_enum_cls_ = actual_cls # type: ignore |
146 | | - return actual_cls |
147 | | - |
148 | | - def __iter__(cls): |
149 | | - return (cls._enum_member_map_[name] for name in cls._enum_member_names_) |
150 | | - |
151 | | - def __reversed__(cls): |
152 | | - return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) |
153 | | - |
154 | | - def __len__(cls): |
155 | | - return len(cls._enum_member_names_) |
| 89 | +class Enum(EnumBase): |
| 90 | + """An :class:`enum.Enum` subclass that implements a missing value creation behavior if it is |
| 91 | + not present in any of the members of it. |
| 92 | + """ |
156 | 93 |
|
157 | | - def __repr__(cls): |
158 | | - return f"<enum {cls.__name__}>" |
| 94 | + def __init_subclass__(cls, *, comparable: bool = False) -> None: |
| 95 | + super().__init_subclass__() |
159 | 96 |
|
160 | | - @property |
161 | | - def __members__(cls): |
162 | | - return types.MappingProxyType(cls._enum_member_map_) |
| 97 | + if comparable: |
163 | 98 |
|
164 | | - def __call__(cls, value): |
165 | | - try: |
166 | | - return cls._enum_value_map_[value] |
167 | | - except (KeyError, TypeError) as e: |
168 | | - raise ValueError(f"{value!r} is not a valid {cls.__name__}") from e |
| 99 | + def __lt__(self: Enum, other: object) -> bool: |
| 100 | + if not isinstance(other, cls): |
| 101 | + return NotImplemented |
| 102 | + return self.value < other.value |
169 | 103 |
|
170 | | - def __getitem__(cls, key): |
171 | | - return cls._enum_member_map_[key] |
| 104 | + def __gt__(self: Enum, other: object) -> bool: |
| 105 | + if not isinstance(other, cls): |
| 106 | + return NotImplemented |
| 107 | + return self.value > other.value |
172 | 108 |
|
173 | | - def __setattr__(cls, name, value): |
174 | | - raise TypeError("Enums are immutable.") |
| 109 | + def __le__(self: Enum, other: object) -> bool: |
| 110 | + if not isinstance(other, cls): |
| 111 | + return NotImplemented |
| 112 | + return self.value <= other.value |
175 | 113 |
|
176 | | - def __delattr__(cls, attr): |
177 | | - raise TypeError("Enums are immutable") |
| 114 | + def __ge__(self: Enum, other: object) -> bool: |
| 115 | + if not isinstance(other, cls): |
| 116 | + return NotImplemented |
| 117 | + return self.value >= other.value |
178 | 118 |
|
179 | | - def __instancecheck__(self, instance): |
180 | | - # isinstance(x, Y) |
181 | | - # -> __instancecheck__(Y, x) |
182 | | - try: |
183 | | - return instance._actual_enum_cls_ is self |
184 | | - except AttributeError: |
185 | | - return False |
| 119 | + cls.__lt__ = __lt__ |
| 120 | + cls.__gt__ = __gt__ |
| 121 | + cls.__le__ = __le__ |
| 122 | + cls.__ge__ = __ge__ |
186 | 123 |
|
| 124 | + @classmethod |
| 125 | + def _missing_(cls, value: Any) -> Self: |
| 126 | + name = f"unknown_{value}" |
| 127 | + if name in cls.__members__: |
| 128 | + return cls.__members__[name] |
187 | 129 |
|
188 | | -if TYPE_CHECKING: |
189 | | - from enum import Enum |
190 | | -else: |
| 130 | + # this creates the new unknown value member |
| 131 | + obj = object.__new__(cls) |
| 132 | + obj._name_ = name |
| 133 | + obj._value_ = value |
191 | 134 |
|
192 | | - class Enum(metaclass=EnumMeta): |
193 | | - @classmethod |
194 | | - def try_value(cls, value): |
195 | | - try: |
196 | | - return cls._enum_value_map_[value] |
197 | | - except (KeyError, TypeError): |
198 | | - return value |
| 135 | + # and adds it to the member mapping of this enum so we don't |
| 136 | + # create a different enum member value each time |
| 137 | + cls._member_map_[name] = obj |
| 138 | + cls._value2member_map_[value] = obj |
| 139 | + return obj |
199 | 140 |
|
200 | 141 |
|
201 | 142 | class ChannelType(Enum): |
@@ -1086,23 +1027,15 @@ class ApplicationCommandPermissionType(Enum): |
1086 | 1027 | user = 2 |
1087 | 1028 | channel = 3 |
1088 | 1029 |
|
1089 | | - |
1090 | | -T = TypeVar("T") |
1091 | | - |
1092 | | - |
1093 | 1030 | def create_unknown_value(cls: type[T], val: Any) -> T: |
1094 | 1031 | value_cls = cls._enum_value_cls_ # type: ignore |
1095 | 1032 | name = f"unknown_{val}" |
1096 | 1033 | return value_cls(name=name, value=val) |
1097 | 1034 |
|
1098 | 1035 |
|
1099 | | -def try_enum(cls: type[T], val: Any) -> T: |
| 1036 | +def try_enum(cls: type[E], val: Any) -> E: |
1100 | 1037 | """A function that tries to turn the value into enum ``cls``. |
1101 | 1038 |
|
1102 | 1039 | If it fails it returns a proxy invalid value instead. |
1103 | 1040 | """ |
1104 | | - |
1105 | | - try: |
1106 | | - return cls._enum_value_map_[val] # type: ignore |
1107 | | - except (KeyError, TypeError, AttributeError): |
1108 | | - return create_unknown_value(cls, val) |
| 1041 | + return cls(val) |
0 commit comments