diff --git a/CHANGELOG-V3.md b/CHANGELOG-V3.md index 525f1c475a..038e1e9a0f 100644 --- a/CHANGELOG-V3.md +++ b/CHANGELOG-V3.md @@ -9,6 +9,8 @@ release. ### Changed +- Removed the custom `enums.Enum` implementation in favor of a stdlib `enum.Enum` subclass. + ### Deprecated ### Removed diff --git a/discord/enums.py b/discord/enums.py index 1488be1843..3862c7149d 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -26,8 +26,10 @@ from __future__ import annotations import types -from collections import namedtuple -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union +from enum import Enum as EnumBase +from typing import Any, Self, TypeVar, Union + +E = TypeVar("E", bound="Enum") __all__ = ( "Enum", @@ -83,118 +85,36 @@ ) -def _create_value_cls(name, comparable): - cls = namedtuple(f"_EnumValue_{name}", "name value") - cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" - cls.__str__ = lambda self: f"{name}.{self.name}" - if comparable: - cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value - cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value - cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value - cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value - return cls - - -def _is_descriptor(obj): - return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") - - -class EnumMeta(type): - if TYPE_CHECKING: - __name__: ClassVar[str] - _enum_member_names_: ClassVar[list[str]] - _enum_member_map_: ClassVar[dict[str, Any]] - _enum_value_map_: ClassVar[dict[Any, Any]] - - def __new__(cls, name, bases, attrs, *, comparable: bool = False): - value_mapping = {} - member_mapping = {} - member_names = [] - - value_cls = _create_value_cls(name, comparable) - for key, value in list(attrs.items()): - is_descriptor = _is_descriptor(value) - if key[0] == "_" and not is_descriptor: - continue - - # Special case classmethod to just pass through - if isinstance(value, classmethod): - continue - - if is_descriptor: - setattr(value_cls, key, value) - del attrs[key] - continue - - try: - new_value = value_mapping[value] - except KeyError: - new_value = value_cls(name=key, value=value) - value_mapping[value] = new_value - member_names.append(key) - - member_mapping[key] = new_value - attrs[key] = new_value - - attrs["_enum_value_map_"] = value_mapping - attrs["_enum_member_map_"] = member_mapping - attrs["_enum_member_names_"] = member_names - attrs["_enum_value_cls_"] = value_cls - actual_cls = super().__new__(cls, name, bases, attrs) - value_cls._actual_enum_cls_ = actual_cls # type: ignore - return actual_cls - - def __iter__(cls): - return (cls._enum_member_map_[name] for name in cls._enum_member_names_) - - def __reversed__(cls): - return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) - - def __len__(cls): - return len(cls._enum_member_names_) - - def __repr__(cls): - return f"" - - @property - def __members__(cls): - return types.MappingProxyType(cls._enum_member_map_) - - def __call__(cls, value): - try: - return cls._enum_value_map_[value] - except (KeyError, TypeError) as e: - raise ValueError(f"{value!r} is not a valid {cls.__name__}") from e - - def __getitem__(cls, key): - return cls._enum_member_map_[key] - - def __setattr__(cls, name, value): - raise TypeError("Enums are immutable.") +class Enum(EnumBase): + """An :class:`enum.Enum` subclass that implements a missing value creation behavior if it is + not present in any of the members of it. + """ - def __delattr__(cls, attr): - raise TypeError("Enums are immutable") + def __init_subclass__(cls, *, comparable: bool = False) -> None: + super().__init_subclass__() - def __instancecheck__(self, instance): - # isinstance(x, Y) - # -> __instancecheck__(Y, x) - try: - return instance._actual_enum_cls_ is self - except AttributeError: - return False + if comparable is True: + cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value + cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value + cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value + cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value + @classmethod + def _missing_(cls, value: Any) -> Self: + name = f"unknown_{value}" + if name in cls.__members__: + return cls.__members__[name] -if TYPE_CHECKING: - from enum import Enum -else: + # this creates the new unknown value member + obj = object.__new__(cls) + obj._name_ = name + obj._value_ = value - class Enum(metaclass=EnumMeta): - @classmethod - def try_value(cls, value): - try: - return cls._enum_value_map_[value] - except (KeyError, TypeError): - return value + # and adds it to the member mapping of this enum so we don't + # create a different enum member value each time + cls._member_map_[name] = obj + cls._value2member_map_[value] = obj + return obj class ChannelType(Enum): @@ -1078,22 +998,9 @@ def __int__(self): return self.value -T = TypeVar("T") - - -def create_unknown_value(cls: type[T], val: Any) -> T: - value_cls = cls._enum_value_cls_ # type: ignore - name = f"unknown_{val}" - return value_cls(name=name, value=val) - - -def try_enum(cls: type[T], val: Any) -> T: +def try_enum(cls: type[E], val: Any) -> E: """A function that tries to turn the value into enum ``cls``. If it fails it returns a proxy invalid value instead. """ - - try: - return cls._enum_value_map_[val] # type: ignore - except (KeyError, TypeError, AttributeError): - return create_unknown_value(cls, val) + return cls(val)