diff --git a/changelog/1410.bugfix.rst b/changelog/1410.bugfix.rst new file mode 100644 index 0000000000..6c7216abd8 --- /dev/null +++ b/changelog/1410.bugfix.rst @@ -0,0 +1 @@ +:class:`AllowedMentions` now takes a Sequence rather than a List for users and roles in order to support covariance. diff --git a/disnake/mentions.py b/disnake/mentions.py index 3fcfe8506f..c79eaa7ccd 100644 --- a/disnake/mentions.py +++ b/disnake/mentions.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Union +from typing import TYPE_CHECKING, Any, List, Sequence, Union, cast from .enums import MessageType @@ -69,13 +69,22 @@ def __init__( self, *, everyone: bool = default, - users: Union[bool, List[Snowflake]] = default, - roles: Union[bool, List[Snowflake]] = default, + users: Union[bool, Sequence[Snowflake]] = default, + roles: Union[bool, Sequence[Snowflake]] = default, replied_user: bool = default, ) -> None: self.everyone = everyone - self.users = users - self.roles = roles + # TODO(3.0): annotate attributes as `Sequence` instead of copying to list + self.users: Union[bool, List[Snowflake]] + self.roles: Union[bool, List[Snowflake]] + if users is default or isinstance(users, bool): + self.users = cast("bool", users) + else: + self.users = list(users) + if roles is default or isinstance(roles, bool): + self.roles = cast("bool", roles) + else: + self.roles = list(roles) self.replied_user = replied_user @classmethod @@ -108,8 +117,8 @@ def from_message(cls, message: Message) -> Self: return cls( everyone=message.mention_everyone, - users=message.mentions.copy(), # type: ignore # mentions is a list of Snowflakes - roles=message.role_mentions.copy(), # type: ignore # mentions is a list of Snowflakes + users=list(message.mentions), + roles=list(message.role_mentions), replied_user=bool( message.type is MessageType.reply and message.reference @@ -119,12 +128,14 @@ def from_message(cls, message: Message) -> Self: ) def to_dict(self) -> AllowedMentionsPayload: + # n.b. this runs nearly every time a message is sent parse: List[AllowedMentionTypePayload] = [] data: AllowedMentionsPayload = {} # type: ignore if self.everyone: parse.append("everyone") + # n.b. not using is True/False on account of _FakeBool if self.users == True: # noqa: E712 parse.append("users") elif self.users != False: # noqa: E712 diff --git a/tests/test_mentions.py b/tests/test_mentions.py index ff5d01b452..2aed35df21 100644 --- a/tests/test_mentions.py +++ b/tests/test_mentions.py @@ -109,6 +109,7 @@ def test_from_message_replied_user() -> None: message = mock.Mock(Message) author = Object(123) message.mentions = [author] + message.role_mentions = [] assert AllowedMentions.from_message(message).replied_user is False message.type = MessageType.reply