Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG-V3.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ release.

### Changed

- Removed the custom `enums.Enum` implementation in favor of a stdlib `enum.Enum` subclass.

### Deprecated

### Removed
Expand Down
155 changes: 31 additions & 124 deletions discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from __future__ import annotations

import types
from collections import namedtuple
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union
from enum import Enum as EnumBase
from typing import Any, Self, TypeVar, Union

E = TypeVar("E", bound="Enum")

__all__ = (
"Enum",
Expand Down Expand Up @@ -83,118 +85,36 @@
)


def _create_value_cls(name, comparable):
cls = namedtuple(f"_EnumValue_{name}", "name value")
cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
cls.__str__ = lambda self: f"{name}.{self.name}"
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls


def _is_descriptor(obj):
return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")


class EnumMeta(type):
if TYPE_CHECKING:
__name__: ClassVar[str]
_enum_member_names_: ClassVar[list[str]]
_enum_member_map_: ClassVar[dict[str, Any]]
_enum_value_map_: ClassVar[dict[Any, Any]]

def __new__(cls, name, bases, attrs, *, comparable: bool = False):
value_mapping = {}
member_mapping = {}
member_names = []

value_cls = _create_value_cls(name, comparable)
for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value)
if key[0] == "_" and not is_descriptor:
continue

# Special case classmethod to just pass through
if isinstance(value, classmethod):
continue

if is_descriptor:
setattr(value_cls, key, value)
del attrs[key]
continue

try:
new_value = value_mapping[value]
except KeyError:
new_value = value_cls(name=key, value=value)
value_mapping[value] = new_value
member_names.append(key)

member_mapping[key] = new_value
attrs[key] = new_value

attrs["_enum_value_map_"] = value_mapping
attrs["_enum_member_map_"] = member_mapping
attrs["_enum_member_names_"] = member_names
attrs["_enum_value_cls_"] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore
return actual_cls

def __iter__(cls):
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)

def __reversed__(cls):
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))

def __len__(cls):
return len(cls._enum_member_names_)

def __repr__(cls):
return f"<enum {cls.__name__}>"

@property
def __members__(cls):
return types.MappingProxyType(cls._enum_member_map_)

def __call__(cls, value):
try:
return cls._enum_value_map_[value]
except (KeyError, TypeError) as e:
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from e

def __getitem__(cls, key):
return cls._enum_member_map_[key]

def __setattr__(cls, name, value):
raise TypeError("Enums are immutable.")
class Enum(EnumBase):
"""An :class:`enum.Enum` subclass that implements a missing value creation behavior if it is
not present in any of the members of it.
"""

def __delattr__(cls, attr):
raise TypeError("Enums are immutable")
def __init_subclass__(cls, *, comparable: bool = False) -> None:
super().__init_subclass__()

def __instancecheck__(self, instance):
# isinstance(x, Y)
# -> __instancecheck__(Y, x)
try:
return instance._actual_enum_cls_ is self
except AttributeError:
return False
if comparable is True:
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value

@classmethod
def _missing_(cls, value: Any) -> Self:
name = f"unknown_{value}"
if name in cls.__members__:
return cls.__members__[name]

if TYPE_CHECKING:
from enum import Enum
else:
# this creates the new unknown value member
obj = object.__new__(cls)
obj._name_ = name
obj._value_ = value

class Enum(metaclass=EnumMeta):
@classmethod
def try_value(cls, value):
try:
return cls._enum_value_map_[value]
except (KeyError, TypeError):
return value
# and adds it to the member mapping of this enum so we don't
# create a different enum member value each time
cls._member_map_[name] = obj
cls._value2member_map_[value] = obj
return obj


class ChannelType(Enum):
Expand Down Expand Up @@ -1078,22 +998,9 @@ def __int__(self):
return self.value


T = TypeVar("T")


def create_unknown_value(cls: type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore
name = f"unknown_{val}"
return value_cls(name=name, value=val)


def try_enum(cls: type[T], val: Any) -> T:
def try_enum(cls: type[E], val: Any) -> E:
"""A function that tries to turn the value into enum ``cls``.

If it fails it returns a proxy invalid value instead.
"""

try:
return cls._enum_value_map_[val] # type: ignore
except (KeyError, TypeError, AttributeError):
return create_unknown_value(cls, val)
return cls(val)