diff --git a/.github/workflows/docs-localization-download.yml b/.github/workflows/docs-localization-download.yml index a2201b5071..5bf56c1341 100644 --- a/.github/workflows/docs-localization-download.yml +++ b/.github/workflows/docs-localization-download.yml @@ -41,7 +41,7 @@ jobs: working-directory: ./docs - name: "Crowdin" id: crowdin - uses: crowdin/github-action@v2.7.1 + uses: crowdin/github-action@v2.9.0 with: upload_sources: false upload_translations: false diff --git a/.github/workflows/docs-localization-upload.yml b/.github/workflows/docs-localization-upload.yml index 837fd1e8de..300937f3a5 100644 --- a/.github/workflows/docs-localization-upload.yml +++ b/.github/workflows/docs-localization-upload.yml @@ -45,7 +45,7 @@ jobs: sphinx-intl update -p ./build/locales ${{ vars.SPHINX_LANGUAGES }} working-directory: ./docs - name: "Crowdin" - uses: crowdin/github-action@v2.7.1 + uses: crowdin/github-action@v2.9.0 with: upload_sources: true upload_translations: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e43197fe9e..fa31ed1af3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude: \.(po|pot|yml|yaml)$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.9 + rev: v0.12.0 hooks: - id: ruff args: [ --fix ] diff --git a/CHANGELOG-V3.md b/CHANGELOG-V3.md index eb94407d2e..525f1c475a 100644 --- a/CHANGELOG-V3.md +++ b/CHANGELOG-V3.md @@ -12,3 +12,17 @@ release. ### Deprecated ### Removed + +- `utils.filter_params` +- `utils.sleep_until` use `asyncio.sleep` combined with `datetime.datetime` instead +- `utils.compute_timedelta` use the `datetime` module instead +- `utils.resolve_invite` +- `utils.resolve_template` +- `utils.parse_time` use `datetime.datetime.fromisoformat` instead +- `utils.time_snowflake` use `utils.generate_snowflake` instead +- `utils.warn_deprecated` +- `utils.deprecated` +- `utils.get` use `utils.find` with `lambda i: i.attr == val` instead +- `AsyncIterator.get` use `AsyncIterator.find` with `lambda i: i.attr == val` instead +- `utils.as_chunks` use `itertools.batched` on Python 3.12+ or your own implementation + instead diff --git a/CHANGELOG.md b/CHANGELOG.md index fdfdc6d896..97c6a99cbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2781](https://github.com/Pycord-Development/pycord/pull/2781)) - Fixed `VoiceClient` crashing randomly while receiving audio ([#2800](https://github.com/Pycord-Development/pycord/pull/2800)) +- Fixed `VoiceClient.connect` failing to do initial connection. + ([#2812](https://github.com/Pycord-Development/pycord/pull/2812)) ### Changed @@ -138,6 +140,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2564](https://github.com/Pycord-Development/pycord/pull/2564)) - Changed the default value of `ApplicationCommand.nsfw` to `False`. ([#2797](https://github.com/Pycord-Development/pycord/pull/2797)) +- Upgraded voice websocket version to v8. + ([#2812](https://github.com/Pycord-Development/pycord/pull/2812)) ### Deprecated @@ -150,8 +154,8 @@ These changes are available on the `master` branch, but have not yet been releas ### Removed -- Removed deprecated support for `Option` in `BridgeCommand`. Use `BridgeOption` - instead. ([#2731])(https://github.com/Pycord-Development/pycord/pull/2731)) +- Removed deprecated support for `Option` in `BridgeCommand`, use `BridgeOption` + instead. ([#2731](https://github.com/Pycord-Development/pycord/pull/2731)) ## [2.6.1] - 2024-09-15 diff --git a/discord/__version.py b/discord/__version.py index 88b3df1b5c..f7020d1ad8 100644 --- a/discord/__version.py +++ b/discord/__version.py @@ -35,7 +35,7 @@ from typing import Literal, NamedTuple -from .utils import deprecated +from .utils.private import deprecated from ._version import __version__, __version_tuple__ diff --git a/discord/abc.py b/discord/abc.py index c59a6276be..acebd83e39 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -41,6 +41,7 @@ runtime_checkable, ) +from .utils.private import warn_deprecated from . import utils from .context_managers import Typing from .enums import ChannelType @@ -724,7 +725,7 @@ def permissions_for(self, obj: Member | Role, /) -> Permissions: if obj.is_default(): return base - overwrite = utils.get(self._overwrites, type=_Overwrites.ROLE, id=obj.id) + overwrite = utils.find(lambda o: o.type == _Overwrites.ROLE and o.id == obj.id, self._overwrites) if overwrite is not None: base.handle_overwrite(overwrite.allow, overwrite.deny) @@ -1529,7 +1530,7 @@ async def send( from .message import MessageReference # noqa: PLC0415 if not isinstance(reference, MessageReference): - utils.warn_deprecated( + warn_deprecated( f"Passing {type(reference).__name__} to reference", "MessageReference", "2.7", @@ -1545,6 +1546,10 @@ async def send( raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") components = view.to_components() + if view.is_components_v2(): + if embeds or content: + raise TypeError("cannot send embeds or content with a view using v2 component logic") + flags.is_components_v2 = True else: components = None @@ -1605,8 +1610,10 @@ async def send( ret = state.create_message(channel=channel, data=data) if view: - state.store_view(view, ret.id) + if view.is_dispatchable(): + state.store_view(view, ret.id) view.message = ret + view.refresh(ret.components) if delete_after is not None: await ret.delete(delay=delete_after) diff --git a/discord/activity.py b/discord/activity.py index 9cda6bc3f2..3377892aa8 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -32,7 +32,7 @@ from .colour import Colour from .enums import ActivityType, try_enum from .partial_emoji import PartialEmoji -from .utils import _get_as_snowflake +from .utils.private import get_as_snowflake __all__ = ( "BaseActivity", @@ -226,7 +226,7 @@ def __init__(self, **kwargs): self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {}) self.assets: ActivityAssets = kwargs.pop("assets", {}) self.party: ActivityParty = kwargs.pop("party", {}) - self.application_id: int | None = _get_as_snowflake(kwargs, "application_id") + self.application_id: int | None = get_as_snowflake(kwargs, "application_id") self.url: str | None = kwargs.pop("url", None) self.flags: int = kwargs.pop("flags", 0) self.sync_id: str | None = kwargs.pop("sync_id", None) diff --git a/discord/appinfo.py b/discord/appinfo.py index bfe7bc77c5..ab697c25a3 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING +from .utils.private import warn_deprecated, get_as_snowflake from . import utils from .asset import Asset from .permissions import Permissions @@ -200,9 +201,9 @@ def __init__(self, state: ConnectionState, data: AppInfoPayload): self._summary: str = data["summary"] self.verify_key: str = data["verify_key"] - self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + self.guild_id: int | None = get_as_snowflake(data, "guild_id") - self.primary_sku_id: int | None = utils._get_as_snowflake(data, "primary_sku_id") + self.primary_sku_id: int | None = get_as_snowflake(data, "primary_sku_id") self.slug: str | None = data.get("slug") self._cover_image: str | None = data.get("cover_image") self.terms_of_service_url: str | None = data.get("terms_of_service_url") @@ -261,7 +262,7 @@ def summary(self) -> str | None: .. versionadded:: 1.3 .. deprecated:: 2.7 """ - utils.warn_deprecated( + warn_deprecated( "summary", "description", reference="https://discord.com/developers/docs/resources/application#application-object-application-structure", diff --git a/discord/asset.py b/discord/asset.py index 3bc04faa7c..931e7bbcaf 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -47,6 +47,11 @@ MISSING = utils.MISSING +def _valid_icon_size(size: int) -> bool: + """Icons must be power of 2 within [16, 4096].""" + return not size & (size - 1) and 4096 >= size >= 16 + + class AssetMixin: url: str _state: Any | None @@ -371,7 +376,7 @@ def replace( url = url.with_path(f"{path}.{static_format}") if size is not MISSING: - if not utils.valid_icon_size(size): + if not _valid_icon_size(size): raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = url.with_query(size=size) else: @@ -398,7 +403,7 @@ def with_size(self, size: int, /) -> Asset: InvalidArgument The asset had an invalid size. """ - if not utils.valid_icon_size(size): + if not _valid_icon_size(size): raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = str(yarl.URL(self._url).with_query(size=size)) diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 65cbbfe6ec..068323f770 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -26,7 +26,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generator, TypeVar +from functools import cached_property +from .utils.private import get_as_snowflake from . import enums, utils from .asset import Asset from .automod import AutoModAction, AutoModTriggerMetadata @@ -559,8 +561,8 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: # into meaningful data when requested self._changes = data.get("changes", []) - self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore - self._target_id = utils._get_as_snowflake(data, "target_id") + self.user = self._get_member(get_as_snowflake(data, "user_id")) # type: ignore + self._target_id = get_as_snowflake(data, "target_id") def _get_member(self, user_id: int) -> Member | User | None: return self.guild.get_member(user_id) or self._users.get(user_id) @@ -568,12 +570,12 @@ def _get_member(self, user_id: int) -> Member | User | None: def __repr__(self) -> str: return f"" - @utils.cached_property + @cached_property def created_at(self) -> datetime.datetime: """Returns the entry's creation time in UTC.""" return utils.snowflake_time(self.id) - @utils.cached_property + @cached_property def target( self, ) -> ( @@ -597,24 +599,24 @@ def target( else: return converter(self._target_id) - @utils.cached_property + @property def category(self) -> enums.AuditLogActionCategory: """The category of the action, if applicable.""" return self.action.category - @utils.cached_property + @cached_property def changes(self) -> AuditLogChanges: """The list of changes this entry has.""" obj = AuditLogChanges(self, self._changes, state=self._state) del self._changes return obj - @utils.cached_property + @property def before(self) -> AuditLogDiff: """The target's prior state.""" return self.changes.before - @utils.cached_property + @property def after(self) -> AuditLogDiff: """The target's subsequent state.""" return self.changes.after diff --git a/discord/bot.py b/discord/bot.py index f22362350a..1399420544 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -63,7 +63,8 @@ from .shard import AutoShardedClient from .types import interactions from .user import User -from .utils import MISSING, async_all, find, get +from .utils import MISSING, find +from .utils.private import async_all if TYPE_CHECKING: from .member import Member @@ -216,13 +217,13 @@ def get_application_command( return command elif (names := name.split())[0] == command.name and isinstance(command, SlashCommandGroup): while len(names) > 1: - command = get(commands, name=names.pop(0)) + command = find(lambda c: c.name == names.pop(0), commands) if not isinstance(command, SlashCommandGroup) or ( guild_ids is not None and command.guild_ids != guild_ids ): return commands = command.subcommands - command = get(commands, name=names.pop()) + command = find(lambda c: c.name == names.pop(), commands) if not isinstance(command, type) or (guild_ids is not None and command.guild_ids != guild_ids): return return command @@ -357,7 +358,7 @@ def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: # Now let's see if there are any commands on discord that we need to delete for cmd, value_ in registered_commands_dict.items(): - match = get(pending, name=value_["name"]) + match = find(lambda c: c.name == value_["name"], pending) if match is None: # We have this command registered but not in our list return_value.append( @@ -516,7 +517,7 @@ def register( ) continue # We can assume the command item is a command, since it's only a string if action is delete - match = get(pending, name=cmd["command"].name, type=cmd["command"].type) + match = find(lambda c: c.name == cmd["command"].name and c.type == cmd["command"].type, pending) if match is None: continue if cmd["action"] == "edit": @@ -605,10 +606,9 @@ def register( registered = await register("bulk", data, guild_id=guild_id) for i in registered: - cmd = get( + cmd = find( + lambda c: c.name == i["name"] and c.type == i.get("type"), self.pending_application_commands, - name=i["name"], - type=i.get("type"), ) if not cmd: raise ValueError(f"Registered command {i['name']}, type {i.get('type')} not found in pending commands") @@ -712,11 +712,9 @@ async def on_connect(): registered_guild_commands[guild_id] = app_cmds for i in registered_commands: - cmd = get( + cmd = find( + lambda c: c.name == i["name"] and c.guild_ids is None and c.type == i.get("type"), self.pending_application_commands, - name=i["name"], - guild_ids=None, - type=i.get("type"), ) if cmd: cmd.id = i["id"] @@ -803,13 +801,13 @@ async def process_application_commands(self, interaction: Interaction, auto_sync ctx = await self.get_application_context(interaction) if command: - ctx.command = command + interaction.command = command await self.invoke_application_command(ctx) async def on_application_command_auto_complete(self, interaction: Interaction, command: ApplicationCommand) -> None: async def callback() -> None: ctx = await self.get_autocomplete_context(interaction) - ctx.command = command + interaction.command = command return await command.invoke_autocomplete_callback(ctx) autocomplete_task = self._bot.loop.create_task(callback()) diff --git a/discord/channel.py b/discord/channel.py index cff03c1614..1846c7c479 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -26,10 +26,20 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Mapping, + Sequence, + TypeVar, + overload, +) import discord.abc +from .utils.private import bytes_to_base64_data, get_as_snowflake, copy_doc from . import utils from .asset import Asset from .emoji import GuildEmoji @@ -45,7 +55,7 @@ ) from .errors import ClientException, InvalidArgument from .file import File -from .flags import ChannelFlags +from .flags import ChannelFlags, MessageFlags from .invite import Invite from .iterators import ArchivedThreadIterator from .mixins import Hashable @@ -71,12 +81,15 @@ if TYPE_CHECKING: from .abc import Snowflake, SnowflakeTime + from .embeds import Embed from .guild import Guild from .guild import GuildChannel as GuildChannelType from .member import Member, VoiceState + from .mentions import AllowedMentions from .message import EmojiInputType, Message, PartialMessage from .role import Role from .state import ConnectionState + from .sticker import GuildSticker, StickerItem from .types.channel import CategoryChannel as CategoryChannelPayload from .types.channel import DMChannel as DMChannelPayload from .types.channel import ForumChannel as ForumChannelPayload @@ -87,6 +100,7 @@ from .types.channel import VoiceChannel as VoiceChannelPayload from .types.snowflake import SnowflakeList from .types.threads import ThreadArchiveDuration + from .ui.view import View from .user import BaseUser, ClientUser, User from .webhook import Webhook @@ -157,7 +171,7 @@ def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> ForumTag self.moderated = data.get("moderated", False) emoji_name = data["emoji_name"] or "" - emoji_id = utils._get_as_snowflake(data, "emoji_id") or None + emoji_id = get_as_snowflake(data, "emoji_id") or None self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) return self @@ -219,7 +233,7 @@ def _update(self, guild: Guild, data: TextChannelPayload | ForumChannelPayload) # This data will always exist self.guild: Guild = guild self.name: str = data["name"] - self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self.category_id: int | None = get_as_snowflake(data, "parent_id") self._type: int = data["type"] # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): @@ -230,7 +244,7 @@ def _update(self, guild: Guild, data: TextChannelPayload | ForumChannelPayload) self.slowmode_delay: int = data.get("rate_limit_per_user", 0) self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440) self.default_thread_slowmode_delay: int | None = data.get("default_thread_rate_limit_per_user") - self.last_message_id: int | None = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: int | None = get_as_snowflake(data, "last_message_id") self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._fill_overwrites(data) @@ -243,7 +257,7 @@ def type(self) -> ChannelType: def _sorting_bucket(self) -> int: return ChannelType.text.value - @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + @copy_doc(discord.abc.GuildChannel.permissions_for) def permissions_for(self, obj: Member | Role, /) -> Permissions: base = super().permissions_for(obj) @@ -294,7 +308,7 @@ async def edit(self, **options) -> _TextChannel: """Edits the channel.""" raise NotImplementedError - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> TextChannel: return await self._clone_impl( { @@ -498,7 +512,7 @@ async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason from .webhook import Webhook # noqa: PLC0415 if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore + avatar = bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -1001,9 +1015,7 @@ def _update(self, guild: Guild, data: ForumChannelPayload) -> None: if emoji_name is not None: self.default_reaction_emoji = reaction_emoji_ctx["emoji_name"] else: - self.default_reaction_emoji = self._state.get_emoji( - utils._get_as_snowflake(reaction_emoji_ctx, "emoji_id") - ) + self.default_reaction_emoji = self._state.get_emoji(get_as_snowflake(reaction_emoji_ctx, "emoji_id")) @property def guidelines(self) -> str | None: @@ -1026,7 +1038,7 @@ def get_tag(self, id: int, /) -> ForumTag | None: .. versionadded:: 2.3 """ - return utils.get(self.available_tags, id=id) + return utils.find(lambda t: t.id == id, self.available_tags) @overload async def edit( @@ -1137,19 +1149,21 @@ async def edit(self, *, reason=None, **options): async def create_thread( self, name: str, - content=None, + content: str | None = None, *, - embed=None, - embeds=None, - file=None, - files=None, - stickers=None, - delete_message_after=None, - nonce=None, - allowed_mentions=None, - view=None, - applied_tags=None, - auto_archive_duration: ThreadArchiveDuration | utils.Undefined = MISSING, + embed: Embed | None = None, + embeds: list[Embed] | None = None, + file: File | None = None, + files: list[File] | None = None, + stickers: Sequence[GuildSticker | StickerItem] | None = None, + delete_message_after: float | None = None, + nonce: int | str | None = None, + allowed_mentions: AllowedMentions | None = None, + view: View | None = None, + applied_tags: list[ForumTag] | None = None, + suppress: bool = False, + silent: bool = False, + auto_archive_duration: ThreadArchiveDuration = MISSING, slowmode_delay: int | utils.Undefined = MISSING, reason: str | None = None, ) -> Thread: @@ -1242,11 +1256,20 @@ async def create_thread( else: allowed_mentions = allowed_mentions.to_dict() + flags = MessageFlags( + suppress_embeds=bool(suppress), + suppress_notifications=bool(silent), + ) + if view: if not hasattr(view, "__discord_ui_view__"): raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") components = view.to_components() + if view.is_components_v2(): + if embeds or content: + raise TypeError("cannot send embeds or content with a view using v2 component logic") + flags.is_components_v2 = True else: components = None @@ -1282,6 +1305,7 @@ async def create_thread( auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, rate_limit_per_user=slowmode_delay or self.slowmode_delay, applied_tags=applied_tags, + flags=flags.value, reason=reason, ) finally: @@ -1291,7 +1315,7 @@ async def create_thread( ret = Thread(guild=self.guild, state=self._state, data=data) msg = ret.get_partial_message(int(data["last_message_id"])) - if view: + if view and view.is_dispatchable(): state.store_view(view, msg.id) if delete_message_after is not None: @@ -1529,14 +1553,14 @@ def _update(self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload) # This data will always exist self.guild = guild self.name: str = data["name"] - self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self.category_id: int | None = get_as_snowflake(data, "parent_id") # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): rtc = data.get("rtc_region") self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) - self.last_message_id: int | None = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: int | None = get_as_snowflake(data, "last_message_id") self.position: int = data.get("position") self.slowmode_delay = data.get("rate_limit_per_user", 0) self.bitrate: int = data.get("bitrate") @@ -1582,7 +1606,7 @@ def voice_states(self) -> dict[int, VoiceState]: if value.channel and value.channel.id == self.id } - @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + @copy_doc(discord.abc.GuildChannel.permissions_for) def permissions_for(self, obj: Member | Role, /) -> Permissions: base = super().permissions_for(obj) @@ -1936,7 +1960,7 @@ async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason from .webhook import Webhook # noqa: PLC0415 if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore + avatar = bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -1946,7 +1970,7 @@ def type(self) -> ChannelType: """The channel's Discord type.""" return ChannelType.voice - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> VoiceChannel: return await self._clone_impl( {"bitrate": self.bitrate, "user_limit": self.user_limit}, @@ -2463,7 +2487,7 @@ async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason from .webhook import Webhook # noqa: PLC0415 if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore + avatar = bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -2482,7 +2506,7 @@ def type(self) -> ChannelType: """The channel's Discord type.""" return ChannelType.stage_voice - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> StageChannel: return await self._clone_impl({}, name=name, reason=reason) @@ -2492,7 +2516,7 @@ def instance(self) -> StageInstance | None: .. versionadded:: 2.0 """ - return utils.get(self.guild.stage_instances, channel_id=self.id) + return utils.find(lambda s: s.channel_id == self.id, self.guild.stage_instances) async def create_instance( self, @@ -2723,7 +2747,7 @@ def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: # This data will always exist self.guild: Guild = guild self.name: str = data["name"] - self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self.category_id: int | None = get_as_snowflake(data, "parent_id") # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): @@ -2745,7 +2769,7 @@ def is_nsfw(self) -> bool: """Checks if the category is NSFW.""" return self.nsfw - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @@ -2811,7 +2835,7 @@ async def edit(self, *, reason=None, **options): # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - @utils.copy_doc(discord.abc.GuildChannel.move) + @copy_doc(discord.abc.GuildChannel.move) async def move(self, **kwargs): kwargs.pop("category", None) await super().move(**kwargs) @@ -3113,7 +3137,7 @@ def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannel self._update_group(data) def _update_group(self, data: GroupChannelPayload) -> None: - self.owner_id: int | None = utils._get_as_snowflake(data, "owner_id") + self.owner_id: int | None = get_as_snowflake(data, "owner_id") self._icon: str | None = data.get("icon") self.name: str | None = data.get("name") self.recipients: list[User] = [self._state.store_user(u) for u in data.get("recipients", [])] diff --git a/discord/client.py b/discord/client.py index 4180182ace..63d0b55eb7 100644 --- a/discord/client.py +++ b/discord/client.py @@ -38,6 +38,7 @@ from discord.banners import print_banner, start_logging from . import utils +from .utils.private import resolve_invite, resolve_template, bytes_to_base64_data, SequenceProxy from .activity import ActivityTypes, BaseActivity, create_activity from .appinfo import AppInfo, PartialAppInfo from .application_role_connection import ApplicationRoleConnectionMetadata @@ -70,9 +71,11 @@ if TYPE_CHECKING: from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .channel import DMChannel + from .interaction import Interaction from .member import Member from .message import Message from .poll import Poll + from .ui.item import Item from .voice_client import VoiceProtocol __all__ = ("Client",) @@ -384,7 +387,7 @@ def cached_messages(self) -> Sequence[Message]: .. versionadded:: 1.1 """ - return utils.SequenceProxy(self._connection._messages or []) + return SequenceProxy(self._connection._messages or []) @property def private_channels(self) -> list[PrivateChannel]: @@ -538,6 +541,32 @@ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exc() + async def on_view_error(self, error: Exception, item: Item, interaction: Interaction) -> None: + """|coro| + + The default view error handler provided by the client. + + This only fires for a view if you did not define its :func:`~discord.ui.View.on_error`. + """ + + print( + f"Ignoring exception in view {interaction.view} for item {item}:", + file=sys.stderr, + ) + traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) + + async def on_modal_error(self, error: Exception, interaction: Interaction) -> None: + """|coro| + + The default modal error handler provided by the client. + The default implementation prints the traceback to stderr. + + This only fires for a modal if you did not define its :func:`~discord.ui.Modal.on_error`. + """ + + print(f"Ignoring exception in modal {interaction.modal}:", file=sys.stderr) + traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) + # hooks async def _call_before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: @@ -1542,7 +1571,7 @@ async def fetch_template(self, code: Template | str) -> Template: :exc:`HTTPException` Getting the template failed. """ - code = utils.resolve_template(code) + code = resolve_template(code) data = await self.http.get_template(code) return Template(data=data, state=self._connection) # type: ignore @@ -1626,7 +1655,7 @@ async def create_guild( Invalid icon image format given. Must be PNG or JPG. """ if icon is not MISSING: - icon_base64 = utils._bytes_to_base64_data(icon) + icon_base64 = bytes_to_base64_data(icon) else: icon_base64 = None @@ -1718,7 +1747,7 @@ async def fetch_invite( Getting the invite failed. """ - invite_id = utils.resolve_invite(url) + invite_id = resolve_invite(url) data = await self.http.get_invite( invite_id, with_counts=with_counts, @@ -1750,7 +1779,7 @@ async def delete_invite(self, invite: Invite | str) -> None: Revoking the invite failed. """ - invite_id = utils.resolve_invite(invite) + invite_id = resolve_invite(invite) await self.http.delete_invite(invite_id) # Miscellaneous stuff @@ -2226,7 +2255,7 @@ async def create_emoji( The created emoji. """ - img = utils._bytes_to_base64_data(image) + img = bytes_to_base64_data(image) data = await self._connection.http.create_application_emoji(self.application_id, name, img) return self._connection.maybe_store_app_emoji(self.application_id, data) diff --git a/discord/commands/context.py b/discord/commands/context.py index a046afdde6..fd4ecbe809 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -53,7 +53,7 @@ from typing import Callable, Awaitable -from ..utils import cached_property +from ..utils.private import copy_doc T = TypeVar("T") CogT = TypeVar("CogT", bound="Cog") @@ -80,8 +80,6 @@ class ApplicationContext(discord.abc.Messageable): The bot that the command belongs to. interaction: :class:`.Interaction` The interaction object that invoked the command. - command: :class:`.ApplicationCommand` - The command that this context belongs to. """ def __init__(self, bot: Bot, interaction: Interaction): @@ -89,7 +87,6 @@ def __init__(self, bot: Bot, interaction: Interaction): self.interaction = interaction # below attributes will be set after initialization - self.command: ApplicationCommand = None # type: ignore self.focused: Option = None # type: ignore self.value: str = None # type: ignore self.options: dict = None # type: ignore @@ -136,53 +133,62 @@ async def invoke( """ return await command(self, *args, **kwargs) - @cached_property + @property + def command(self) -> ApplicationCommand | None: + """The command that this context belongs to.""" + return self.interaction.command + + @command.setter + def command(self, value: ApplicationCommand | None) -> None: + self.interaction.command = value + + @property def channel(self) -> InteractionChannel | None: """Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]: Returns the channel associated with this context's command. Shorthand for :attr:`.Interaction.channel`. """ return self.interaction.channel - @cached_property + @property def channel_id(self) -> int | None: """Returns the ID of the channel associated with this context's command. Shorthand for :attr:`.Interaction.channel_id`. """ return self.interaction.channel_id - @cached_property + @property def guild(self) -> Guild | None: """Returns the guild associated with this context's command. Shorthand for :attr:`.Interaction.guild`. """ return self.interaction.guild - @cached_property + @property def guild_id(self) -> int | None: """Returns the ID of the guild associated with this context's command. Shorthand for :attr:`.Interaction.guild_id`. """ return self.interaction.guild_id - @cached_property + @property def locale(self) -> str | None: """Returns the locale of the guild associated with this context's command. Shorthand for :attr:`.Interaction.locale`. """ return self.interaction.locale - @cached_property + @property def guild_locale(self) -> str | None: """Returns the locale of the guild associated with this context's command. Shorthand for :attr:`.Interaction.guild_locale`. """ return self.interaction.guild_locale - @cached_property + @property def app_permissions(self) -> Permissions: return self.interaction.app_permissions - @cached_property + @property def me(self) -> Member | ClientUser | None: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message @@ -190,14 +196,14 @@ def me(self) -> Member | ClientUser | None: """ return self.interaction.guild.me if self.interaction.guild is not None else self.bot.user - @cached_property + @property def message(self) -> Message | None: """Returns the message sent with this context's command. Shorthand for :attr:`.Interaction.message`, if applicable. """ return self.interaction.message - @cached_property + @property def user(self) -> Member | User: """Returns the user that sent this context's command. Shorthand for :attr:`.Interaction.user`. @@ -216,7 +222,7 @@ def voice_client(self) -> VoiceClient | None: return self.interaction.guild.voice_client - @cached_property + @property def response(self) -> InteractionResponse: """Returns the response object associated with this context's command. Shorthand for :attr:`.Interaction.response`. @@ -258,17 +264,17 @@ def unselected_options(self) -> list[Option] | None: return None @property - @discord.utils.copy_doc(InteractionResponse.send_modal) + @copy_doc(InteractionResponse.send_modal) def send_modal(self) -> Callable[..., Awaitable[Interaction]]: return self.interaction.response.send_modal @property - @discord.utils.copy_doc(Interaction.respond) + @copy_doc(Interaction.respond) def respond(self, *args, **kwargs) -> Callable[..., Awaitable[Interaction | WebhookMessage]]: return self.interaction.respond @property - @discord.utils.copy_doc(InteractionResponse.send_message) + @copy_doc(InteractionResponse.send_message) def send_response(self) -> Callable[..., Awaitable[Interaction]]: if not self.interaction.response.is_done(): return self.interaction.response.send_message @@ -278,7 +284,7 @@ def send_response(self) -> Callable[..., Awaitable[Interaction]]: ) @property - @discord.utils.copy_doc(Webhook.send) + @copy_doc(Webhook.send) def send_followup(self) -> Callable[..., Awaitable[WebhookMessage]]: if self.interaction.response.is_done(): return self.followup.send @@ -288,7 +294,7 @@ def send_followup(self) -> Callable[..., Awaitable[WebhookMessage]]: ) @property - @discord.utils.copy_doc(InteractionResponse.defer) + @copy_doc(InteractionResponse.defer) def defer(self) -> Callable[..., Awaitable[None]]: return self.interaction.response.defer @@ -322,7 +328,7 @@ async def delete(self, *, delay: float | None = None) -> None: return await self.interaction.delete_original_response(delay=delay) @property - @discord.utils.copy_doc(Interaction.edit_original_response) + @copy_doc(Interaction.edit_original_response) def edit(self) -> Callable[..., Awaitable[InteractionMessage]]: return self.interaction.edit_original_response @@ -384,8 +390,6 @@ class AutocompleteContext: The bot that the command belongs to. interaction: :class:`.Interaction` The interaction object that invoked the autocomplete. - command: :class:`.ApplicationCommand` - The command that this context belongs to. focused: :class:`.Option` The option the user is currently typing. value: :class:`.str` @@ -394,13 +398,12 @@ class AutocompleteContext: A name to value mapping of the options that the user has selected before this option. """ - __slots__ = ("bot", "interaction", "command", "focused", "value", "options") + __slots__ = ("bot", "interaction", "focused", "value", "options") def __init__(self, bot: Bot, interaction: Interaction): self.bot = bot self.interaction = interaction - self.command: ApplicationCommand = None # type: ignore self.focused: Option = None # type: ignore self.value: str = None # type: ignore self.options: dict = None # type: ignore @@ -414,3 +417,12 @@ def cog(self) -> Cog | None: return None return self.command.cog + + @property + def command(self) -> ApplicationCommand | None: + """The command that this context belongs to.""" + return self.interaction.command + + @command.setter + def command(self, value: ApplicationCommand | None) -> None: + self.interaction.command = value diff --git a/discord/commands/core.py b/discord/commands/core.py index 1920610826..af405bcecd 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -68,7 +68,8 @@ from ..role import Role from ..threads import Thread from ..user import User -from ..utils import MISSING, async_all, find, maybe_coroutine, utcnow, warn_deprecated +from ..utils import MISSING, find, utcnow +from ..utils.private import warn_deprecated, async_all, maybe_awaitable from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice @@ -432,7 +433,7 @@ async def can_run(self, ctx: ApplicationContext) -> bool: if cog is not None: local_check = cog._get_overridden_method(cog.cog_check) if local_check is not None: - ret = await maybe_coroutine(local_check, ctx) + ret = await maybe_awaitable(local_check, ctx) if not ret: return False diff --git a/discord/components.py b/discord/components.py index 0d9a927a66..999efa7a8c 100644 --- a/discord/components.py +++ b/discord/components.py @@ -25,20 +25,58 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar - -from .enums import ButtonStyle, ChannelType, ComponentType, InputTextStyle, try_enum -from .partial_emoji import PartialEmoji, _EmojiTag -from .utils import MISSING, Undefined, get_slots +from typing import ( + TYPE_CHECKING, + cast, + ClassVar, + TypeVar, + Generic, + TypeAlias, + Literal, + overload, +) +from collections.abc import Iterator, Sequence +from typing_extensions import override +from abc import ABC, abstractmethod + + +from .asset import AssetMixin +from .colour import Colour +from .enums import ( + ButtonStyle, + ComponentType, + InputTextStyle, + SeparatorSpacingSize, + try_enum, +) +from .flags import AttachmentFlags +from .partial_emoji import PartialEmoji, _EmojiTag # pyright: ignore[reportPrivateUsage] +from .utils import MISSING, Undefined +from .state import ConnectionState if TYPE_CHECKING: + from typing_extensions import Self from .emoji import AppEmoji, GuildEmoji from .types.components import ActionRow as ActionRowPayload from .types.components import ButtonComponent as ButtonComponentPayload from .types.components import Component as ComponentPayload + from .types.components import ContainerComponent as ContainerComponentPayload + from .types.components import FileComponent as FileComponentPayload from .types.components import InputText as InputTextComponentPayload - from .types.components import SelectMenu as SelectMenuPayload + from .types.components import MediaGalleryComponent as MediaGalleryComponentPayload + from .types.components import MediaGalleryItem as MediaGalleryItemPayload + from .types.components import SectionComponent as SectionComponentPayload + from .types.components import StringSelect as StringSelectPayload + from .types.components import ChannelSelect as ChannelSelectPayload + from .types.components import RoleSelect as RoleSelectPayload + from .types.components import MentionableSelect as MentionableSelectPayload + from .types.components import UserSelect as UserSelectPayload from .types.components import SelectOption as SelectOptionPayload + from .types.components import SeparatorComponent as SeparatorComponentPayload + from .types.components import TextDisplayComponent as TextDisplayComponentPayload + from .types.components import ThumbnailComponent as ThumbnailComponentPayload + from .types.components import UnfurledMediaItem as UnfurledMediaItemPayload + from .types.components import SelectDefaultValue __all__ = ( "Component", @@ -47,19 +85,38 @@ "SelectMenu", "SelectOption", "InputText", + "Section", + "TextDisplay", + "Thumbnail", + "MediaGallery", + "MediaGalleryItem", + "UnfurledMediaItem", + "FileComponent", + "Separator", + "Container", ) -C = TypeVar("C", bound="Component") + +AnyEmoji = GuildEmoji | AppEmoji | PartialEmoji +P = TypeVar("P", bound="ComponentPayload", covariant=True) +C = TypeVar("C", bound="Component[ComponentPayload]", covariant=True) -class Component: +class Component(ABC, Generic[P]): """Represents a Discord Bot UI Kit Component. - Currently, the only components supported by Discord are: + The components supported by Discord in messages are as follows: - :class:`ActionRow` - :class:`Button` - :class:`SelectMenu` + - :class:`Section` + - :class:`TextDisplay` + - :class:`Thumbnail` + - :class:`MediaGallery` + - :class:`FileComponent` + - :class:`Separator` + - :class:`Container` This class is abstract and cannot be instantiated. @@ -69,66 +126,86 @@ class Component: ---------- type: :class:`ComponentType` The type of component. + id: :class:`int` + The component's ID. If not provided by the user, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. """ - __slots__: tuple[str, ...] = ("type",) + __slots__: tuple[str, ...] = ("type", "id") # pyright: ignore[reportIncompatibleUnannotatedOverride] __repr_info__: ClassVar[tuple[str, ...]] type: ComponentType + versions: tuple[int, ...] + + def __init__(self, id: int | None = None) -> None: + self.id: int | None = id + @override def __repr__(self) -> str: attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__) return f"<{self.__class__.__name__} {attrs}>" + @abstractmethod + def to_dict(self) -> P: ... + + @abstractmethod @classmethod - def _raw_construct(cls: type[C], **kwargs) -> C: - self: C = cls.__new__(cls) - for slot in get_slots(cls): - try: - value = kwargs[slot] - except KeyError: - pass - else: - setattr(self, slot, value) - return self + def from_payload(cls, payload: P) -> Self: ... # pyright: ignore[reportGeneralTypeIssues] - def to_dict(self) -> dict[str, Any]: - raise NotImplementedError + def is_v2(self) -> bool: + """Whether this component was introduced in Components V2.""" + return bool(self.versions and 1 not in self.versions) + def any_is_v2(self) -> bool: + """Whether this component or any of its children were introduced in Components V2.""" + return self.is_v2() -class ActionRow(Component): - """Represents a Discord Bot UI Kit Action Row. + def is_dispatchable(self) -> bool: + """Wether this component can be interacted with and lead to a :class:`Interaction`""" + return False - This is a component that holds up to 5 children components in a row. + def any_is_dispatchable(self) -> bool: + """Whether this component or any of its children can be interacted with and lead to a :class:`Interaction`""" + return self.is_dispatchable() - This inherits from :class:`Component`. +class StateComponent(Component[P], ABC): + @abstractmethod + @classmethod + @override + def from_payload(cls, payload: P, state: ConnectionState | None = None) -> Self: # pyright: ignore[reportGeneralTypeIssues] + ... - .. versionadded:: 2.0 - Attributes - ---------- - type: :class:`ComponentType` - The type of component. - children: List[:class:`Component`] - The children components that this holds, if any. +class WalkableComponent(Component[P], ABC, Generic[P, C]): + """A component that can be walked through. + + This is an abstract class and cannot be instantiated directly. + It is used to represent components that can be walked through, such as :class:`ActionRow`, :class:`Container` and :class:`Section`. """ - __slots__: tuple[str, ...] = ("children",) + __slots__: tuple[str, ...] = ("components",) # pyright: ignore[reportIncompatibleUnannotatedOverride] + components: list[C] - __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + def walk_components(self) -> Iterator[C]: + """Walks through the components in this component.""" + for component in self.components: + if isinstance(component, WalkableComponent): + yield from component.walk_components() + else: + yield component - def __init__(self, data: ComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data["type"]) - self.children: list[Component] = [_component_factory(d) for d in data.get("components", [])] + @override + def any_is_v2(self) -> bool: + """Whether this component or any of its children were introduced in Components V2.""" + return self.is_v2() or any(c.any_is_v2() for c in self.walk_components()) - def to_dict(self) -> ActionRowPayload: - return { - "type": int(self.type), - "components": [child.to_dict() for child in self.children], - } # type: ignore + @override + def any_is_dispatchable(self) -> bool: + """Whether this component or any of its children can be interacted with and lead to a :class:`Interaction`""" + return self.is_dispatchable() or any(c.any_is_dispatchable() for c in self.walk_components()) -class InputText(Component): +class InputText(Component[InputTextComponentPayload]): """Represents an Input Text field from the Discord Bot UI Kit. This inherits from :class:`Component`. @@ -137,7 +214,7 @@ class InputText(Component): style: :class:`.InputTextStyle` The style of the input text field. custom_id: Optional[:class:`str`] - The ID of the input text field that gets received during an interaction. + The custom ID of the input text field that gets received during an interaction. label: :class:`str` The label for the input text field. placeholder: Optional[:class:`str`] @@ -151,10 +228,11 @@ class InputText(Component): Whether the input text field is required or not. Defaults to `True`. value: Optional[:class:`str`] The value that has been entered in the input text field. + id: Optional[:class:`int`] + The input text's ID. """ __slots__: tuple[str, ...] = ( - "type", "style", "custom_id", "label", @@ -166,21 +244,60 @@ class InputText(Component): ) __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (1, 2) + type: Literal[ComponentType.input_text] = ComponentType.input_text # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + style: int | InputTextStyle, + custom_id: str, + label: str, + min_lenght: int | None = None, + max_length: int | None = None, + placeholder: str | None = None, + required: bool = True, + value: str | None = None, + id: int | None = None, + ) -> None: + self.style: InputTextStyle = style # pyright: ignore[reportAttributeAccessIssue] + self.custom_id: str = custom_id + self.label: str = label + self.min_length: int | None = min_lenght + self.max_length: int | None = max_length + self.placeholder: str | None = placeholder + self.required: bool = required + self.value: str | None = value + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: InputTextComponentPayload) -> Self: + style = try_enum(InputTextStyle, payload["style"]) + custom_id = payload["custom_id"] + label = payload["label"] + min_length = payload.get("min_length") + max_length = payload.get("max_length") + placeholder = payload.get("placeholder") + required = payload.get("required", True) + value = payload.get("value") - def __init__(self, data: InputTextComponentPayload): - self.type = ComponentType.input_text - self.style: InputTextStyle = try_enum(InputTextStyle, data["style"]) - self.custom_id = data["custom_id"] - self.label: str = data.get("label", None) - self.placeholder: str | None = data.get("placeholder", None) - self.min_length: int | None = data.get("min_length", None) - self.max_length: int | None = data.get("max_length", None) - self.required: bool = data.get("required", True) - self.value: str | None = data.get("value", None) + return cls( + style=style, + custom_id=custom_id, + label=label, + min_lenght=min_length, + max_length=max_length, + placeholder=placeholder, + required=required, + value=value, + id=payload.get("id"), + ) + @override def to_dict(self) -> InputTextComponentPayload: - payload = { - "type": 4, + payload: InputTextComponentPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, "style": self.style.value, "label": self.label, } @@ -205,16 +322,11 @@ def to_dict(self) -> InputTextComponentPayload: return payload # type: ignore -class Button(Component): +class Button(Component[ButtonComponentPayload]): """Represents a button from the Discord Bot UI Kit. This inherits from :class:`Component`. - .. note:: - - The user constructible and usable type to create a button is :class:`discord.ui.Button` - not this one. - .. versionadded:: 2.0 Attributes @@ -234,6 +346,34 @@ class Button(Component): The emoji of the button, if available. sku_id: Optional[:class:`int`] The ID of the SKU this button refers to. + id: Optional[:class:`int`] + The button's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + style: :class:`.ButtonStyle` + The style of the button. + custom_id: Optional[:class:`str`] + The ID of the button that gets received during an interaction. + Cannot be used with :class:`ButtonStyle.url` or :class:`ButtonStyle.premium`. + label: Optional[:class:`str`] + The label of the button, if any. + Cannot be used with :class:`ButtonStyle.premium`. + emoji: Optional[:class:`str` | :class:`PartialEmoji`] + The emoji of the button, if available. + Cannot be used with :class:`ButtonStyle.premium`. + disabled: :class:`bool` + Whether the button is disabled or not. + url: Optional[:class:`str`] + The URL this button sends you to. + Can only be used with :class:`ButtonStyle.url`. + id: Optional[:class:`int`] + The button's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + sku_id: Optional[:class:`int`] + The ID of the SKU this button refers to. + Can only be used with :class:`ButtonStyle.premium`. """ __slots__: tuple[str, ...] = ( @@ -247,24 +387,129 @@ class Button(Component): ) __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (1, 2) + type: Literal[ComponentType.button] = ComponentType.button # pyright: ignore[reportIncompatibleVariableOverride] + width: Literal[1] = 1 + + # Premium button + @overload + def __init__( + self, + style: Literal[ButtonStyle.premium], + *, + sku_id: int, + disabled: bool = False, + id: int | None = None, + ) -> None: ... + + # URL button with label + @overload + def __init__( + self, + style: Literal[ButtonStyle.url], + *, + label: str, + emoji: str | AnyEmoji | None = None, + disabled: bool = False, + url: str, + id: int | None = None, + ) -> None: ... + + # URL button with emoji + @overload + def __init__( + self, + style: Literal[ButtonStyle.url], + *, + emoji: str | AnyEmoji, + label: str | None = None, + disabled: bool = False, + url: str, + id: int | None = None, + ) -> None: ... + + # Interactive button with label + @overload + def __init__( + self, + style: Literal[ButtonStyle.primary, ButtonStyle.secondary, ButtonStyle.success, ButtonStyle.danger], + *, + custom_id: str, + label: str, + emoji: str | AnyEmoji | None = None, + disabled: bool = False, + id: int | None = None, + ) -> None: ... + + # Interactive button with emoji + @overload + def __init__( + self, + style: Literal[ButtonStyle.primary, ButtonStyle.secondary, ButtonStyle.success, ButtonStyle.danger], + *, + custom_id: str, + emoji: str | AnyEmoji, + label: str | None = None, + disabled: bool = False, + id: int | None = None, + ) -> None: ... - def __init__(self, data: ButtonComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data["type"]) - self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) - self.custom_id: str | None = data.get("custom_id") - self.url: str | None = data.get("url") - self.disabled: bool = data.get("disabled", False) - self.label: str | None = data.get("label") + def __init__( + self, + style: int | ButtonStyle, + custom_id: str | None = None, + label: str | None = None, + emoji: str | AnyEmoji | None = None, + disabled: bool = False, + url: str | None = None, + id: int | None = None, + sku_id: int | None = None, + ) -> None: + self.style: ButtonStyle = try_enum(ButtonStyle, style) + self.custom_id: str | None = custom_id + self.url: str | None = url + self.disabled: bool = disabled + self.label: str | None = label self.emoji: PartialEmoji | None - try: - self.emoji = PartialEmoji.from_dict(data["emoji"]) - except KeyError: - self.emoji = None - self.sku_id: str | None = data.get("sku_id") + if isinstance(emoji, _EmojiTag): + self.emoji = emoji._to_partial() # pyright: ignore[reportPrivateUsage] + elif isinstance(emoji, str): + self.emoji = PartialEmoji.from_str(emoji) + else: + self.emoji = emoji + self.sku_id: int | None = sku_id + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: ButtonComponentPayload) -> Self: + style = try_enum(ButtonStyle, payload["style"]) + custom_id = payload.get("custom_id") + label = payload.get("label") + emoji = payload.get("emoji") + disabled = payload.get("disabled", False) + url = payload.get("url") + sku_id = payload.get("sku_id") + + if emoji is not None: + emoji = PartialEmoji.from_dict(emoji) + + return cls( # pyright: ignore[reportCallIssue] + style=style, + custom_id=custom_id, + label=label, + emoji=emoji, + disabled=disabled, + url=url, + id=payload.get("id"), + sku_id=int(sku_id) if sku_id is not None else None, + ) + @override def to_dict(self) -> ButtonComponentPayload: - payload = { + payload: ButtonComponentPayload = { # pyright: ignore[reportAssignmentType] "type": 2, + "id": self.id, "style": int(self.style), "label": self.label, "disabled": self.disabled, @@ -276,7 +521,7 @@ def to_dict(self) -> ButtonComponentPayload: payload["url"] = self.url if self.emoji: - payload["emoji"] = self.emoji.to_dict() + payload["emoji"] = self.emoji.to_dict() # pyright: ignore[reportGeneralTypeIssues] if self.sku_id: payload["sku_id"] = self.sku_id @@ -284,91 +529,6 @@ def to_dict(self) -> ButtonComponentPayload: return payload # type: ignore -class SelectMenu(Component): - """Represents a select menu from the Discord Bot UI Kit. - - A select menu is functionally the same as a dropdown, however - on mobile it renders a bit differently. - - .. note:: - - The user constructible and usable type to create a select menu is - :class:`discord.ui.Select` not this one. - - .. versionadded:: 2.0 - - .. versionchanged:: 2.3 - - Added support for :attr:`ComponentType.user_select`, :attr:`ComponentType.role_select`, - :attr:`ComponentType.mentionable_select`, and :attr:`ComponentType.channel_select`. - - Attributes - ---------- - type: :class:`ComponentType` - The select menu's type. - custom_id: Optional[:class:`str`] - The ID of the select menu that gets received during an interaction. - placeholder: Optional[:class:`str`] - The placeholder text that is shown if nothing is selected, if any. - min_values: :class:`int` - The minimum number of items that must be chosen for this select menu. - Defaults to 1 and must be between 0 and 25. - max_values: :class:`int` - The maximum number of items that must be chosen for this select menu. - Defaults to 1 and must be between 1 and 25. - options: List[:class:`SelectOption`] - A list of options that can be selected in this menu. - Will be an empty list for all component types - except for :attr:`ComponentType.string_select`. - channel_types: List[:class:`ChannelType`] - A list of channel types that can be selected. - Will be an empty list for all component types - except for :attr:`ComponentType.channel_select`. - disabled: :class:`bool` - Whether the select is disabled or not. - """ - - __slots__: tuple[str, ...] = ( - "custom_id", - "placeholder", - "min_values", - "max_values", - "options", - "channel_types", - "disabled", - ) - - __repr_info__: ClassVar[tuple[str, ...]] = __slots__ - - def __init__(self, data: SelectMenuPayload): - self.type = try_enum(ComponentType, data["type"]) - self.custom_id: str = data["custom_id"] - self.placeholder: str | None = data.get("placeholder") - self.min_values: int = data.get("min_values", 1) - self.max_values: int = data.get("max_values", 1) - self.disabled: bool = data.get("disabled", False) - self.options: list[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])] - self.channel_types: list[ChannelType] = [try_enum(ChannelType, ct) for ct in data.get("channel_types", [])] - - def to_dict(self) -> SelectMenuPayload: - payload: SelectMenuPayload = { - "type": self.type.value, - "custom_id": self.custom_id, - "min_values": self.min_values, - "max_values": self.max_values, - "disabled": self.disabled, - } - - if self.type is ComponentType.string_select: - payload["options"] = [op.to_dict() for op in self.options] - if self.type is ComponentType.channel_select and self.channel_types: - payload["channel_types"] = [ct.value for ct in self.channel_types] - if self.placeholder: - payload["placeholder"] = self.placeholder - - return payload - - class SelectOption: """Represents a :class:`discord.SelectMenu`'s option. @@ -406,7 +566,7 @@ def __init__( label: str, value: str | Undefined = MISSING, description: str | None = None, - emoji: str | GuildEmoji | AppEmoji | PartialEmoji | None = None, + emoji: str | AnyEmoji | None = None, default: bool = False, ) -> None: if len(label) > 100: @@ -418,12 +578,13 @@ def __init__( if description is not None and len(description) > 100: raise ValueError("description must be 100 characters or fewer") - self.label = label - self.value = label if value is MISSING else value - self.description = description + self.label: str = label + self.value: str = label if value is MISSING else value + self.description: str | None = description self.emoji = emoji - self.default = default + self.default: bool = default + @override def __repr__(self) -> str: return ( " str: f"emoji={self.emoji!r} default={self.default!r}>" ) + @override def __str__(self) -> str: base = f"{self.emoji} {self.label}" if self.emoji else self.label if self.description: @@ -438,29 +600,29 @@ def __str__(self) -> str: return base @property - def emoji(self) -> str | GuildEmoji | AppEmoji | PartialEmoji | None: + def emoji(self) -> PartialEmoji | None: """The emoji of the option, if available.""" return self._emoji @emoji.setter - def emoji(self, value) -> None: + def emoji(self, value: str | AnyEmoji | None) -> None: # pyright: ignore[reportPropertyTypeMismatch] if value is not None: if isinstance(value, str): value = PartialEmoji.from_str(value) - elif isinstance(value, _EmojiTag): - value = value._to_partial() + elif isinstance(value, _EmojiTag): # pyright: ignore[reportUnnecessaryIsInstance] + value = value._to_partial() # pyright: ignore[reportPrivateUsage] else: - raise TypeError( - f"expected emoji to be str, GuildEmoji, AppEmoji, or PartialEmoji, not {value.__class__}" + raise TypeError( # pyright: ignore[reportUnreachable] + f"expected emoji to be None, str, GuildEmoji, AppEmoji, or PartialEmoji, not {value.__class__}" ) - self._emoji = value + self._emoji: PartialEmoji | None = value @classmethod def from_dict(cls, data: SelectOptionPayload) -> SelectOption: - try: - emoji = PartialEmoji.from_dict(data["emoji"]) - except KeyError: + if e := data.get("emoji"): + emoji = PartialEmoji.from_dict(e) + else: emoji = None return cls( @@ -479,7 +641,7 @@ def to_dict(self) -> SelectOptionPayload: } if self.emoji: - payload["emoji"] = self.emoji.to_dict() # type: ignore + payload["emoji"] = self.emoji.to_dict() # type: ignore # pyright: ignore[reportGeneralTypeIssues] if self.description: payload["description"] = self.description @@ -487,16 +649,1369 @@ def to_dict(self) -> SelectOptionPayload: return payload -def _component_factory(data: ComponentPayload) -> Component: - component_type = data["type"] - if component_type == 1: - return ActionRow(data) - elif component_type == 2: - return Button(data) # type: ignore - elif component_type == 4: - return InputText(data) # type: ignore - elif component_type in (3, 5, 6, 7, 8): - return SelectMenu(data) # type: ignore - else: - as_enum = try_enum(ComponentType, component_type) - return Component._raw_construct(type=as_enum) +DT = TypeVar("DT", bound='Literal["user", "role", "channel"]') + + +class DefaultSelectOption(Generic[DT]): + """ + Represents a default select menu option. + Can only be used :class:`UserSelectMenu`, :class:`RoleSelectMenu`, and :class:`MentionableSelectMenu`. + + .. versionadded:: 3.0 + + Attributes + ---------- + id: :class:`int` + The ID of the default option. + type: :class:`str` + The type of the default option. This can be either "user", "role", or "channel". + This is used to determine which type of select menu this option belongs to. + """ + + __slots__: tuple[str, ...] = ("id", "type") + + def __init__( + self, + id: int, + type: DT, + ) -> None: + self.id: int = id + self.type: DT = type + + @override + def __repr__(self) -> str: + return f"" + + @classmethod + def from_payload(cls, payload: SelectDefaultValue[DT]) -> DefaultSelectOption[DT]: + """Creates a DefaultSelectOption from a dictionary.""" + return cls( + id=payload["id"], + type=payload["type"], + ) + + def to_dict(self) -> SelectDefaultValue[DT]: + """Converts the DefaultSelectOption to a dictionary.""" + return { + "id": self.id, + "type": self.type, + } + + +SelectMenuTypes = ( + StringSelectPayload | ChannelSelectPayload | RoleSelectPayload | MentionableSelectPayload | UserSelectPayload +) + +T = TypeVar( + "T", + bound=SelectMenuTypes, +) + + +class SelectMenu(Component[T], ABC, Generic[T]): + """Represents a select menu from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + This is an abstract class and cannot be instantiated directly. + + .. versionadded:: 3.0 + + """ + + __slots__: tuple[str, ...] = ( # pyright: ignore[reportIncompatibleUnannotatedOverride] + "custom_id", + "placeholder", + "min_values", + "max_values", + "disabled", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (1, 2) + type: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] + ComponentType.string_select, + ComponentType.channel_select, + ComponentType.role_select, + ComponentType.mentionable_select, + ComponentType.user_select, + ] + width: Literal[5] = 5 + + def __init__( + self, + custom_id: str, + *, + placeholder: str | None = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + id: int | None = None, + ): + self.custom_id: str = custom_id + self.placeholder: str | None = placeholder + self.min_values: int = min_values + self.max_values: int = max_values + self.disabled: bool = disabled + super().__init__(id=id) + + +class StringSelectMenu(SelectMenu[StringSelectPayload]): + """Represents a string select menu from the Discord Bot UI Kit. + + This inherits from :class:`SelectMenu`. + + .. versionadded:: 3.0 + + Attributes + ---------- + options: List[:class:`SelectOption`] + The options available in this select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. + Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + options: Sequence[:class:`SelectOption`] + The options available in this select menu. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("options",) + type: Literal[ComponentType.string_select] = ComponentType.string_select # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + custom_id: str, + options: Sequence[SelectOption], + *, + placeholder: str | None = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + id: int | None = None, + ): + super().__init__( + custom_id=custom_id, + placeholder=placeholder, + min_values=min_values, + max_values=max_values, + disabled=disabled, + id=id, + ) + self.options: list[SelectOption] = list(options) + + @classmethod + @override + def from_payload(cls, payload: StringSelectPayload) -> Self: + options = [SelectOption.from_dict(option) for option in payload["options"]] + return cls( + custom_id=payload["custom_id"], + options=options, + placeholder=payload.get("placeholder"), + min_values=payload.get("min_values", 1), + max_values=payload.get("max_values", 1), + disabled=payload.get("disabled", False), + id=payload.get("id"), + ) + + @override + def to_dict(self) -> StringSelectPayload: + payload: StringSelectPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, + "custom_id": self.custom_id, + "options": [option.to_dict() for option in self.options], + "min_values": self.min_values, + "max_values": self.max_values, + } + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.disabled: + payload["disabled"] = self.disabled + + return payload + + +class UserSelectMenu(SelectMenu[UserSelectPayload]): + """Represents a user select menu from the Discord Bot UI Kit. + + This inherits from :class:`SelectMenu`. + + .. versionadded:: 3.0 + + Attributes + ---------- + default_values: List[:class:`DefaultSelectOption[Literal["user"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. + Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + default_values: Sequence[:class:`DefaultSelectOption[Literal["user"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + options: Sequence[:class:`SelectOption`] + The options available in this select menu. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("default_values",) + type: Literal[ComponentType.user_select] = ComponentType.user_select # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + *, + default_values: Sequence[DefaultSelectOption[Literal["user"]]] | None = None, + custom_id: str, + placeholder: str | None = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + id: int | None = None, + ): + super().__init__( + custom_id=custom_id, + placeholder=placeholder, + min_values=min_values, + max_values=max_values, + disabled=disabled, + id=id, + ) + self.default_values: list[DefaultSelectOption[Literal["user"]]] = ( + list(default_values) if default_values is not None else [] + ) + + @classmethod + @override + def from_payload(cls, payload: UserSelectPayload) -> Self: + default_values: list[DefaultSelectOption[Literal["user"]]] = [ + DefaultSelectOption.from_payload(value) for value in payload.get("default_values", []) + ] + return cls( + custom_id=payload["custom_id"], + placeholder=payload.get("placeholder"), + min_values=payload.get("min_values", 1), + max_values=payload.get("max_values", 1), + disabled=payload.get("disabled", False), + id=payload.get("id"), + default_values=default_values, + ) + + @override + def to_dict(self) -> UserSelectPayload: + payload: UserSelectPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + } + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.disabled: + payload["disabled"] = self.disabled + + if self.default_values: + payload["default_values"] = [value.to_dict() for value in self.default_values] + + return payload + + +class RoleSelectMenu(SelectMenu[RoleSelectPayload]): + """Represents a role select menu from the Discord Bot UI Kit. + + This inherits from :class:`SelectMenu`. + + .. versionadded:: 3.0 + + Attributes + ---------- + default_values: List[:class:`DefaultSelectOption[Literal["role"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. + Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + default_values: Sequence[:class:`DefaultSelectOption[Literal["role"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("default_values",) + type: Literal[ComponentType.role_select] = ComponentType.role_select # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + *, + default_values: Sequence[DefaultSelectOption[Literal["role"]]] | None = None, + custom_id: str, + placeholder: str | None = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + id: int | None = None, + ): + super().__init__( + custom_id=custom_id, + placeholder=placeholder, + min_values=min_values, + max_values=max_values, + disabled=disabled, + id=id, + ) + self.default_values: list[DefaultSelectOption[Literal["role"]]] = ( + list(default_values) if default_values is not None else [] + ) + + @classmethod + @override + def from_payload(cls, payload: RoleSelectPayload) -> Self: + default_values: list[DefaultSelectOption[Literal["role"]]] = [ + DefaultSelectOption.from_payload(value) for value in payload.get("default_values", []) + ] + return cls( + custom_id=payload["custom_id"], + placeholder=payload.get("placeholder"), + min_values=payload.get("min_values", 1), + max_values=payload.get("max_values", 1), + disabled=payload.get("disabled", False), + id=payload.get("id"), + default_values=default_values, + ) + + @override + def to_dict(self) -> RoleSelectPayload: + payload: RoleSelectPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + } + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.disabled: + payload["disabled"] = self.disabled + + if self.default_values: + payload["default_values"] = [value.to_dict() for value in self.default_values] + + return payload + + +class MentionableSelectMenu(SelectMenu[MentionableSelectPayload]): + """Represents a mentionable select menu from the Discord Bot UI Kit. + + This inherits from :class:`SelectMenu`. + + .. versionadded:: 3.0 + + Attributes + ---------- + default_values: List[:class:`DefaultSelectOption[Literal["role", "user"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. + Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + default_values: Sequence[:class:`DefaultSelectOption[Literal["role", "user"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("default_values",) + type: Literal[ComponentType.mentionable_select] = ComponentType.mentionable_select # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + *, + default_values: Sequence[DefaultSelectOption[Literal["role", "user"]]] | None = None, + custom_id: str, + placeholder: str | None = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + id: int | None = None, + ): + super().__init__( + custom_id=custom_id, + placeholder=placeholder, + min_values=min_values, + max_values=max_values, + disabled=disabled, + id=id, + ) + self.default_values: list[DefaultSelectOption[Literal["role", "user"]]] = ( + list(default_values) if default_values is not None else [] + ) + + @classmethod + @override + def from_payload(cls, payload: MentionableSelectPayload) -> Self: + default_values: list[DefaultSelectOption[Literal["role", "user"]]] = [ + DefaultSelectOption.from_payload(value) for value in payload.get("default_values", []) + ] + return cls( + custom_id=payload["custom_id"], + placeholder=payload.get("placeholder"), + min_values=payload.get("min_values", 1), + max_values=payload.get("max_values", 1), + disabled=payload.get("disabled", False), + id=payload.get("id"), + default_values=default_values, + ) + + @override + def to_dict(self) -> MentionableSelectPayload: + payload: MentionableSelectPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + } + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.disabled: + payload["disabled"] = self.disabled + + if self.default_values: + payload["default_values"] = [value.to_dict() for value in self.default_values] + + return payload + + +class ChannelSelectMenu(SelectMenu[ChannelSelectPayload]): + """Represents a channel select menu from the Discord Bot UI Kit. + + This inherits from :class:`SelectMenu`. + + .. versionadded:: 3.0 + + Attributes + ---------- + default_values: List[:class:`DefaultSelectOption[Literal["channel"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. + Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + default_values: Sequence[:class:`DefaultSelectOption[Literal["channel"]]`] + The default selected values of the select menu. + custom_id: :class:`str` + The custom ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of values that must be selected. + Defaults to 1. + max_values: :class:`int` + The maximum number of values that can be selected. + Defaults to 1. + disabled: :class:`bool` + Whether the select menu is disabled or not. Defaults to ``False``. + id: Optional[:class:`int`] + The select menu's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("default_values",) + type: Literal[ComponentType.channel_select] = ComponentType.channel_select # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + *, + default_values: Sequence[DefaultSelectOption[Literal["channel"]]] | None = None, + custom_id: str, + placeholder: str | None = None, + min_values: int = 1, + max_values: int = 1, + disabled: bool = False, + id: int | None = None, + ): + super().__init__( + custom_id=custom_id, + placeholder=placeholder, + min_values=min_values, + max_values=max_values, + disabled=disabled, + id=id, + ) + self.default_values: list[DefaultSelectOption[Literal["channel"]]] = ( + list(default_values) if default_values is not None else [] + ) + + @classmethod + @override + def from_payload(cls, payload: ChannelSelectPayload) -> Self: + default_values: list[DefaultSelectOption[Literal["channel"]]] = [ + DefaultSelectOption.from_payload(value) for value in payload.get("default_values", []) + ] + return cls( + custom_id=payload["custom_id"], + placeholder=payload.get("placeholder"), + min_values=payload.get("min_values", 1), + max_values=payload.get("max_values", 1), + disabled=payload.get("disabled", False), + id=payload.get("id"), + default_values=default_values, + ) + + @override + def to_dict(self) -> ChannelSelectPayload: + payload: ChannelSelectPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + } + if self.placeholder: + payload["placeholder"] = self.placeholder + + if self.disabled: + payload["disabled"] = self.disabled + + if self.default_values: + payload["default_values"] = [value.to_dict() for value in self.default_values] + + return payload + + +class TextDisplay(Component[TextDisplayComponentPayload]): + """Represents a Text Display from Components V2. + + This is a component that displays text. + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + content: :class:`str` + The component's text content. + id: Optional[:class:`int`] + The component's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + content: :class:`str` + The text content of the component. + id: Optional[:class:`int`] + The component's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("content",) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.text_display] = ComponentType.text_display # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__(self, content: str, id: int | None = None): + self.content: str = content + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: TextDisplayComponentPayload) -> Self: + return cls( + content=payload["content"], + id=payload.get("id"), + ) + + @override + def to_dict(self) -> TextDisplayComponentPayload: + return {"type": int(self.type), "id": self.id, "content": self.content} # pyright: ignore[reportReturnType] + + +class UnfurledMediaItem(AssetMixin): + """Represents an Unfurled Media Item used in Components V2. + + This is used as an underlying component for other media-based components such as :class:`Thumbnail`, :class:`FileComponent`, and :class:`MediaGalleryItem`. + + .. versionadded:: 2.7 + + Attributes + ---------- + url: :class:`str` + The URL of this media item. This can either be an arbitrary URL or an ``attachment://`` URL to work with local files. + """ + + def __init__(self, url: str): + self._state: ConnectionState | None = None + self._url: str = url + self.proxy_url: str | None = None + self.height: int | None = None + self.width: int | None = None + self.content_type: str | None = None + self.flags: AttachmentFlags | None = None + self.attachment_id: int | None = None + + @property + @override + def url(self) -> str: # pyright: ignore[reportIncompatibleVariableOverride] + """Returns this media item's url.""" + return self._url + + @classmethod + def from_dict(cls, data: UnfurledMediaItemPayload, state: ConnectionState | None = None) -> UnfurledMediaItem: + r = cls(data.get("url")) + r.proxy_url = data.get("proxy_url") + r.height = data.get("height") + r.width = data.get("width") + r.content_type = data.get("content_type") + r.flags = AttachmentFlags._from_value(data.get("flags", 0)) # pyright: ignore[reportPrivateUsage] + r.attachment_id = data.get("attachment_id") # pyright: ignore[reportAttributeAccessIssue] + r._state = state + return r + + def to_dict(self) -> UnfurledMediaItemPayload: + return {"url": self.url} # pyright: ignore[reportReturnType] + + +class Thumbnail(StateComponent[ThumbnailComponentPayload]): + """Represents a Thumbnail from Components V2. + + This is a component that displays media, such as images and videos. + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + media: :class:`UnfurledMediaItem` + The component's underlying media object. + description: Optional[:class:`str`] + The thumbnail's description, up to 1024 characters. + spoiler: Optional[:class:`bool`] + Whether the thumbnail has the spoiler overlay. + + Parameters + ---------- + url: :class:`str` | :class:`UnfurledMediaItem` + The URL of the thumbnail. This can either be an arbitrary URL or an ``attachment://`` URL to work with local files. + id: Optional[:class:`int`] + The thumbnail's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + description: Optional[:class:`str`] + The thumbnail's description, up to 1024 characters. + spoiler: Optional[:class:`bool`] + Whether the thumbnail has the spoiler overlay. Defaults to ``False``. + """ + + __slots__: tuple[str, ...] = ( + "file", + "description", + "spoiler", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.thumbnail] = ComponentType.thumbnail # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + url: str | UnfurledMediaItem, + *, + id: int | None = None, + description: str | None = None, + spoiler: bool | None = False, + ): + self.file: UnfurledMediaItem = url if isinstance(url, UnfurledMediaItem) else UnfurledMediaItem(url) + self.description: str | None = description + self.spoiler: bool | None = spoiler + super().__init__(id=id) + + @property + def url(self) -> str: + """Returns the URL of this thumbnail's underlying media item.""" + return self.file.url + + @classmethod + @override + def from_payload(cls, payload: ThumbnailComponentPayload, state: ConnectionState | None = None) -> Self: + file = UnfurledMediaItem.from_dict(payload.get("file", {}), state=state) + return cls( + url=file, + id=payload.get("id"), + description=payload.get("description"), + spoiler=payload.get("spoiler", False), + ) + + @override + def to_dict(self) -> ThumbnailComponentPayload: + payload: ThumbnailComponentPayload = {"type": self.type, "id": self.id, "media": self.file.to_dict()} # pyright: ignore[reportAssignmentType] + if self.description: + payload["description"] = self.description + if self.spoiler is not None: + payload["spoiler"] = self.spoiler + return payload + + +AllowedSectionComponents: TypeAlias = TextDisplay +AllowedSectionAccessoryComponents = Button | Thumbnail + + +class Section( + WalkableComponent[SectionComponentPayload, AllowedSectionComponents | AllowedSectionAccessoryComponents], +): + """Represents a Section from Components V2. + + This is a component that groups other components together with an additional component to the right as the accessory. + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + + Attributes + ---------- + components: List[:class:`Component`] + The components contained in this section. Currently supports :class:`TextDisplay`. + accessory: Optional[:class:`Component`] + The accessory attached to this Section. Currently supports :class:`Button` and :class:`Thumbnail`. + + Parameters + ---------- + components: Sequence[:class:`AllowedSectionComponents`] + The components contained in this section. Currently supports :class:`TextDisplay`. + accessory: Optional[:class:`AllowedSectionAccessoryComponents`] + The accessory attached to this Section. Currently supports :class:`Button` and :class:`Thumbnail`. + id: Optional[:class:`int`] + The section's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + """ + + __slots__: tuple[str, ...] = ("components", "accessory") + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.section] = ComponentType.section # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + components: Sequence[AllowedSectionComponents], + accessory: AllowedSectionAccessoryComponents | None = None, + id: int | None = None, + ): + self.components: list[AllowedSectionComponents] = list(components) # pyright: ignore[reportIncompatibleVariableOverride] + self.accessory: AllowedSectionAccessoryComponents | None = accessory + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: SectionComponentPayload, state: ConnectionState | None = None) -> Self: + # self.id: int = data.get("id") + components: list[AllowedSectionComponents] = cast( + "list[AllowedSectionComponents]", + [_component_factory(d, state=state) for d in payload.get("components", [])], + ) + accessory: AllowedSectionAccessoryComponents | None = None + if _accessory := payload.get("accessory"): + accessory = cast("AllowedSectionAccessoryComponents", _component_factory(_accessory, state=state)) + return cls( + components=components, + accessory=accessory, + id=payload.get("id"), + ) + + @override + def to_dict(self) -> SectionComponentPayload: + payload: SectionComponentPayload = { # pyright: ignore[reportAssignmentType] + "type": int(self.type), + "id": self.id, + "components": [c.to_dict() for c in self.components], + } + if self.accessory: + payload["accessory"] = self.accessory.to_dict() + return payload + + +class MediaGalleryItem: + """Represents an item used in the :class:`MediaGallery` component. + + This is used as an underlying component for other media-based components such as :class:`Thumbnail`, :class:`FileComponent`, and :class:`MediaGalleryItem`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + url: :class:`str` + The URL of this gallery item. This can either be an arbitrary URL or an ``attachment://`` URL to work with local files. + description: Optional[:class:`str`] + The gallery item's description, up to 1024 characters. + spoiler: Optional[:class:`bool`] + Whether the gallery item is a spoiler. + """ + + def __init__(self, url: str, *, description: str | None = None, spoiler: bool = False): + self._state: ConnectionState | None = None + self.media: UnfurledMediaItem = UnfurledMediaItem(url) + self.description: str | None = description + self.spoiler: bool = spoiler + + @property + def url(self) -> str: + """Returns the URL of this gallery's underlying media item.""" + return self.media.url + + def is_dispatchable(self) -> bool: + return False + + @classmethod + def from_payload(cls, data: MediaGalleryItemPayload, state: ConnectionState | None = None) -> MediaGalleryItem: + media = (umi := data.get("media")) and UnfurledMediaItem.from_dict(umi, state=state) + description = data.get("description") + spoiler = data.get("spoiler", False) + + r = cls( + url=media.url, + description=description, + spoiler=spoiler, + ) + r._state = state + r.media = media + return r + + def to_dict(self) -> MediaGalleryItemPayload: + payload: MediaGalleryItemPayload = {"media": self.media.to_dict()} + if self.description: + payload["description"] = self.description + payload["spoiler"] = self.spoiler + return payload + + +class MediaGallery(StateComponent[MediaGalleryComponentPayload]): + """Represents a Media Gallery from Components V2. + + This is a component that displays up to 10 different :class:`MediaGalleryItem` objects. + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + items: List[:class:`MediaGalleryItem`] + The media this gallery contains. + """ + + __slots__: tuple[str, ...] = ("items",) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.media_gallery] = ComponentType.media_gallery # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__(self, items: Sequence[MediaGalleryItem], id: int | None = None): + self.items: list[MediaGalleryItem] = list(items) + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: MediaGalleryComponentPayload, state: ConnectionState | None = None) -> Self: + items = [MediaGalleryItem.from_payload(d, state=state) for d in payload.get("items", [])] + return cls(items, id=payload.get("id")) + + @override + def to_dict(self) -> MediaGalleryComponentPayload: + return { # pyright: ignore[reportReturnType] + "type": int(self.type), + "id": self.id, + "items": [i.to_dict() for i in self.items], + } + + +class FileComponent(StateComponent[FileComponentPayload]): + """Represents a File from Components V2. + + This component displays a downloadable file in a message. + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + file: :class:`UnfurledMediaItem` + The file's media item. + name: :class:`str` + The file's name. + size: :class:`int` + The file's size in bytes. + spoiler: Optional[:class:`bool`] + Whether the file has the spoiler overlay. + """ + + __slots__: tuple[str, ...] = ( + "file", + "spoiler", + "name", + "size", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.file] = ComponentType.file # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + url: str | UnfurledMediaItem, + *, + spoiler: bool | None = False, + id: int | None = None, + size: int | None = None, + name: str | None = None, + ) -> None: + self.file: UnfurledMediaItem = url if isinstance(url, UnfurledMediaItem) else UnfurledMediaItem(url) + self.spoiler: bool | None = bool(spoiler) if spoiler is not None else None + self.size: int | None = size + self.name: str | None = name + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: FileComponentPayload, state: ConnectionState | None = None) -> Self: + file = UnfurledMediaItem.from_dict(payload.get("file", {}), state=state) + return cls( + file, spoiler=payload.get("spoiler"), id=payload.get("id"), size=payload["size"], name=payload["name"] + ) + + @override + def to_dict(self) -> FileComponentPayload: + payload = {"type": int(self.type), "id": self.id, "file": self.file.to_dict()} + if self.spoiler is not None: + payload["spoiler"] = self.spoiler + return payload # type: ignore # pyright: ignore[reportReturnType] + + @property + def url(self) -> str: + return self.file.url + + @url.setter + def url(self, url: str) -> None: + self.file = UnfurledMediaItem(url) + + +class Separator(Component[SeparatorComponentPayload]): + """Represents a Separator from Components V2. + + This is a component that visually separates components. + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + divider: :class:`bool` + Whether the separator will show a horizontal line in addition to vertical spacing. + spacing: Optional[:class:`SeparatorSpacingSize`] + The separator's spacing size. + """ + + __slots__: tuple[str, ...] = ( + "divider", + "spacing", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.separator] = ComponentType.separator # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, divider: bool = True, spacing: SeparatorSpacingSize = SeparatorSpacingSize.small, id: int | None = None + ) -> None: + self.divider: bool = divider + self.spacing: SeparatorSpacingSize = spacing + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: SeparatorComponentPayload) -> Self: + self = cls( + divider=payload.get("divider", False), spacing=try_enum(SeparatorSpacingSize, payload.get("spacing", 1)) + ) + self.id = payload.get("id") + return self + + @override + def to_dict(self) -> SeparatorComponentPayload: + return { # pyright: ignore[reportReturnType] + "type": int(self.type), + "id": self.id, + "divider": self.divider, + "spacing": int(self.spacing), + } # type: ignore + + +AllowedActionRowComponents = Button | InputText | SelectMenu[SelectMenuTypes] + + +class ActionRow(WalkableComponent[ActionRowPayload, AllowedActionRowComponents]): + """Represents a Discord Bot UI Kit Action Row. + + This is a component that holds up to 5 children components in a row. + + This inherits from :class:`Component`. + + .. versionadded:: 2.0 + .. versionchanged:: 3.0 + + Attributes + ---------- + type: :class:`ComponentType` + The type of component. + components: List[:class:`AllowedActionRowComponents`] + The components that this ActionRow holds, if any. + id: Optional[:class:`int`] + The action row's ID. If not provided, it is set sequentially by Discord. + The ID `0` is treated as if no ID was provided. + + Parameters + ---------- + components: Sequence[:class:`AllowedActionRowComponents`] + + """ + + __slots__: tuple[str, ...] = ("components",) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (1, 2) + type: Literal[ComponentType.action_row] = ComponentType.action_row # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__(self, components: Sequence[AllowedActionRowComponents], id: int | None = None) -> None: + self.components: list[AllowedActionRowComponents] = list(components) + super().__init__(id=id) + + @classmethod + @override + def from_payload(cls, payload: ActionRowPayload) -> Self: + components: list[AllowedActionRowComponents] = cast( + "list[AllowedActionRowComponents]", [_component_factory(d) for d in payload.get("components", [])] + ) + return cls(components, id=payload.get("id")) + + @property + def width(self): + """Return the sum of the components' widths.""" + return sum(getattr(c, "width", 0) for c in self.components) + + @override + def to_dict(self) -> ActionRowPayload: + return { # pyright: ignore[reportReturnType] + "type": int(self.type), + "id": self.id, + "components": [component.to_dict() for component in self.components], + } # type: ignore + + +AllowedContainerComponents = ActionRow | TextDisplay | Section | MediaGallery | Separator | FileComponent + + +class Container(WalkableComponent[ContainerComponentPayload, AllowedContainerComponents]): + """Represents a Container from Components V2. + + This is a component that contains different :class:`Component` objects. + It may only contain: + + - :class:`ActionRow` + - :class:`TextDisplay` + - :class:`Section` + - :class:`MediaGallery` + - :class:`Separator` + - :class:`FileComponent` + + This inherits from :class:`Component`. + + .. versionadded:: 2.7 + .. versionchanged:: 3.0 + + Attributes + ---------- + components: List[:class:`Component`] + The components contained in this container. + accent_color: Optional[:class:`Colour`] + The accent color of the container. + spoiler: Optional[:class:`bool`] + Whether the entire container has the spoiler overlay. + """ + + __slots__: tuple[str, ...] = ( + "accent_color", + "spoiler", + "components", + ) + + __repr_info__: ClassVar[tuple[str, ...]] = __slots__ + versions: tuple[int, ...] = (2,) + type: Literal[ComponentType.container] = ComponentType.container # pyright: ignore[reportIncompatibleVariableOverride] + + def __init__( + self, + accent_color: Colour | None = None, + spoiler: bool | None = False, + id: int | None = None, + *, + components: Sequence[AllowedContainerComponents] = (), + ) -> None: + self.accent_color: Colour | None = accent_color + self.spoiler: bool | None = spoiler + self.components: list[AllowedContainerComponents] = list(components) + super().__init__(id=id) + + @override + def to_dict(self) -> ContainerComponentPayload: + payload: ContainerComponentPayload = { + "type": int(self.type), # pyright: ignore[reportAssignmentType] + "id": self.id, + "components": [c.to_dict() for c in self.components], + } + if self.accent_color: + payload["accent_color"] = self.accent_color.value + if self.spoiler is not None: + payload["spoiler"] = self.spoiler + return payload + + @classmethod + @override + def from_payload(cls, payload: ContainerComponentPayload, state: ConnectionState | None = None) -> Self: + components: list[AllowedContainerComponents] = cast( + "list[AllowedContainerComponents]", + [_component_factory(d, state=state) for d in payload.get("components", [])], + ) + accent_color = Colour(c) if (c := payload.get("accent_color") is not None) else None + return cls( + accent_color=accent_color, + spoiler=payload.get("spoiler"), + id=payload.get("id"), + components=components, + ) + + +class UnknownComponent(Component[ComponentPayload]): + """Represents an unknown component. + + This is used when the component type is not recognized by the library, + for example if a new component is introduced by Discord. + + .. versionadded:: 3.0 + + Attributes + ---------- + type: :class:`ComponentType` + The type of the unknown component. + + """ + + __slots__: tuple[str, ...] = ("type",) + + def __init__(self, type: ComponentType, id: int | None = None) -> None: + self.type: ComponentType = type + super().__init__(id=id) + + @override + def to_dict(self) -> ComponentPayload: + return {"type": int(self.type)} # pyright: ignore[reportReturnType] + + @classmethod + @override + def from_payload(cls, payload: ComponentPayload) -> Self: + type_ = try_enum(ComponentType, payload.pop("type", 0)) + self = cls(type_, id=payload.pop("id", None)) + for key, value in payload.items(): + setattr(self, key, value) + return self + + +COMPONENT_MAPPINGS = { + 1: ActionRow, + 2: Button, + 3: StringSelectMenu, + 4: InputText, + 5: UserSelectMenu, + 6: RoleSelectMenu, + 7: MentionableSelectMenu, + 8: ChannelSelectMenu, + 9: Section, + 10: TextDisplay, + 11: Thumbnail, + 12: MediaGallery, + 13: FileComponent, + 14: Separator, + 17: Container, +} + +STATE_COMPONENTS = (Section, Container, Thumbnail, MediaGallery, FileComponent) + +def _component_factory(data: P, state: ConnectionState | None = None) -> Component[P]: + component_type = data["type"] + if cls := COMPONENT_MAPPINGS.get(component_type): + if issubclass(cls, StateComponent): + return cls(data, state=state) # pyright: ignore[reportCallIssue, reportReturnType] + else: + return cls(data) # pyright: ignore[reportArgumentType, reportCallIssue, reportReturnType] + else: + return UnknownComponent.from_payload(data) # pyright: ignore[reportReturnType] + + +AnyComponent = ( + ActionRow + | Button + | StringSelectMenu + | InputText + | UserSelectMenu + | RoleSelectMenu + | MentionableSelectMenu + | ChannelSelectMenu + | Section + | TextDisplay + | Thumbnail + | MediaGallery + | FileComponent + | Separator + | Container + | UnknownComponent +) diff --git a/discord/embeds.py b/discord/embeds.py index cab91a5176..b3194447d5 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -28,6 +28,7 @@ import datetime from typing import TYPE_CHECKING, Any, Mapping, TypeVar +from .utils.private import parse_time from . import utils from .colour import Colour @@ -437,7 +438,7 @@ def from_dict(cls: type[E], data: Mapping[str, Any]) -> E: pass try: - self._timestamp = utils.parse_time(data["timestamp"]) + self._timestamp = parse_time(data["timestamp"]) except KeyError: pass diff --git a/discord/emoji.py b/discord/emoji.py index 5dd1eafd3d..40df9ca252 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -30,7 +30,8 @@ from .asset import Asset, AssetMixin from .partial_emoji import PartialEmoji, _EmojiTag from .user import User -from .utils import MISSING, SnowflakeList, Undefined, snowflake_time +from .utils import MISSING, Undefined, snowflake_time +from .utils.private import SnowflakeList __all__ = ( "Emoji", diff --git a/discord/enums.py b/discord/enums.py index 4de87b9ff3..8c3533af53 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -78,6 +78,8 @@ "InteractionContextType", "PollLayoutType", "MessageReferenceType", + "SubscriptionStatus", + "SeparatorSpacingSize", ) @@ -702,6 +704,14 @@ class ComponentType(Enum): role_select = 6 mentionable_select = 7 channel_select = 8 + section = 9 + text_display = 10 + thumbnail = 11 + media_gallery = 12 + file = 13 + separator = 14 + content_inventory_entry = 16 + container = 17 def __int__(self): return self.value @@ -1058,6 +1068,16 @@ class SubscriptionStatus(Enum): inactive = 2 +class SeparatorSpacingSize(Enum): + """A separator component's spacing size.""" + + small = 1 + large = 2 + + def __int__(self): + return self.value + + T = TypeVar("T") diff --git a/discord/ext/bridge/core.py b/discord/ext/bridge/core.py index 1c3c5d3d1b..6f319f07a7 100644 --- a/discord/ext/bridge/core.py +++ b/discord/ext/bridge/core.py @@ -40,7 +40,8 @@ SlashCommandOptionType, ) -from ...utils import MISSING, find, get, warn_deprecated +from discord.utils import MISSING, find +from discord.utils.private import warn_deprecated from ..commands import ( BadArgument, ) @@ -608,7 +609,7 @@ async def convert(self, ctx, argument: str) -> Any: if self.choices: choices_names: list[str | int | float] = [choice.name for choice in self.choices] - if converted in choices_names and (choice := get(self.choices, name=converted)): + if converted in choices_names and (choice := find(lambda c: c.name == converted, self.choices)): converted = choice.value else: choices = [choice.value for choice in self.choices] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 2e038fbc92..09d7f7a22d 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -33,6 +33,7 @@ import discord from discord.utils import Undefined +from discord.utils.private import copy_doc, maybe_awaitable, async_all from . import errors from .context import Context @@ -132,7 +133,7 @@ def __init__( self.help_command = DefaultHelpCommand() if help_command is MISSING else help_command self.strip_after_prefix = options.get("strip_after_prefix", False) - @discord.utils.copy_doc(discord.Client.close) + @copy_doc(discord.Client.close) async def close(self) -> None: for extension in tuple(self.__extensions): try: @@ -179,7 +180,7 @@ async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: return True # type-checker doesn't distinguish between functions and methods - return await discord.utils.async_all(f(ctx) for f in data) # type: ignore + return await async_all(f(ctx) for f in data) # type: ignore # help command stuff @@ -223,7 +224,7 @@ async def get_prefix(self, message: Message) -> list[str] | str: """ prefix = ret = self.command_prefix if callable(prefix): - ret = await discord.utils.maybe_coroutine(prefix, self, message) + ret = await maybe_awaitable(prefix, self, message) if not isinstance(ret, str): try: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 79e2cfba6e..55e376c9c8 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -30,8 +30,9 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union import discord.abc -import discord.utils from discord.message import Message +from discord.utils.private import copy_doc +from discord.utils import Undefined, MISSING if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -50,8 +51,6 @@ __all__ = ("Context",) -MISSING: Any = discord.utils.MISSING - T = TypeVar("T") BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") @@ -125,12 +124,12 @@ def __init__( message: Message, bot: BotT, view: StringView, - args: list[Any] | discord.utils.Undefined = MISSING, - kwargs: dict[str, Any] | discord.utils.Undefined = MISSING, + args: list[Any] | Undefined = MISSING, + kwargs: dict[str, Any] | Undefined = MISSING, prefix: str | None = None, command: Command | None = None, invoked_with: str | None = None, - invoked_parents: list[str] | discord.utils.Undefined = MISSING, + invoked_parents: list[str] | Undefined = MISSING, invoked_subcommand: Command | None = None, subcommand_passed: str | None = None, command_failed: bool = False, @@ -281,28 +280,28 @@ def cog(self) -> Cog | None: return None return self.command.cog - @discord.utils.cached_property + @property def guild(self) -> Guild | None: """Returns the guild associated with this context's command. None if not available. """ return self.message.guild - @discord.utils.cached_property + @property def channel(self) -> MessageableChannel: """Returns the channel associated with this context's command. Shorthand for :attr:`.Message.channel`. """ return self.message.channel - @discord.utils.cached_property + @property def author(self) -> User | Member: """Union[:class:`~discord.User`, :class:`.Member`]: Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` """ return self.message.author - @discord.utils.cached_property + @property def me(self) -> Member | ClientUser: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message @@ -398,10 +397,10 @@ async def send_help(self, *args: Any) -> Any: except CommandError as e: await cmd.on_help_command_error(self, e) - @discord.utils.copy_doc(Message.reply) + @copy_doc(Message.reply) async def reply(self, content: str | None = None, **kwargs: Any) -> Message: return await self.message.reply(content, **kwargs) - @discord.utils.copy_doc(Message.forward_to) + @copy_doc(Message.forward_to) async def forward_to(self, channel: discord.abc.Messageable, **kwargs: Any) -> Message: return await self.message.forward_to(channel, **kwargs) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index e9352e4fce..d34815001d 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -89,7 +89,6 @@ def _get_from_guilds(bot, getter, argument): return result -_utils_get = discord.utils.get T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) CT = TypeVar("CT", bound=discord.abc.GuildChannel) @@ -194,7 +193,7 @@ async def query_member_named(self, guild, argument): if len(argument) > 5 and argument[-5] == "#": username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get(members, name=username, discriminator=discriminator) + return discord.utils.find(lambda m: m.name == username and m.discriminator == discriminator, members) members = await guild.query_members(argument, limit=100, cache=cache) return discord.utils.find( lambda m: argument in (m.nick, m.name, m.global_name), @@ -239,7 +238,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Member: if guild: result = guild.get_member(user_id) if ctx.message is not None and result is None: - result = _utils_get(ctx.message.mentions, id=user_id) + result = discord.utils.find(lambda e: e.id == user_id, ctx.message.mentions) else: result = _get_from_guilds(bot, "get_member", user_id) @@ -287,7 +286,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.User: user_id = int(match.group(1)) result = ctx.bot.get_user(user_id) if ctx.message is not None and result is None: - result = _utils_get(ctx.message.mentions, id=user_id) + result = discord.utils.find(lambda e: e.id == user_id, ctx.message.mentions) if result is None: try: result = await ctx.bot.fetch_user(user_id) @@ -441,7 +440,7 @@ def _resolve_channel(ctx: Context, argument: str, attribute: str, type: type[CT] # not a mention if guild: iterable: Iterable[CT] = getattr(guild, attribute) - result: CT | None = discord.utils.get(iterable, name=argument) + result: CT | None = discord.utils.find(lambda e: e.name == argument, iterable) else: def check(c): @@ -470,7 +469,7 @@ def _resolve_thread(ctx: Context, argument: str, attribute: str, type: type[TT]) # not a mention if guild: iterable: Iterable[TT] = getattr(guild, attribute) - result: TT | None = discord.utils.get(iterable, name=argument) + result: TT | None = discord.utils.find(lambda e: e.name == argument, iterable) else: thread_id = int(match.group(1)) if guild: @@ -709,7 +708,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Role: if match: result = guild.get_role(int(match.group(1))) else: - result = discord.utils.get(guild._roles.values(), name=argument) + result = discord.utils.find(lambda e: e.name == argument, guild._roles.values()) if result is None: raise RoleNotFound(argument) @@ -760,7 +759,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Guild: result = ctx.bot.get_guild(guild_id) if result is None: - result = discord.utils.get(ctx.bot.guilds, name=argument) + result = discord.utils.find(lambda e: e.name == argument, ctx.bot.guilds) if result is None: raise GuildNotFound(argument) @@ -792,10 +791,10 @@ async def convert(self, ctx: Context, argument: str) -> discord.GuildEmoji: if match is None: # Try to get the emoji by name. Try local guild first. if guild: - result = discord.utils.get(guild.emojis, name=argument) + result = discord.utils.find(lambda e: e.name == argument, guild.emojis) if result is None: - result = discord.utils.get(bot.emojis, name=argument) + result = discord.utils.find(lambda e: e.name == argument, bot.emojis) else: emoji_id = int(match.group(1)) @@ -858,10 +857,10 @@ async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker: if match is None: # Try to get the sticker by name. Try local guild first. if guild: - result = discord.utils.get(guild.stickers, name=argument) + result = discord.utils.find(lambda s: s.name == argument, guild.stickers) if result is None: - result = discord.utils.get(bot.stickers, name=argument) + result = discord.utils.find(lambda s: s.name == argument, bot.stickers) else: sticker_id = int(match.group(1)) @@ -913,17 +912,23 @@ async def convert(self, ctx: Context, argument: str) -> str: if ctx.guild: def resolve_member(id: int) -> str: - m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_member(id) + m = ( + None if msg is None else discord.utils.find(lambda e: e.id == id, msg.mentions) + ) or ctx.guild.get_member(id) return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: - r = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_role(id) + r = ( + None if msg is None else discord.utils.find(lambda e: e.id == id, msg.mentions) + ) or ctx.guild.get_role(id) return f"@{r.name}" if r else "@deleted-role" else: def resolve_member(id: int) -> str: - m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.bot.get_user(id) + m = ( + None if msg is None else discord.utils.find(lambda e: e.id == id, msg.mentions) + ) or ctx.bot.get_user(id) return f"@{m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index ac2c7a1073..fd90615f5b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -43,6 +43,7 @@ ) import discord +from discord.utils.private import evaluate_annotation, async_all, maybe_awaitable from discord import utils from discord.utils import Undefined @@ -138,7 +139,6 @@ def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, A signature = inspect.signature(function) params = {} cache: dict[str, Any] = {} - eval_annotation = discord.utils.evaluate_annotation for name, parameter in signature.parameters.items(): annotation = parameter.annotation if annotation is parameter.empty: @@ -148,7 +148,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, A params[name] = parameter.replace(annotation=type(None)) continue - annotation = eval_annotation(annotation, globalns, globalns, cache) + annotation = evaluate_annotation(annotation, globalns, globalns, cache) if annotation is Greedy: raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") @@ -1146,7 +1146,7 @@ async def can_run(self, ctx: Context) -> bool: if cog is not None: local_check = Cog._get_overridden_method(cog.cog_check) if local_check is not None: - ret = await discord.utils.maybe_coroutine(local_check, ctx) + ret = await maybe_awaitable(local_check, ctx) if not ret: return False @@ -1155,7 +1155,7 @@ async def can_run(self, ctx: Context) -> bool: # since we have no checks, then we just return True. return True - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore + return await async_all(predicate(ctx) for predicate in predicates) # type: ignore finally: ctx.command = original @@ -1865,9 +1865,9 @@ def predicate(ctx: Context) -> bool: # ctx.guild is None doesn't narrow ctx.author to Member if isinstance(item, int): - role = discord.utils.get(ctx.author.roles, id=item) # type: ignore + role = discord.utils.find(lambda r: r.id == item, ctx.author.roles) # type: ignore else: - role = discord.utils.get(ctx.author.roles, name=item) # type: ignore + role = discord.utils.find(lambda r: r.name == item, ctx.author.roles) # type: ignore if role is None: raise MissingRole(item) return True @@ -1912,9 +1912,14 @@ def predicate(ctx): raise NoPrivateMessage() # ctx.guild is None doesn't narrow ctx.author to Member - getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore + getter = functools.partial(discord.utils.find, seq=ctx.author.roles) # type: ignore if any( - (getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None) for item in items + ( + getter(lambda e: e.id == item) is not None + if isinstance(item, int) + else getter(lambda e: e.name == item) is not None + ) + for item in items ): return True raise MissingAnyRole(list(items)) @@ -1942,9 +1947,9 @@ def predicate(ctx): me = ctx.me if isinstance(item, int): - role = discord.utils.get(me.roles, id=item) + role = discord.utils.find(lambda r: r.id == item, me.roles) else: - role = discord.utils.get(me.roles, name=item) + role = discord.utils.find(lambda r: r.name == item, me.roles) if role is None: raise BotMissingRole(item) return True @@ -1971,9 +1976,14 @@ def predicate(ctx): raise NoPrivateMessage() me = ctx.me - getter = functools.partial(discord.utils.get, me.roles) + getter = functools.partial(discord.utils.find, seq=me.roles) if any( - (getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None) for item in items + ( + getter(lambda e: e.id == item) is not None + if isinstance(item, int) + else getter(lambda e: e.name == item) is not None + ) + for item in items ): return True raise BotMissingAnyRole(list(items)) diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index d1d8b23fce..8c023a9e67 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -32,7 +32,8 @@ from typing import TYPE_CHECKING, Any, Iterator, Literal, Pattern, TypeVar, Union from discord import utils -from discord.utils import MISSING, Undefined, maybe_coroutine, resolve_annotation +from discord.utils import MISSING, Undefined +from discord.utils.private import resolve_annotation, maybe_awaitable from .converter import run_converters from .errors import ( @@ -56,7 +57,7 @@ def _missing_field_factory() -> field: - return field(default_factory=lambda: MISSING) + return field(default_factory=lambda: utils.MISSING) @dataclass @@ -86,13 +87,13 @@ class Flag: Whether multiple given values overrides the previous value. """ - name: str | Undefined = _missing_field_factory() # noqa: RUF009 + name: str | utils.Undefined = _missing_field_factory() # noqa: RUF009 aliases: list[str] = field(default_factory=list) - attribute: str | Undefined = _missing_field_factory() # noqa: RUF009 - annotation: Any | Undefined = _missing_field_factory() # noqa: RUF009 - default: Any | Undefined = _missing_field_factory() # noqa: RUF009 - max_args: int | Undefined = _missing_field_factory() # noqa: RUF009 - override: bool | Undefined = _missing_field_factory() # noqa: RUF009 + attribute: str | utils.Undefined = _missing_field_factory() # noqa: RUF009 + annotation: Any | utils.Undefined = _missing_field_factory() # noqa: RUF009 + default: Any | utils.Undefined = _missing_field_factory() # noqa: RUF009 + max_args: int | utils.Undefined = _missing_field_factory() # noqa: RUF009 + override: bool | utils.Undefined = _missing_field_factory() # noqa: RUF009 cast_to_dict: bool = False @property @@ -489,7 +490,7 @@ async def _construct_default(cls: type[F], ctx: Context) -> F: flags = cls.__commands_flags__ for flag in flags.values(): if callable(flag.default): - default = await maybe_coroutine(flag.default, ctx) + default = await maybe_awaitable(flag.default, ctx) setattr(self, flag.attribute, default) else: setattr(self, flag.attribute, flag.default) @@ -588,7 +589,7 @@ async def convert(cls: type[F], ctx: Context, argument: str) -> F: raise MissingRequiredFlag(flag) else: if callable(flag.default): - default = await maybe_coroutine(flag.default, ctx) + default = await maybe_awaitable(flag.default, ctx) setattr(self, flag.attribute, default) else: setattr(self, flag.attribute, flag.default) diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index be6221fe80..94e5c96755 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -31,7 +31,8 @@ import re from typing import TYPE_CHECKING, Any -import discord.utils +from discord import utils +from discord.utils.private import string_width, maybe_awaitable from .core import Command, Group from .errors import CommandError @@ -335,7 +336,7 @@ def __init__(self, **options): self.command_attrs = attrs = options.pop("command_attrs", {}) attrs.setdefault("name", "help") attrs.setdefault("help", "Shows this message") - self.context: Context = discord.utils.MISSING + self.context: Context = utils.MISSING self._command_impl = _HelpCommandImpl(self, **self.command_attrs) def copy(self): @@ -570,12 +571,14 @@ async def filter_commands(self, commands, *, sort=False, key=None, exclude: tupl key = lambda c: c.name # Ignore Application Commands because they don't have hidden/docs + from discord import ApplicationCommand # noqa: PLC0415 + new_commands = [ command for command in commands if not isinstance( command, - (discord.commands.ApplicationCommand, *(exclude if exclude else ())), + (ApplicationCommand, *(exclude if exclude else ())), ) ] iterator = new_commands if self.show_hidden else filter(lambda c: not c.hidden, new_commands) @@ -620,7 +623,7 @@ def get_max_size(self, commands): The maximum width of the commands. """ - as_lengths = (discord.utils._string_width(c.name) for c in commands) + as_lengths = (string_width(c.name) for c in commands) return max(as_lengths, default=0) def get_destination(self): @@ -858,8 +861,6 @@ async def command_callback(self, ctx, *, command=None): if cog is not None: return await self.send_cog_help(cog) - maybe_coro = discord.utils.maybe_coroutine - # If it's not a cog then it's a command. # Since we want to have detailed errors when someone # passes an invalid subcommand, we need to walk through @@ -867,18 +868,18 @@ async def command_callback(self, ctx, *, command=None): keys = command.split(" ") cmd = bot.all_commands.get(keys[0]) if cmd is None: - string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0])) + string = await maybe_awaitable(self.command_not_found, self.remove_mentions(keys[0])) return await self.send_error_message(string) for key in keys[1:]: try: found = cmd.all_commands.get(key) except AttributeError: - string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) + string = await maybe_awaitable(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) else: if found is None: - string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) + string = await maybe_awaitable(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) cmd = found @@ -984,10 +985,9 @@ def add_indented_commands(self, commands, *, heading, max_size=None): self.paginator.add_line(heading) max_size = max_size or self.get_max_size(commands) - get_width = discord.utils._string_width for command in commands: name = command.name - width = max_size - (get_width(name) - len(name)) + width = max_size - (string_width(name) - len(name)) entry = f"{self.indent * ' '}{name:<{width}} {command.short_doc}" self.paginator.add_line(self.shorten_text(entry)) diff --git a/discord/ext/pages/pagination.py b/discord/ext/pages/pagination.py index 082af40375..f3c1a6866f 100644 --- a/discord/ext/pages/pagination.py +++ b/discord/ext/pages/pagination.py @@ -148,8 +148,8 @@ def __init__( files: list[discord.File] | None = None, **kwargs, ): - if content is None and embeds is None: - raise discord.InvalidArgument("A page cannot have both content and embeds equal to None.") + if content is None and embeds is None and custom_view is None: + raise discord.InvalidArgument("A page must at least have content, embeds, or custom_view set.") self._content = content self._embeds = embeds or [] self._custom_view = custom_view @@ -532,8 +532,9 @@ async def update( async def on_timeout(self) -> None: """Disables all buttons when the view times out.""" if self.disable_on_timeout: - for item in self.children: - item.disabled = True + for item in self.walk_children(): + if hasattr(item, "disabled"): + item.disabled = True page = self.pages[self.current_page] page = self.get_page_content(page) files = page.update_files() @@ -558,8 +559,10 @@ async def disable( The page content to show after disabling the paginator. """ page = self.get_page_content(page) - for item in self.children: - if include_custom or not self.custom_view or item not in self.custom_view.children: + for item in self.walk_children(): + if (include_custom or not self.custom_view or item not in self.custom_view.children) and hasattr( + item, "disabled" + ): item.disabled = True if page: await self.message.edit( @@ -841,6 +844,8 @@ def get_page_content( return Page(content=None, embeds=[page], files=[]) elif isinstance(page, discord.File): return Page(content=None, embeds=[], files=[page]) + elif isinstance(page, discord.ui.View): + return Page(content=None, embeds=[], files=[], custom_view=page) elif isinstance(page, List): if all(isinstance(x, discord.Embed) for x in page): return Page(content=None, embeds=page, files=[]) @@ -850,7 +855,8 @@ def get_page_content( raise TypeError("All list items must be embeds or files.") else: raise TypeError( - "Page content must be a Page object, string, an embed, a list of embeds, a file, or a list of files." + "Page content must be a Page object, string, an embed, a view, a list of" + " embeds, a file, or a list of files." ) async def page_action(self, interaction: discord.Interaction | None = None) -> None: diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 8fecd9a85f..bbff6f3750 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -49,6 +49,13 @@ ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) +def compute_timedelta(dt: datetime.datetime): + if dt.tzinfo is None: + dt = dt.astimezone() + now = datetime.datetime.now(datetime.timezone.utc) + return max((dt - now).total_seconds(), 0) + + class SleepHandle: __slots__ = ("future", "loop", "handle") diff --git a/discord/flags.py b/discord/flags.py index 6055ad3ec3..8952c52d76 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -321,22 +321,22 @@ class MessageFlags(BaseFlags): @flag_value def crossposted(self): """:class:`bool`: Returns ``True`` if the message is the original crossposted message.""" - return 1 + return 1 << 0 @flag_value def is_crossposted(self): """:class:`bool`: Returns ``True`` if the message was crossposted from another channel.""" - return 2 + return 1 << 1 @flag_value def suppress_embeds(self): """:class:`bool`: Returns ``True`` if the message's embeds have been suppressed.""" - return 4 + return 1 << 2 @flag_value def source_message_deleted(self): """:class:`bool`: Returns ``True`` if the source message for this crosspost has been deleted.""" - return 8 + return 1 << 3 @flag_value def urgent(self): @@ -344,7 +344,7 @@ def urgent(self): An urgent message is one sent by Discord Trust and Safety. """ - return 16 + return 1 << 4 @flag_value def has_thread(self): @@ -352,7 +352,7 @@ def has_thread(self): .. versionadded:: 2.0 """ - return 32 + return 1 << 5 @flag_value def ephemeral(self): @@ -360,7 +360,7 @@ def ephemeral(self): .. versionadded:: 2.0 """ - return 64 + return 1 << 6 @flag_value def loading(self): @@ -370,7 +370,7 @@ def loading(self): .. versionadded:: 2.0 """ - return 128 + return 1 << 7 @flag_value def failed_to_mention_some_roles_in_thread(self): @@ -378,7 +378,7 @@ def failed_to_mention_some_roles_in_thread(self): .. versionadded:: 2.0 """ - return 256 + return 1 << 8 @flag_value def suppress_notifications(self): @@ -389,7 +389,7 @@ def suppress_notifications(self): .. versionadded:: 2.4 """ - return 4096 + return 1 << 12 @flag_value def is_voice_message(self): @@ -397,7 +397,15 @@ def is_voice_message(self): .. versionadded:: 2.5 """ - return 8192 + return 1 << 13 + + @flag_value + def is_components_v2(self): + """:class:`bool`: Returns ``True`` if this message has v2 components. This flag disables sending `content`, `embed`, and `embeds`. + + .. versionadded:: 2.7 + """ + return 1 << 15 @flag_value def has_snapshot(self): diff --git a/discord/gateway.py b/discord/gateway.py index 05f1277f7d..67a974c162 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -35,6 +35,7 @@ import traceback import zlib from collections import deque, namedtuple +from typing import TYPE_CHECKING import aiohttp @@ -42,6 +43,7 @@ from .activity import BaseActivity from .enums import SpeakingState from .errors import ConnectionClosed, InvalidArgument +from .utils.private import from_json, to_json _log = logging.getLogger(__name__) @@ -200,6 +202,9 @@ def ack(self): class VoiceKeepAliveHandler(KeepAliveHandler): + if TYPE_CHECKING: + ws: DiscordVoiceWebSocket + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.recent_ack_latencies = deque(maxlen=20) @@ -208,7 +213,10 @@ def __init__(self, *args, **kwargs): self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind" def get_payload(self): - return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)} + return { + "op": self.ws.HEARTBEAT, + "d": {"t": int(time.time() * 1000), "seq_ack": self.ws.seq_ack}, + } def ack(self): ack_time = time.perf_counter() @@ -450,7 +458,7 @@ async def received_message(self, msg, /): self._buffer = bytearray() self.log_receive(msg) - msg = utils._from_json(msg) + msg = from_json(msg) _log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg) event = msg.get("t") @@ -637,7 +645,7 @@ async def send(self, data, /): async def send_as_json(self, data): try: - await self.send(utils._to_json(data)) + await self.send(to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc @@ -645,7 +653,7 @@ async def send_as_json(self, data): async def send_heartbeat(self, data): # This bypasses the rate limit handling code since it has a higher priority try: - await self.socket.send_str(utils._to_json(data)) + await self.socket.send_str(to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc @@ -671,7 +679,7 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): }, } - sent = utils._to_json(payload) + sent = to_json(payload) _log.debug('Sending "%s" to change status', sent) await self.send(sent) @@ -766,6 +774,7 @@ def __init__(self, socket, loop, *, hook=None): self._close_code = None self.secret_key = None self.ssrc_map = {} + self.seq_ack: int = -1 if hook: self._hook = hook @@ -774,7 +783,7 @@ async def _hook(self, *args): async def send_as_json(self, data): _log.debug("Sending voice websocket frame: %s.", data) - await self.ws.send_str(utils._to_json(data)) + await self.ws.send_str(to_json(data)) send_heartbeat = send_as_json @@ -786,6 +795,9 @@ async def resume(self): "token": state.token, "server_id": str(state.server_id), "session_id": state.session_id, + # this seq_ack will allow for us to do buffered resume, which is, receive the + # lost voice packets while trying to resume the reconnection + "seq_ack": self.seq_ack, }, } await self.send_as_json(payload) @@ -806,7 +818,7 @@ async def identify(self): @classmethod async def from_client(cls, client, *, resume=False, hook=None): """Creates a voice websocket for the :class:`VoiceClient`.""" - gateway = f"wss://{client.endpoint}/?v=4" + gateway = f"wss://{client.endpoint}/?v=8" http = client._state.http socket = await http.ws_connect(gateway, compress=15) ws = cls(socket, loop=client.loop, hook=hook) @@ -842,7 +854,13 @@ async def client_connect(self): await self.send_as_json(payload) async def speak(self, state=SpeakingState.voice): - payload = {"op": self.SPEAKING, "d": {"speaking": int(state), "delay": 0}} + payload = { + "op": self.SPEAKING, + "d": { + "speaking": int(state), + "delay": 0, + }, + } await self.send_as_json(payload) @@ -850,6 +868,7 @@ async def received_message(self, msg): _log.debug("Voice websocket frame received: %s", msg) op = msg["op"] data = msg.get("d") + self.seq_ack = data.get("seq", self.seq_ack) if op == self.READY: await self.initial_connection(data) @@ -931,7 +950,7 @@ async def poll_event(self): # This exception is handled up the chain msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) if msg.type is aiohttp.WSMsgType.TEXT: - await self.received_message(utils._from_json(msg.data)) + await self.received_message(from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data diff --git a/discord/guild.py b/discord/guild.py index 3b8ffcad37..96db69454d 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -40,6 +40,7 @@ overload, ) +from .utils.private import get_as_snowflake, bytes_to_base64_data from . import abc, utils from .asset import Asset from .automod import AutoModAction, AutoModRule, AutoModTriggerMetadata @@ -472,7 +473,7 @@ def _from_data(self, guild: GuildPayload) -> None: ) self.features: list[GuildFeature] = guild.get("features", []) self._splash: str | None = guild.get("splash") - self._system_channel_id: int | None = utils._get_as_snowflake(guild, "system_channel_id") + self._system_channel_id: int | None = get_as_snowflake(guild, "system_channel_id") self.description: str | None = guild.get("description") self.max_presences: int | None = guild.get("max_presences") self.max_members: int | None = guild.get("max_members") @@ -483,8 +484,8 @@ def _from_data(self, guild: GuildPayload) -> None: self._system_channel_flags: int = guild.get("system_channel_flags", 0) self.preferred_locale: str | None = guild.get("preferred_locale") self._discovery_splash: str | None = guild.get("discovery_splash") - self._rules_channel_id: int | None = utils._get_as_snowflake(guild, "rules_channel_id") - self._public_updates_channel_id: int | None = utils._get_as_snowflake(guild, "public_updates_channel_id") + self._rules_channel_id: int | None = get_as_snowflake(guild, "rules_channel_id") + self._public_updates_channel_id: int | None = get_as_snowflake(guild, "public_updates_channel_id") self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0)) self.approximate_presence_count = guild.get("approximate_presence_count") self.approximate_member_count = guild.get("approximate_member_count") @@ -510,8 +511,8 @@ def _from_data(self, guild: GuildPayload) -> None: self._sync(guild) self._large: bool | None = None if self._member_count is None else self._member_count >= 250 - self.owner_id: int | None = utils._get_as_snowflake(guild, "owner_id") - self.afk_channel: VoiceChannel | None = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore + self.owner_id: int | None = get_as_snowflake(guild, "owner_id") + self.afk_channel: VoiceChannel | None = self.get_channel(get_as_snowflake(guild, "afk_channel_id")) # type: ignore for obj in guild.get("voice_states", []): self._update_voice_state(obj, int(obj["channel_id"])) @@ -1019,7 +1020,7 @@ def get_member_named(self, name: str, /) -> Member | None: # do the actual lookup and return if found # if it isn't found then we'll do a full name lookup below. - result = utils.get(members, name=name[:-5], discriminator=potential_discriminator) + result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members) if result is not None: return result @@ -1713,24 +1714,24 @@ async def edit( fields["afk_timeout"] = afk_timeout if icon is not MISSING: - fields["icon"] = icon if icon is None else utils._bytes_to_base64_data(icon) + fields["icon"] = icon if icon is None else bytes_to_base64_data(icon) if banner is not MISSING: if banner is None: fields["banner"] = banner else: - fields["banner"] = utils._bytes_to_base64_data(banner) + fields["banner"] = bytes_to_base64_data(banner) if splash is not MISSING: if splash is None: fields["splash"] = splash else: - fields["splash"] = utils._bytes_to_base64_data(splash) + fields["splash"] = bytes_to_base64_data(splash) if discovery_splash is not MISSING: if discovery_splash is None: fields["discovery_splash"] = discovery_splash else: - fields["discovery_splash"] = utils._bytes_to_base64_data(discovery_splash) + fields["discovery_splash"] = bytes_to_base64_data(discovery_splash) if default_notifications is not MISSING: if not isinstance(default_notifications, NotificationLevel): @@ -2648,7 +2649,7 @@ async def create_custom_emoji( The created emoji. """ - img = utils._bytes_to_base64_data(image) + img = bytes_to_base64_data(image) role_ids = [role.id for role in roles] if roles else [] data = await self._state.http.create_custom_emoji(self.id, name, img, roles=role_ids, reason=reason) return self._state.store_emoji(self, data) @@ -2875,7 +2876,7 @@ async def create_role( if icon is None: fields["icon"] = None else: - fields["icon"] = utils._bytes_to_base64_data(icon) + fields["icon"] = bytes_to_base64_data(icon) fields["unicode_emoji"] = None if unicode_emoji is not MISSING: @@ -3681,7 +3682,7 @@ async def create_scheduled_event( payload["scheduled_end_time"] = end_time.isoformat() if image is not MISSING: - payload["image"] = utils._bytes_to_base64_data(image) + payload["image"] = bytes_to_base64_data(image) data = await self._state.http.create_scheduled_event(guild_id=self.id, reason=reason, **payload) event = ScheduledEvent(state=self._state, guild=self, creator=self.me, data=data) diff --git a/discord/http.py b/discord/http.py index 4600c045c6..205b9b253c 100644 --- a/discord/http.py +++ b/discord/http.py @@ -34,6 +34,7 @@ import aiohttp +from .utils.private import get_mime_type_for_image, to_json, from_json from . import __version__, utils from .errors import ( DiscordServerError, @@ -46,7 +47,8 @@ ) from .file import VoiceMessage from .gateway import DiscordClientWebSocketResponse -from .utils import MISSING, warn_deprecated +from .utils import MISSING +from .utils.private import warn_deprecated _log = logging.getLogger(__name__) @@ -96,7 +98,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str text = await response.text(encoding="utf-8") try: if response.headers["content-type"] == "application/json": - return utils._from_json(text) + return from_json(text) except KeyError: # Thanks Cloudflare pass @@ -259,7 +261,7 @@ async def request( # some checking if it's a JSON request if "json" in kwargs: headers["Content-Type"] = "application/json" - kwargs["data"] = utils._to_json(kwargs.pop("json")) + kwargs["data"] = to_json(kwargs.pop("json")) try: reason = kwargs.pop("reason") @@ -567,7 +569,7 @@ def send_multipart_helper( } ) payload["attachments"] = attachments - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) return self.request(route, form=form, files=files) def send_files( @@ -640,7 +642,7 @@ def edit_multipart_helper( payload["attachments"] = attachments else: payload["attachments"].extend(attachments) - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) return self.request(route, form=form, files=files) @@ -1173,6 +1175,7 @@ def start_forum_thread( allowed_mentions: message.AllowedMentions | None = None, stickers: list[sticker.StickerItem] | None = None, components: list[components.Component] | None = None, + flags: int | None = None, ) -> Response[threads.Thread]: payload: dict[str, Any] = { "name": name, @@ -1209,6 +1212,9 @@ def start_forum_thread( if stickers: message["sticker_ids"] = stickers + if flags: + message["flags"] = flags + if message != {}: payload["message"] = message @@ -1240,7 +1246,7 @@ def start_forum_thread( ) payload["attachments"] = attachments - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) return self.request(route, form=form, reason=reason) return self.request(route, json=payload, reason=reason) @@ -1649,7 +1655,7 @@ def create_guild_sticker( initial_bytes = file.fp.read(16) try: - mime_type = utils._get_mime_type_for_image(initial_bytes) + mime_type = get_mime_type_for_image(initial_bytes) except InvalidArgument: if initial_bytes.startswith(b"{"): mime_type = "application/json" @@ -2587,7 +2593,7 @@ def _edit_webhook_helper( form: list[dict[str, Any]] = [ { "name": "payload_json", - "value": utils._to_json(payload), + "value": to_json(payload), } ] diff --git a/discord/integrations.py b/discord/integrations.py index 8bb1bd80e7..709ad1b586 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -31,7 +31,8 @@ from .enums import ExpireBehaviour, try_enum from .errors import InvalidArgument from .user import User -from .utils import MISSING, _get_as_snowflake, parse_time +from .utils import MISSING +from .utils.private import get_as_snowflake, parse_time from discord import utils __all__ = ( @@ -207,7 +208,7 @@ def _from_data(self, data: StreamIntegrationPayload) -> None: self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"]) self.expire_grace_period: int = data["expire_grace_period"] self.synced_at: datetime.datetime = parse_time(data["synced_at"]) - self._role_id: int | None = _get_as_snowflake(data, "role_id") + self._role_id: int | None = get_as_snowflake(data, "role_id") self.syncing: bool = data["syncing"] self.enable_emoticons: bool = data["enable_emoticons"] self.subscriber_count: int = data["subscriber_count"] diff --git a/discord/interactions.py b/discord/interactions.py index 48ed74e8fe..c0d9882c46 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -26,9 +26,13 @@ from __future__ import annotations import asyncio +from collections.abc import Sequence import datetime from typing import TYPE_CHECKING, Any, Coroutine, Union +from .components import AnyComponent + +from .utils.private import get_as_snowflake, deprecated, delay_task, cached_slot_property from . import utils from .channel import ChannelType, PartialMessageable, _threaded_channel_factory from .enums import ( @@ -76,7 +80,7 @@ VoiceChannel, ) from .client import Client - from .commands import OptionChoice + from .commands import ApplicationCommand, OptionChoice from .embeds import Embed from .mentions import AllowedMentions from .poll import Poll @@ -86,8 +90,6 @@ from .types.interactions import InteractionData from .types.interactions import InteractionMetadata as InteractionMetadataPayload from .types.interactions import MessageInteraction as MessageInteractionPayload - from .ui.modal import Modal - from .ui.view import View InteractionChannel = Union[ VoiceChannel, @@ -153,6 +155,18 @@ class Interaction: The context in which this command was executed. .. versionadded:: 2.6 + command: Optional[:class:`ApplicationCommand`] + The command that this interaction belongs to. + + .. versionadded:: 2.7 + view: Optional[:class:`View`] + The view that this interaction belongs to. + + .. versionadded:: 2.7 + modal: Optional[:class:`Modal`] + The modal that this interaction belongs to. + + .. versionadded:: 2.7 """ __slots__: tuple[str, ...] = ( @@ -173,6 +187,7 @@ class Interaction: "entitlements", "context", "authorizing_integration_owners", + "command", "_channel_data", "_message_data", "_guild_data", @@ -200,8 +215,8 @@ def _from_data(self, data: InteractionPayload): self.data: InteractionData | None = data.get("data") self.token: str = data["token"] self.version: int = data["version"] - self.channel_id: int | None = utils._get_as_snowflake(data, "channel_id") - self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + self.channel_id: int | None = get_as_snowflake(data, "channel_id") + self.guild_id: int | None = get_as_snowflake(data, "guild_id") self.application_id: int = int(data["application_id"]) self.locale: str | None = data.get("locale") self.guild_locale: str | None = data.get("guild_locale") @@ -219,6 +234,8 @@ def _from_data(self, data: InteractionPayload): try_enum(InteractionContextType, data["context"]) if "context" in data else None ) + self.command: ApplicationCommand | None = None + self.message: Message | None = None self.channel = None @@ -296,8 +313,8 @@ def is_component(self) -> bool: """Indicates whether the interaction is a message component.""" return self.type == InteractionType.component - @utils.cached_slot_property("_cs_channel") - @utils.deprecated("Interaction.channel", "2.7", stacklevel=4) + @cached_slot_property("_cs_channel") + @deprecated("Interaction.channel", "2.7", stacklevel=4) def cached_channel(self) -> InteractionChannel | None: """The cached channel from which the interaction was sent. DM channels are not resolved. These are :class:`PartialMessageable` instead. @@ -321,12 +338,12 @@ def permissions(self) -> Permissions: """ return Permissions(self._permissions) - @utils.cached_slot_property("_cs_app_permissions") + @cached_slot_property("_cs_app_permissions") def app_permissions(self) -> Permissions: """The resolved permissions of the application in the channel, including overwrites.""" return Permissions(self._app_permissions) - @utils.cached_slot_property("_cs_response") + @cached_slot_property("_cs_response") def response(self) -> InteractionResponse: """Returns an object responsible for handling responding to the interaction. @@ -335,7 +352,7 @@ def response(self) -> InteractionResponse: """ return InteractionResponse(self) - @utils.cached_slot_property("_cs_followup") + @cached_slot_property("_cs_followup") def followup(self) -> Webhook: """Returns the followup webhook for followup interactions.""" payload = { @@ -433,7 +450,7 @@ async def original_response(self) -> InteractionMessage: self._original_response = message return message - @utils.deprecated("Interaction.original_response", "2.2") + @deprecated("Interaction.original_response", "2.2") async def original_message(self): """An alias for :meth:`original_response`. @@ -460,7 +477,7 @@ async def edit_original_response( file: File | utils.Undefined = MISSING, files: list[File] | utils.Undefined = MISSING, attachments: list[Attachment] | utils.Undefined = MISSING, - view: View | None | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | None = None, delete_after: float | None = None, suppress: bool = False, @@ -495,9 +512,11 @@ async def edit_original_response( allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. - view: Optional[:class:`~discord.ui.View`] - The updated view to update this message with. If ``None`` is passed then - the view is removed. + components: Optional[Sequence[AnyComponent]] + The updated components to update this message with. If ``None`` is passed then + the components are removed. + + ..versionadded:: 3.0 delete_after: Optional[:class:`float`] If provided, the number of seconds to wait in the background before deleting the message we just edited. If the deletion fails, @@ -553,14 +572,16 @@ async def edit_original_response( message = InteractionMessage(state=state, channel=self.channel, data=data) # type: ignore if view and not view.is_finished(): view.message = message - self._state.store_view(view, message.id) + view.refresh(message.components) + if view.is_dispatchable(): + self._state.store_view(view, message.id) if delete_after is not None: await self.delete_original_response(delay=delete_after) return message - @utils.deprecated("Interaction.edit_original_response", "2.2") + @deprecated("Interaction.edit_original_response", "2.2") async def edit_original_message(self, **kwargs): """An alias for :meth:`edit_original_response`. @@ -614,11 +635,11 @@ async def delete_original_response(self, *, delay: float | None = None) -> None: ) if delay is not None: - utils.delay_task(delay, func) + delay_task(delay, func) else: await func - @utils.deprecated("Interaction.delete_original_response", "2.2") + @deprecated("Interaction.delete_original_response", "2.2") async def delete_original_message(self, **kwargs): """An alias for :meth:`delete_original_response`. @@ -908,7 +929,7 @@ async def send_message( HTTPException Sending the message failed. TypeError - You specified both ``embed`` and ``embeds``. + You specified both ``embed`` and ``embeds``, or sent content or embeds with V2 components. ValueError The length of ``embeds`` was invalid. InteractionResponded @@ -939,6 +960,10 @@ async def send_message( if view is not None: payload["components"] = view.to_components() + if view.is_components_v2(): + if embeds or content: + raise TypeError("cannot send embeds or content with a view using v2 component logic") + flags.is_components_v2 = True if poll is not None: payload["poll"] = poll.to_dict() @@ -998,7 +1023,8 @@ async def send_message( view.timeout = 15 * 60.0 view.parent = self._parent - self._parent._state.store_view(view) + if view.is_dispatchable(): + self._parent._state.store_view(view) self._responded = True if delete_after is not None: @@ -1243,7 +1269,7 @@ async def send_modal(self, modal: Modal) -> Interaction: self._parent._state.store_modal(modal, self._parent.user.id) return self._parent - @utils.deprecated("a button with type ButtonType.premium", "2.6") + @deprecated("a button with type ButtonType.premium", "2.6") async def premium_required(self) -> Interaction: """|coro| @@ -1532,8 +1558,8 @@ def __init__(self, *, data: InteractionMetadataPayload, state: ConnectionState): self.authorizing_integration_owners: AuthorizingIntegrationOwners = AuthorizingIntegrationOwners( data["authorizing_integration_owners"], state ) - self.original_response_message_id: int | None = utils._get_as_snowflake(data, "original_response_message_id") - self.interacted_message_id: int | None = utils._get_as_snowflake(data, "interacted_message_id") + self.original_response_message_id: int | None = get_as_snowflake(data, "original_response_message_id") + self.interacted_message_id: int | None = get_as_snowflake(data, "interacted_message_id") self.triggering_interaction_metadata: InteractionMetadata | None = None if tim := data.get("triggering_interaction_metadata"): self.triggering_interaction_metadata = InteractionMetadata(data=tim, state=state) @@ -1541,7 +1567,7 @@ def __init__(self, *, data: InteractionMetadataPayload, state: ConnectionState): def __repr__(self): return f"" - @utils.cached_slot_property("_cs_original_response_message") + @cached_slot_property("_cs_original_response_message") def original_response_message(self) -> Message | None: """Optional[:class:`Message`]: The original response message. Returns ``None`` if the message is not in cache, or if :attr:`original_response_message_id` is ``None``. @@ -1550,7 +1576,7 @@ def original_response_message(self) -> Message | None: return None return self._state._get_message(self.original_response_message_id) - @utils.cached_slot_property("_cs_interacted_message") + @cached_slot_property("_cs_interacted_message") def interacted_message(self) -> Message | None: """Optional[:class:`Message`]: The message that triggered the interaction. Returns ``None`` if the message is not in cache, or if :attr:`interacted_message_id` is ``None``. @@ -1596,7 +1622,7 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - @utils.cached_slot_property("_cs_user") + @cached_slot_property("_cs_user") def user(self) -> User | None: """Optional[:class:`User`]: The user that authorized the integration. Returns ``None`` if the user is not in cache, or if :attr:`user_id` is ``None``. @@ -1605,7 +1631,7 @@ def user(self) -> User | None: return None return self._state.get_user(self.user_id) - @utils.cached_slot_property("_cs_guild") + @cached_slot_property("_cs_guild") def guild(self) -> Guild | None: """Optional[:class:`Guild`]: The guild that authorized the integration. Returns ``None`` if the guild is not in cache, or if :attr:`guild_id` is ``0`` or ``None``. diff --git a/discord/invite.py b/discord/invite.py index b301a5c3e0..b7ef549cfa 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -32,7 +32,8 @@ from .enums import ChannelType, InviteTarget, VerificationLevel, try_enum from .mixins import Hashable from .object import Object -from .utils import _get_as_snowflake, parse_time, snowflake_time +from .utils import snowflake_time +from .utils.private import get_as_snowflake, parse_time __all__ = ( "PartialInviteChannel", @@ -413,7 +414,7 @@ def from_incomplete(cls: type[I], *, state: ConnectionState, data: InvitePayload @classmethod def from_gateway(cls: type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: - guild_id: int | None = _get_as_snowflake(data, "guild_id") + guild_id: int | None = get_as_snowflake(data, "guild_id") guild: Guild | Object | None = state._get_guild(guild_id) channel_id = int(data["channel_id"]) if guild is not None: diff --git a/discord/iterators.py b/discord/iterators.py index a5dc53286b..d04827c5ab 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -41,7 +41,8 @@ from .audit_logs import AuditLogEntry from .errors import NoMoreItems from .object import Object -from .utils import maybe_coroutine, snowflake_time, time_snowflake +from .utils import generate_snowflake, snowflake_time +from .utils.private import maybe_awaitable __all__ = ( "ReactionIterator", @@ -84,20 +85,6 @@ class _AsyncIterator(AsyncIterator[T]): async def next(self) -> T: raise NotImplementedError - def get(self, **attrs: Any) -> Awaitable[T | None]: - def predicate(elem: T): - for attr, val in attrs.items(): - nested = attr.split("__") - obj = elem - for attribute in nested: - obj = getattr(obj, attribute) - - if obj != val: - return False - return True - - return self.find(predicate) - async def find(self, predicate: _Func[T, bool]) -> T | None: while True: try: @@ -105,7 +92,7 @@ async def find(self, predicate: _Func[T, bool]) -> T | None: except NoMoreItems: return None - ret = await maybe_coroutine(predicate, elem) + ret = await maybe_awaitable(predicate, elem) if ret: return elem @@ -163,7 +150,7 @@ def __init__(self, iterator, func): async def next(self) -> T: # this raises NoMoreItems and will propagate appropriately item = await self.iterator.next() - return await maybe_coroutine(self.func, item) + return await maybe_awaitable(self.func, item) class _FilteredAsyncIterator(_AsyncIterator[T]): @@ -181,7 +168,7 @@ async def next(self) -> T: while True: # propagate NoMoreItems similar to _MappedAsyncIterator item = await getter() - ret = await maybe_coroutine(pred, item) + ret = await maybe_awaitable(pred, item) if ret: return item @@ -341,11 +328,11 @@ def __init__( oldest_first=None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) if isinstance(around, datetime.datetime): - around = Object(id=time_snowflake(around)) + around = Object(id=generate_snowflake(around)) self.reverse = after is not None if oldest_first is None else oldest_first self.messageable = messageable @@ -467,9 +454,9 @@ def __init__( action_type=None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.guild = guild self.loop = guild._state.loop @@ -578,9 +565,9 @@ class GuildIterator(_AsyncIterator["Guild"]): def __init__(self, bot, limit, before=None, after=None, with_counts=True): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.bot = bot self.limit = limit @@ -665,7 +652,7 @@ async def _retrieve_guilds_after_strategy(self, retrieve): class MemberIterator(_AsyncIterator["Member"]): def __init__(self, guild, limit=1000, after=None): if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.guild = guild self.limit = limit @@ -797,7 +784,7 @@ def __init__( self.before = None elif isinstance(before, datetime.datetime): if joined: - self.before = str(time_snowflake(before, high=False)) + self.before = str(generate_snowflake(before, high=False)) else: self.before = before.isoformat() else: @@ -873,9 +860,9 @@ def __init__( after: datetime.datetime | int | None = None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.event = event self.limit = limit @@ -967,9 +954,9 @@ def __init__( self.sku_ids = sku_ids if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.before = before self.after = after @@ -1081,9 +1068,9 @@ def __init__( user_id: int | None = None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.state = state self.sku_id = sku_id diff --git a/discord/member.py b/discord/member.py index 0775f2670f..3d20c2a4fc 100644 --- a/discord/member.py +++ b/discord/member.py @@ -44,6 +44,7 @@ from .permissions import Permissions from .user import BaseUser, User, _UserTag from .utils import MISSING +from .utils.private import parse_time, SnowflakeList, copy_doc __all__ = ( "VoiceState", @@ -148,7 +149,7 @@ def _update( self.mute: bool = data.get("mute", False) self.deaf: bool = data.get("deaf", False) self.suppress: bool = data.get("suppress", False) - self.requested_to_speak_at: datetime.datetime | None = utils.parse_time(data.get("request_to_speak_timestamp")) + self.requested_to_speak_at: datetime.datetime | None = parse_time(data.get("request_to_speak_timestamp")) self.channel: VocalGuildChannel | None = channel def __repr__(self) -> str: @@ -200,7 +201,7 @@ def general(self, *args, **kwargs): return general func = generate_function(attr) - func = utils.copy_doc(value)(func) + func = copy_doc(value)(func) setattr(cls, attr, func) return cls @@ -309,16 +310,16 @@ def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: Connecti self._state: ConnectionState = state self._user: User = state.store_user(data["user"]) self.guild: Guild = guild - self.joined_at: datetime.datetime | None = utils.parse_time(data.get("joined_at")) - self.premium_since: datetime.datetime | None = utils.parse_time(data.get("premium_since")) - self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"])) + self.joined_at: datetime.datetime | None = parse_time(data.get("joined_at")) + self.premium_since: datetime.datetime | None = parse_time(data.get("premium_since")) + self._roles: SnowflakeList = SnowflakeList(map(int, data["roles"])) self._client_status: dict[str | None, str] = {None: "offline"} self.activities: tuple[ActivityTypes, ...] = () self.nick: str | None = data.get("nick", None) self.pending: bool = data.get("pending", False) self._avatar: str | None = data.get("avatar") self._banner: str | None = data.get("banner") - self.communication_disabled_until: datetime.datetime | None = utils.parse_time( + self.communication_disabled_until: datetime.datetime | None = parse_time( data.get("communication_disabled_until") ) self.flags: MemberFlags = MemberFlags._from_value(data.get("flags", 0)) @@ -359,9 +360,9 @@ def _from_message(cls: type[M], *, message: Message, data: MemberPayload) -> M: return cls(data=data, guild=message.guild, state=message._state) # type: ignore def _update_from_message(self, data: MemberPayload) -> None: - self.joined_at = utils.parse_time(data.get("joined_at")) - self.premium_since = utils.parse_time(data.get("premium_since")) - self._roles = utils.SnowflakeList(map(int, data["roles"])) + self.joined_at = parse_time(data.get("joined_at")) + self.premium_since = parse_time(data.get("premium_since")) + self._roles = SnowflakeList(map(int, data["roles"])) self.nick = data.get("nick", None) self.pending = data.get("pending", False) @@ -386,7 +387,7 @@ def _try_upgrade( def _copy(cls: type[M], member: M) -> M: self: M = cls.__new__(cls) # to bypass __init__ - self._roles = utils.SnowflakeList(member._roles, is_sorted=True) + self._roles = SnowflakeList(member._roles, is_sorted=True) self.joined_at = member.joined_at self.premium_since = member.premium_since self._client_status = member._client_status.copy() @@ -422,11 +423,11 @@ def _update(self, data: MemberPayload) -> None: except KeyError: pass - self.premium_since = utils.parse_time(data.get("premium_since")) - self._roles = utils.SnowflakeList(map(int, data["roles"])) + self.premium_since = parse_time(data.get("premium_since")) + self._roles = SnowflakeList(map(int, data["roles"])) self._avatar = data.get("avatar") self._banner = data.get("banner") - self.communication_disabled_until = utils.parse_time(data.get("communication_disabled_until")) + self.communication_disabled_until = parse_time(data.get("communication_disabled_until")) self.flags = MemberFlags._from_value(data.get("flags", 0)) def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> tuple[User, User] | None: @@ -1047,7 +1048,7 @@ async def add_roles(self, *roles: Snowflake, reason: str | None = None, atomic: """ if not atomic: - new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s) + new_roles = list({Object(id=r.id) for s in (self.roles[1:], roles) for r in s}) await self.edit(roles=new_roles, reason=reason) else: req = self._state.http.add_role diff --git a/discord/message.py b/discord/message.py index 30bcc9e4dc..c220fc5f0b 100644 --- a/discord/message.py +++ b/discord/message.py @@ -34,16 +34,17 @@ Any, Callable, ClassVar, - Sequence, TypeVar, Union, overload, ) from urllib.parse import parse_qs, urlparse +from collections.abc import Sequence +from .utils.private import get_as_snowflake, parse_time, warn_deprecated, delay_task, cached_slot_property from . import utils from .channel import PartialMessageable -from .components import _component_factory +from .components import _component_factory, AnyComponent from .embeds import Embed from .emoji import AppEmoji, GuildEmoji from .enums import ChannelType, MessageReferenceType, MessageType, try_enum @@ -91,8 +92,8 @@ from .types.snowflake import SnowflakeList from .types.threads import ThreadArchiveDuration from .types.user import User as UserPayload - from .ui.view import View from .user import User + from .components import Component MR = TypeVar("MR", bound="MessageReference") EmojiInputType = Union[GuildEmoji, AppEmoji, PartialEmoji, str] @@ -537,9 +538,9 @@ def __init__( def with_state(cls: type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR: self = cls.__new__(cls) self.type = try_enum(MessageReferenceType, data.get("type")) or MessageReferenceType.default - self.message_id = utils._get_as_snowflake(data, "message_id") - self.channel_id = utils._get_as_snowflake(data, "channel_id") - self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.message_id = get_as_snowflake(data, "message_id") + self.channel_id = get_as_snowflake(data, "channel_id") + self.guild_id = get_as_snowflake(data, "guild_id") self.fail_if_not_exists = data.get("fail_if_not_exists", True) self._state = state self.resolved = None @@ -630,7 +631,7 @@ class MessageCall: def __init__(self, state: ConnectionState, data: MessageCallPayload): self._state: ConnectionState = state self._participants: SnowflakeList = data.get("participants", []) - self._ended_timestamp: datetime.datetime | None = utils.parse_time(data["ended_timestamp"]) + self._ended_timestamp: datetime.datetime | None = parse_time(data["ended_timestamp"]) @property def participants(self) -> list[User | Object]: @@ -706,7 +707,7 @@ def __init__( self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) self.stickers: list[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])] self.components: list[Component] = [_component_factory(d) for d in data.get("components", [])] - self._edited_timestamp: datetime.datetime | None = utils.parse_time(data["edited_timestamp"]) + self._edited_timestamp: datetime.datetime | None = parse_time(data["edited_timestamp"]) @property def created_at(self) -> datetime.datetime: @@ -968,14 +969,14 @@ def __init__( self._state: ConnectionState = state self._raw_data: MessagePayload = data self.id: int = int(data["id"]) - self.webhook_id: int | None = utils._get_as_snowflake(data, "webhook_id") + self.webhook_id: int | None = get_as_snowflake(data, "webhook_id") self.reactions: list[Reaction] = [Reaction(message=self, data=d) for d in data.get("reactions", [])] self.attachments: list[Attachment] = [Attachment(data=a, state=self._state) for a in data["attachments"]] self.embeds: list[Embed] = [Embed.from_dict(a) for a in data["embeds"]] self.application: MessageApplicationPayload | None = data.get("application") self.activity: MessageActivityPayload | None = data.get("activity") self.channel: MessageableChannel = channel - self._edited_timestamp: datetime.datetime | None = utils.parse_time(data["edited_timestamp"]) + self._edited_timestamp: datetime.datetime | None = parse_time(data["edited_timestamp"]) self.type: MessageType = try_enum(MessageType, data["type"]) self.pinned: bool = data["pinned"] self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) @@ -984,13 +985,13 @@ def __init__( self.content: str = data["content"] self.nonce: int | str | None = data.get("nonce") self.stickers: list[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])] - self.components: list[Component] = [_component_factory(d) for d in data.get("components", [])] + self.components: list[Component] = [_component_factory(d, state=state) for d in data.get("components", [])] try: # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore except AttributeError: - self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id")) + self.guild = state._get_guild(get_as_snowflake(data, "guild_id")) try: ref = data["message_reference"] @@ -1149,7 +1150,7 @@ def _update(self, data): pass def _handle_edited_timestamp(self, value: str) -> None: - self._edited_timestamp = utils.parse_time(value) + self._edited_timestamp = parse_time(value) def _handle_pinned(self, value: bool) -> None: self.pinned = value @@ -1236,7 +1237,7 @@ def _handle_mention_roles(self, role_mentions: list[int]) -> None: self.role_mentions.append(role) def _handle_components(self, components: list[ComponentPayload]): - self.components = [_component_factory(d) for d in components] + self.components = [_component_factory(d, state=self._state) for d in components] def _rebind_cached_references(self, new_guild: Guild, new_channel: TextChannel | Thread) -> None: self.guild = new_guild @@ -1244,7 +1245,7 @@ def _rebind_cached_references(self, new_guild: Guild, new_channel: TextChannel | @property def interaction(self) -> MessageInteraction | None: - utils.warn_deprecated( + warn_deprecated( "interaction", "interaction_metadata", "2.6", @@ -1254,7 +1255,7 @@ def interaction(self) -> MessageInteraction | None: @interaction.setter def interaction(self, value: MessageInteraction | None) -> None: - utils.warn_deprecated( + warn_deprecated( "interaction", "interaction_metadata", "2.6", @@ -1262,7 +1263,7 @@ def interaction(self, value: MessageInteraction | None) -> None: ) self._interaction = value - @utils.cached_slot_property("_cs_raw_mentions") + @cached_slot_property("_cs_raw_mentions") def raw_mentions(self) -> list[int]: """A property that returns an array of user IDs matched with the syntax of ``<@user_id>`` in the message content. @@ -1272,28 +1273,28 @@ def raw_mentions(self) -> list[int]: """ return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)] - @utils.cached_slot_property("_cs_raw_channel_mentions") + @cached_slot_property("_cs_raw_channel_mentions") def raw_channel_mentions(self) -> list[int]: """A property that returns an array of channel IDs matched with the syntax of ``<#channel_id>`` in the message content. """ return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)] - @utils.cached_slot_property("_cs_raw_role_mentions") + @cached_slot_property("_cs_raw_role_mentions") def raw_role_mentions(self) -> list[int]: """A property that returns an array of role IDs matched with the syntax of ``<@&role_id>`` in the message content. """ return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)] - @utils.cached_slot_property("_cs_channel_mentions") + @cached_slot_property("_cs_channel_mentions") def channel_mentions(self) -> list[GuildChannel]: if self.guild is None: return [] it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) - return utils._unique(it) + return list(dict.fromkeys(it)) # using dict.fromkeys and not set to preserve order - @utils.cached_slot_property("_cs_clean_content") + @cached_slot_property("_cs_clean_content") def clean_content(self) -> str: """A property that returns the content in a "cleaned up" manner. This basically means that mentions are transformed @@ -1370,7 +1371,7 @@ def is_system(self) -> bool: MessageType.thread_starter_message, ) - @utils.cached_slot_property("_cs_system_content") + @cached_slot_property("_cs_system_content") def system_content(self) -> str: r"""A property that returns the content that is rendered regardless of the :attr:`Message.type`. @@ -1538,7 +1539,7 @@ async def delete(self, *, delay: float | None = None, reason: str | None = None) """ del_func = self._state.http.delete_message(self.channel.id, self.id, reason=reason) if delay is not None: - utils.delay_task(delay, del_func) + delay_task(delay, del_func) else: await del_func @@ -1555,7 +1556,7 @@ async def edit( suppress: bool = ..., delete_after: float | None = ..., allowed_mentions: AllowedMentions | None = ..., - view: View | None = ..., + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, ) -> Message: ... async def edit( @@ -1569,7 +1570,7 @@ async def edit( suppress: bool | utils.Undefined = MISSING, delete_after: float | None = None, allowed_mentions: AllowedMentions | None | utils.Undefined = MISSING, - view: View | None | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, ) -> Message: """|coro| @@ -1618,9 +1619,8 @@ async def edit( are used instead. .. versionadded:: 1.4 - view: Optional[:class:`~discord.ui.View`] - The updated view to update this message with. If ``None`` is passed then - the view is removed. + components: Optional[Sequence[:class:`AnyComponent`]] + The new components to replace the originals with. If ``None`` is passed then the components are removed. Raises ------ @@ -1648,10 +1648,10 @@ async def edit( elif embeds is not MISSING: payload["embeds"] = [e.to_dict() for e in embeds] + flags = MessageFlags._from_value(self.flags.value) + if suppress is not MISSING: - flags = MessageFlags._from_value(self.flags.value) flags.suppress_embeds = suppress - payload["flags"] = flags.value if allowed_mentions is MISSING: if self._state.allowed_mentions is not None and self.author.id == self._state.self_id: @@ -1665,12 +1665,20 @@ async def edit( if attachments is not MISSING: payload["attachments"] = [a.to_dict() for a in attachments] - if view is not MISSING: - self._state.prevent_view_updates_for(self.id) - payload["components"] = view.to_components() if view else [] + if components is not MISSING: + payload["components"] = [] + if components: + for c in components: + if c.any_is_v2()(): + flags.is_components_v2 = True + payload["components"].append(c.to_dict()) + if file is not MISSING and files is not MISSING: raise InvalidArgument("cannot pass both file and files parameter to edit()") + if flags.value != self.flags.value: + payload["flags"] = flags.value + if file is not MISSING or files is not MISSING: if file is not MISSING: if not isinstance(file, File): @@ -1699,10 +1707,6 @@ async def edit( data = await self._state.http.edit_message(self.channel.id, self.id, **payload) message = Message(state=self._state, channel=self.channel, data=data) - if view and not view.is_finished(): - view.message = message - self._state.store_view(view, self.id) - if delete_after is not None: await self.delete(delay=delete_after) @@ -2184,7 +2188,7 @@ def created_at(self) -> datetime.datetime: def poll(self) -> Poll | None: return self._state._polls.get(self.id) - @utils.cached_slot_property("_cs_guild") + @cached_slot_property("_cs_guild") def guild(self) -> Guild | None: """The guild that the partial message belongs to, if applicable.""" return getattr(self.channel, "guild", None) @@ -2248,11 +2252,12 @@ async def edit(self, **fields: Any) -> Message | None: to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` are used instead. - view: Optional[:class:`~discord.ui.View`] - The updated view to update this message with. If ``None`` is passed then - the view is removed. + components: Optional[Sequence[AnyComponent]] + The new components to replace the originals with. If ``None`` is passed then the components + are removed. - .. versionadded:: 2.0 + ..versionchanged:: 3.0 + Changed from view to components. Returns ------- @@ -2305,10 +2310,14 @@ async def edit(self, **fields: Any) -> Message | None: self._state.allowed_mentions.to_dict() if self._state.allowed_mentions else None ) - view = fields.pop("view", MISSING) - if view is not MISSING: - self._state.prevent_view_updates_for(self.id) - fields["components"] = view.to_components() if view else [] + components = fields.pop("components", MISSING) + if components is not MISSING: + fields["components"] = [] + if components: + for c in components: + if c.any_is_v2(): + flags.is_components_v2 = True + fields["components"].append(c.to_dict()) if fields: data = await self._state.http.edit_message(self.channel.id, self.id, **fields) @@ -2319,9 +2328,6 @@ async def edit(self, **fields: Any) -> Message | None: if fields: # data isn't unbound msg = self._state.create_message(channel=self.channel, data=data) # type: ignore - if view and not view.is_finished(): - view.message = msg - self._state.store_view(view, self.id) return msg async def end_poll(self) -> Message: diff --git a/discord/monetization.py b/discord/monetization.py index 8ca4f23c40..18b55894d0 100644 --- a/discord/monetization.py +++ b/discord/monetization.py @@ -31,7 +31,8 @@ from .flags import SKUFlags from .iterators import SubscriptionIterator from .mixins import Hashable -from .utils import MISSING, _get_as_snowflake, parse_time +from .utils import MISSING +from .utils.private import get_as_snowflake, parse_time if TYPE_CHECKING: from datetime import datetime @@ -226,12 +227,12 @@ def __init__(self, *, data: EntitlementPayload, state: ConnectionState) -> None: self.id: int = int(data["id"]) self.sku_id: int = int(data["sku_id"]) self.application_id: int = int(data["application_id"]) - self.user_id: int | MISSING = _get_as_snowflake(data, "user_id") or MISSING + self.user_id: int | MISSING = get_as_snowflake(data, "user_id") or MISSING self.type: EntitlementType = try_enum(EntitlementType, data["type"]) self.deleted: bool = data["deleted"] self.starts_at: datetime | MISSING = parse_time(data.get("starts_at")) or MISSING self.ends_at: datetime | MISSING | None = parse_time(ea) if (ea := data.get("ends_at")) is not None else MISSING - self.guild_id: int | MISSING = _get_as_snowflake(data, "guild_id") or MISSING + self.guild_id: int | MISSING = get_as_snowflake(data, "guild_id") or MISSING self.consumed: bool = data.get("consumed", False) def __repr__(self) -> str: diff --git a/discord/onboarding.py b/discord/onboarding.py index 868a5c8ae3..06d865128d 100644 --- a/discord/onboarding.py +++ b/discord/onboarding.py @@ -25,11 +25,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any +from functools import cached_property + +from discord import utils from .enums import OnboardingMode, PromptType, try_enum from .partial_emoji import PartialEmoji -from .utils import MISSING, cached_property, generate_snowflake, get -from discord import utils +from .utils import MISSING, generate_snowflake, find if TYPE_CHECKING: from .abc import Snowflake @@ -81,7 +83,7 @@ def __init__( id: int | None = None, ): # ID is required when making edits, but it can be any snowflake that isn't already used by another prompt during edits - self.id: int = int(id) if id else generate_snowflake() + self.id: int = int(id) if id else generate_snowflake(mode="realistic") self.title: str = title self.channels: list[Snowflake] = channels or [] self.roles: list[Snowflake] = roles or [] @@ -127,7 +129,7 @@ def _from_dict(cls, data: PromptOptionPayload, guild: Guild) -> PromptOption: # Emoji object is {'id': None, 'name': None, 'animated': False} ... emoji = PartialEmoji.from_dict(_emoji) if emoji.id: - emoji = get(guild.emojis, id=emoji.id) or emoji + emoji = find(lambda e: e.id == emoji.id, guild.emojis) or emoji else: emoji = None @@ -169,7 +171,7 @@ def __init__( id: int | None = None, # Currently optional as users can manually create these ): # ID is required when making edits, but it can be any snowflake that isn't already used by another prompt during edits - self.id: int = int(id) if id else generate_snowflake() + self.id: int = int(id) if id else generate_snowflake(mode="realistic") self.type: PromptType = type if isinstance(self.type, int): @@ -431,7 +433,7 @@ def get_prompt( The matching prompt, or None if it didn't exist. """ - return get(self.prompts, id=id) + return find(lambda p: p.id == id, self.prompts) async def delete_prompt( self, diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index 9bf304bc75..aacb0d8778 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from . import utils +from .utils.private import get_as_snowflake from .asset import Asset, AssetMixin from .errors import InvalidArgument @@ -108,7 +109,7 @@ def __init__(self, *, name: str | None, animated: bool = False, id: int | None = def from_dict(cls: type[PE], data: PartialEmojiPayload | dict[str, Any]) -> PE: return cls( animated=data.get("animated", False), - id=utils._get_as_snowflake(data, "id"), + id=get_as_snowflake(data, "id"), name=data.get("name") or "", ) diff --git a/discord/poll.py b/discord/poll.py index 5bddce29dd..7d2d426bcb 100644 --- a/discord/poll.py +++ b/discord/poll.py @@ -26,7 +26,9 @@ import datetime from typing import TYPE_CHECKING, Any +from functools import cached_property +from .utils.private import parse_time from . import utils from .enums import PollLayoutType, try_enum from .iterators import VoteIterator @@ -143,7 +145,7 @@ def count(self) -> int | None: return None if self._poll.results is None: return None # Unknown vote count. - _count = self._poll.results and utils.get(self._poll.results.answer_counts, id=self.id) + _count = self._poll.results and utils.find(lambda p: p.id == id, self._poll.results.answer_counts) if _count: return _count.count return 0 # If an answer isn't in answer_counts, it has 0 votes. @@ -341,10 +343,10 @@ def __init__( self._expiry = None self._message = None - @utils.cached_property + @cached_property def expiry(self) -> datetime.datetime | None: """An aware datetime object that specifies the date and time in UTC when the poll will end.""" - return utils.parse_time(self._expiry) + return parse_time(self._expiry) def to_dict(self) -> PollPayload: dict_ = { @@ -425,7 +427,7 @@ def get_answer(self, id) -> PollAnswer | None: Optional[:class:`.PollAnswer`] The returned answer or ``None`` if not found. """ - return utils.get(self.answers, id=id) + return utils.find(lambda a: a.id == id, self.answers) def add_answer( self, diff --git a/discord/role.py b/discord/role.py index b39d5d3ae6..c9af02d7e8 100644 --- a/discord/role.py +++ b/discord/role.py @@ -34,7 +34,8 @@ from .flags import RoleFlags from .mixins import Hashable from .permissions import Permissions -from .utils import MISSING, _bytes_to_base64_data, _get_as_snowflake, snowflake_time +from .utils import MISSING, snowflake_time +from .utils.private import get_as_snowflake, bytes_to_base64_data __all__ = ( "RoleTags", @@ -89,9 +90,9 @@ class RoleTags: ) def __init__(self, data: RoleTagPayload): - self.bot_id: int | None = _get_as_snowflake(data, "bot_id") - self.integration_id: int | None = _get_as_snowflake(data, "integration_id") - self.subscription_listing_id: int | None = _get_as_snowflake(data, "subscription_listing_id") + self.bot_id: int | None = get_as_snowflake(data, "bot_id") + self.integration_id: int | None = get_as_snowflake(data, "integration_id") + self.subscription_listing_id: int | None = get_as_snowflake(data, "subscription_listing_id") # NOTE: The API returns "null" for each of the following tags if they are True, and omits them if False. # However, "null" corresponds to None. # This is different from other fields where "null" means "not there". @@ -526,7 +527,7 @@ async def edit( if icon is None: payload["icon"] = None else: - payload["icon"] = _bytes_to_base64_data(icon) + payload["icon"] = bytes_to_base64_data(icon) payload["unicode_emoji"] = None if unicode_emoji is not MISSING: diff --git a/discord/scheduled_events.py b/discord/scheduled_events.py index b6b3847682..2ce40b2ab2 100644 --- a/discord/scheduled_events.py +++ b/discord/scheduled_events.py @@ -39,7 +39,7 @@ from .iterators import ScheduledEventSubscribersIterator from .mixins import Hashable from .object import Object -from .utils import warn_deprecated +from .utils.private import warn_deprecated, get_as_snowflake, bytes_to_base64_data __all__ = ( "ScheduledEvent", @@ -206,7 +206,7 @@ def __init__( self.end_time: datetime.datetime | None = end_time self.status: ScheduledEventStatus = try_enum(ScheduledEventStatus, data.get("status")) self.subscriber_count: int | None = data.get("user_count", None) - self.creator_id: int | None = utils._get_as_snowflake(data, "creator_id") + self.creator_id: int | None = get_as_snowflake(data, "creator_id") self.creator: Member | None = creator entity_metadata = data.get("entity_metadata") @@ -361,7 +361,7 @@ async def edit( if image is None: payload["image"] = None else: - payload["image"] = utils._bytes_to_base64_data(image) + payload["image"] = bytes_to_base64_data(image) if location is not MISSING: if not isinstance(location, (ScheduledEventLocation, utils.Undefined)): diff --git a/discord/stage_instance.py b/discord/stage_instance.py index d5703f69c2..8fdc508f41 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -30,7 +30,8 @@ from .enums import StagePrivacyLevel, try_enum from .errors import InvalidArgument from .mixins import Hashable -from .utils import MISSING, Undefined, cached_slot_property +from .utils import MISSING, Undefined +from .utils.private import cached_slot_property __all__ = ("StageInstance",) diff --git a/discord/state.py b/discord/state.py index 7b9841c447..d49f4865e6 100644 --- a/discord/state.py +++ b/discord/state.py @@ -43,7 +43,9 @@ Union, ) +from .utils.private import parse_time, sane_wait_for from . import utils +from .utils.private import get_as_snowflake, parse_time, sane_wait_for from .activity import BaseActivity from .audit_logs import AuditLogEntry from .automod import AutoModRule @@ -69,8 +71,6 @@ from .stage_instance import StageInstance from .sticker import GuildSticker from .threads import Thread, ThreadMember -from .ui.modal import Modal, ModalStore -from .ui.view import View, ViewStore from .user import ClientUser, User if TYPE_CHECKING: @@ -180,7 +180,7 @@ def __init__( self.hooks: dict[str, Callable] = hooks self.shard_count: int | None = None self._ready_task: asyncio.Task | None = None - self.application_id: int | None = utils._get_as_snowflake(options, "application_id") + self.application_id: int | None = get_as_snowflake(options, "application_id") self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) if self.guild_ready_timeout < 0: @@ -246,7 +246,7 @@ def __init__( self.clear() - def clear(self, *, views: bool = True) -> None: + def clear(self) -> None: self.user: ClientUser | None = None # Originally, this code used WeakValueDictionary to maintain references to the # global user mapping. @@ -265,9 +265,6 @@ def clear(self, *, views: bool = True) -> None: self._stickers: dict[int, GuildSticker] = {} self._guilds: dict[int, Guild] = {} self._polls: dict[int, Poll] = {} - if views: - self._view_store: ViewStore = ViewStore(self) - self._modal_store: ModalStore = ModalStore(self) self._voice_clients: dict[int, VoiceClient] = {} # LRU of max size 128 @@ -379,19 +376,6 @@ def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view: View, message_id: int | None = None) -> None: - self._view_store.add_view(view, message_id) - - def store_modal(self, modal: Modal, message_id: int) -> None: - self._modal_store.add_modal(modal, message_id) - - def prevent_view_updates_for(self, message_id: int) -> View | None: - return self._view_store.remove_message_tracking(message_id) - - @property - def persistent_views(self) -> Sequence[View]: - return self._view_store.persistent_views - @property def guilds(self) -> list[Guild]: return list(self._guilds.values()) @@ -626,7 +610,7 @@ def parse_ready(self, data) -> None: self._ready_task.cancel() self._ready_state = asyncio.Queue() - self.clear(views=False) + self.clear() self.user = ClientUser(state=self, data=data["user"]) self.store_user(data["user"]) @@ -636,7 +620,7 @@ def parse_ready(self, data) -> None: except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, "id") + self.application_id = get_as_snowflake(application, "id") # flags will always be present here self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore @@ -750,12 +734,9 @@ def parse_message_update(self, data) -> None: self.store_raw_poll(poll_data, raw) self.dispatch("raw_message_edit", raw) - if "components" in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data["components"]) - def parse_message_reaction_add(self, data) -> None: emoji = data["emoji"] - emoji_id = utils._get_as_snowflake(emoji, "id") + emoji_id = get_as_snowflake(emoji, "id") emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"]) raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") @@ -792,7 +773,7 @@ def parse_message_reaction_remove_all(self, data) -> None: def parse_message_reaction_remove(self, data) -> None: emoji = data["emoji"] - emoji_id = utils._get_as_snowflake(emoji, "id") + emoji_id = get_as_snowflake(emoji, "id") emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionActionEvent(data, emoji, "REACTION_REMOVE") @@ -822,7 +803,7 @@ def parse_message_reaction_remove(self, data) -> None: def parse_message_reaction_remove_emoji(self, data) -> None: emoji = data["emoji"] - emoji_id = utils._get_as_snowflake(emoji, "id") + emoji_id = get_as_snowflake(emoji, "id") emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionClearEmojiEvent(data, emoji) self.dispatch("raw_reaction_clear_emoji", raw) @@ -890,18 +871,18 @@ def parse_interaction_create(self, data) -> None: if data["type"] == 3: # interaction component custom_id = interaction.data["custom_id"] # type: ignore component_type = interaction.data["component_type"] # type: ignore - self._view_store.dispatch(component_type, custom_id, interaction) + # TODO: components interactions if interaction.type == InteractionType.modal_submit: user_id, custom_id = ( interaction.user.id, interaction.data["custom_id"], ) - asyncio.create_task(self._modal_store.dispatch(user_id, custom_id, interaction)) + # TODO: modal interactions self.dispatch("interaction", interaction) def parse_presence_update(self, data) -> None: - guild_id = utils._get_as_snowflake(data, "guild_id") + guild_id = get_as_snowflake(data, "guild_id") # guild_id won't be None here guild = self._get_guild(guild_id) if guild is None: @@ -945,7 +926,7 @@ def parse_invite_delete(self, data) -> None: self.dispatch("invite_delete", invite) def parse_channel_delete(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + guild = self._get_guild(get_as_snowflake(data, "guild_id")) channel_id = int(data["id"]) if guild is not None: channel = guild.get_channel(channel_id) @@ -964,7 +945,7 @@ def parse_channel_update(self, data) -> None: self.dispatch("private_channel_update", old_channel, channel) return - guild_id = utils._get_as_snowflake(data, "guild_id") + guild_id = get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: channel = guild.get_channel(channel_id) @@ -992,7 +973,7 @@ def parse_channel_create(self, data) -> None: ) return - guild_id = utils._get_as_snowflake(data, "guild_id") + guild_id = get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here @@ -1023,7 +1004,7 @@ def parse_channel_pins_update(self, data) -> None: ) return - last_pin = utils.parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None + last_pin = parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None if guild is None: self.dispatch("private_channel_pins_update", channel, last_pin) @@ -1737,8 +1718,8 @@ def parse_stage_instance_delete(self, data) -> None: ) def parse_voice_state_update(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) - channel_id = utils._get_as_snowflake(data, "channel_id") + guild = self._get_guild(get_as_snowflake(data, "guild_id")) + channel_id = get_as_snowflake(data, "channel_id") flags = self.member_cache_flags # self.user is *always* cached when this is called self_id = self.user.id # type: ignore @@ -1837,7 +1818,7 @@ def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> User return self.get_user(user_id) def get_reaction_emoji(self, data) -> GuildEmoji | AppEmoji | PartialEmoji: - emoji_id = utils._get_as_snowflake(data, "id") + emoji_id = get_as_snowflake(data, "id") if not emoji_id: return data["name"] @@ -1935,7 +1916,7 @@ async def _delay_ready(self) -> None: ) if len(current_bucket) >= max_concurrency: try: - await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) + await sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) except asyncio.TimeoutError: fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" _log.warning(fmt, guild.shard_id, len(current_bucket)) @@ -1957,7 +1938,7 @@ async def _delay_ready(self) -> None: # 110 reqs/minute w/ 1 req/guild plus some buffer timeout = 61 * (len(children) / 110) try: - await utils.sane_wait_for(futures, timeout=timeout) + await sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: _log.warning( ("Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds"), @@ -2005,7 +1986,7 @@ def parse_ready(self, data) -> None: except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, "id") + self.application_id = get_as_snowflake(application, "id") self.application_flags = ApplicationFlags._from_value(application["flags"]) for guild_data in data["guilds"]: diff --git a/discord/sticker.py b/discord/sticker.py index 66d8d95ef9..6fb3e5ba38 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -32,7 +32,8 @@ from .enums import StickerFormatType, StickerType, try_enum from .errors import InvalidData from .mixins import Hashable -from .utils import MISSING, Undefined, cached_slot_property, find, get, snowflake_time +from .utils import MISSING, Undefined, find, snowflake_time +from .utils.private import cached_slot_property __all__ = ( "StickerPack", @@ -119,7 +120,7 @@ def _from_data(self, data: StickerPackPayload) -> None: self.name: str = data["name"] self.sku_id: int = int(data["sku_id"]) self.cover_sticker_id: int = int(data["cover_sticker_id"]) - self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore + self.cover_sticker: StandardSticker = find(lambda s: s.id == self.cover_sticker_id, self.stickers) # type: ignore self.description: str = data["description"] self._banner: int = int(data["banner_asset_id"]) diff --git a/discord/team.py b/discord/team.py index d04e47e3ff..a7d890e15d 100644 --- a/discord/team.py +++ b/discord/team.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING from . import utils +from .utils.private import get_as_snowflake from .asset import Asset from .enums import TeamMembershipState, try_enum from .user import BaseUser @@ -68,7 +69,7 @@ def __init__(self, state: ConnectionState, data: TeamPayload): self.id: int = int(data["id"]) self.name: str = data["name"] self._icon: str | None = data["icon"] - self.owner_id: int | None = utils._get_as_snowflake(data, "owner_user_id") + self.owner_id: int | None = get_as_snowflake(data, "owner_user_id") self.members: list[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]] def __repr__(self) -> str: @@ -84,7 +85,7 @@ def icon(self) -> Asset | None: @property def owner(self) -> TeamMember | None: """The team's owner.""" - return utils.get(self.members, id=self.owner_id) + return utils.find(lambda m: m.id == self.owner_id, self.members) class TeamMember(BaseUser): diff --git a/discord/template.py b/discord/template.py index 4eaa71e9b7..e5336fd5ed 100644 --- a/discord/template.py +++ b/discord/template.py @@ -28,7 +28,8 @@ from typing import TYPE_CHECKING, Any from .guild import Guild -from .utils import MISSING, Undefined, _bytes_to_base64_data, parse_time +from .utils import MISSING, Undefined +from .utils.private import bytes_to_base64_data, parse_time __all__ = ("Template",) @@ -195,7 +196,7 @@ async def create_guild(self, name: str, icon: Any = None) -> Guild: Invalid icon image format given. Must be PNG or JPG. """ if icon is not None: - icon = _bytes_to_base64_data(icon) + icon = bytes_to_base64_data(icon) data = await self._state.http.create_from_template(self.code, name, icon) return Guild(data=data, state=self._state) diff --git a/discord/threads.py b/discord/threads.py index a0270e54d9..7442b0ec01 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -32,7 +32,8 @@ from .errors import ClientException from .flags import ChannelFlags from .mixins import Hashable -from .utils import MISSING, _get_as_snowflake, parse_time +from .utils import MISSING +from .utils.private import get_as_snowflake, parse_time from discord import utils __all__ = ( @@ -189,7 +190,7 @@ def _from_data(self, data: ThreadPayload): # This data may be missing depending on how this object is being created self.owner_id = int(data.get("owner_id")) if data.get("owner_id", None) is not None else None - self.last_message_id = _get_as_snowflake(data, "last_message_id") + self.last_message_id = get_as_snowflake(data, "last_message_id") self.slowmode_delay = data.get("rate_limit_per_user", 0) self.message_count = data.get("message_count", None) self.member_count = data.get("member_count", None) @@ -283,7 +284,7 @@ def applied_tags(self) -> list[ForumTag]: This is only available for threads in forum or media channels. """ - from .channel import ForumChannel # to prevent circular import # noqa: PLC0415 + from .channel import ForumChannel # noqa: PLC0415 # to prevent circular import if isinstance(self.parent, ForumChannel): return [tag for tag_id in self._applied_tags if (tag := self.parent.get_tag(tag_id)) is not None] diff --git a/discord/types/components.py b/discord/types/components.py index 7b05f8bf08..8aa75079d3 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -25,7 +25,7 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Literal, Union, Generic, TypeVar from typing_extensions import NotRequired, TypedDict @@ -33,34 +33,35 @@ from .emoji import PartialEmoji from .snowflake import Snowflake -ComponentType = Literal[1, 2, 3, 4] +ComponentType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17] ButtonStyle = Literal[1, 2, 3, 4, 5, 6] InputTextStyle = Literal[1, 2] +SeparatorSpacingSize = Literal[1, 2] -class ActionRow(TypedDict): - type: Literal[1] - components: list[Component] +class BaseComponent(TypedDict): + type: ComponentType + id: NotRequired[int] -class ButtonComponent(TypedDict): +class ButtonComponent(BaseComponent): + type: Literal[2] # pyright: ignore[reportIncompatibleVariableOverride] + style: ButtonStyle + label: NotRequired[str] + emoji: NotRequired[PartialEmoji] custom_id: NotRequired[str] url: NotRequired[str] disabled: NotRequired[bool] - emoji: NotRequired[PartialEmoji] - label: NotRequired[str] - type: Literal[2] - style: ButtonStyle - sku_id: Snowflake + sku_id: NotRequired[Snowflake] -class InputText(TypedDict): +class InputText(BaseComponent): + type: Literal[4] # pyright: ignore[reportIncompatibleVariableOverride] min_length: NotRequired[int] max_length: NotRequired[int] required: NotRequired[bool] placeholder: NotRequired[str] value: NotRequired[str] - type: Literal[4] style: InputTextStyle custom_id: str label: str @@ -74,15 +75,159 @@ class SelectOption(TypedDict): default: bool -class SelectMenu(TypedDict): +T = TypeVar("T", bound=Literal["user", "role", "channel"]) + + +class SelectDefaultValue(TypedDict, Generic[T]): + id: int + type: T + + +class StringSelect(BaseComponent): + type: Literal[3] # pyright: ignore[reportIncompatibleVariableOverride] + custom_id: str + options: list[SelectOption] placeholder: NotRequired[str] min_values: NotRequired[int] max_values: NotRequired[int] disabled: NotRequired[bool] - channel_types: NotRequired[list[ChannelType]] - options: NotRequired[list[SelectOption]] - type: Literal[3, 5, 6, 7, 8] + + +class UserSelect(BaseComponent): + type: Literal[5] # pyright: ignore[reportIncompatibleVariableOverride] + custom_id: str + placeholder: NotRequired[str] + default_values: NotRequired[list[SelectDefaultValue[Literal["user"]]]] + min_values: NotRequired[int] + max_values: NotRequired[int] + disabled: NotRequired[bool] + + +class RoleSelect(BaseComponent): + type: Literal[6] # pyright: ignore[reportIncompatibleVariableOverride] + custom_id: str + placeholder: NotRequired[str] + default_values: NotRequired[list[SelectDefaultValue[Literal["role"]]]] + min_values: NotRequired[int] + max_values: NotRequired[int] + disabled: NotRequired[bool] + + +class MentionableSelect(BaseComponent): + type: Literal[7] # pyright: ignore[reportIncompatibleVariableOverride] + custom_id: str + placeholder: NotRequired[str] + default_values: NotRequired[list[SelectDefaultValue[Literal["role", "user"]]]] + min_values: NotRequired[int] + max_values: NotRequired[int] + disabled: NotRequired[bool] + + +class ChannelSelect(BaseComponent): + type: Literal[8] # pyright: ignore[reportIncompatibleVariableOverride] custom_id: str + channel_types: NotRequired[list[ChannelType]] + placeholder: NotRequired[str] + default_values: NotRequired[list[SelectDefaultValue[Literal["channel"]]]] + min_values: NotRequired[int] + max_values: NotRequired[int] + disabled: NotRequired[bool] + + +class SectionComponent(BaseComponent): + type: Literal[9] # pyright: ignore[reportIncompatibleVariableOverride] + components: list[TextDisplayComponent] + accessory: NotRequired[ThumbnailComponent | ButtonComponent] + + +class TextDisplayComponent(BaseComponent): + type: Literal[10] # pyright: ignore[reportIncompatibleVariableOverride] + content: str + + +class UnfurledMediaItem(TypedDict): + url: str + proxy_url: str + height: NotRequired[int | None] + width: NotRequired[int | None] + content_type: NotRequired[str] + flags: NotRequired[int] + attachment_id: NotRequired[Snowflake] + + +class ThumbnailComponent(BaseComponent): + type: Literal[11] # pyright: ignore[reportIncompatibleVariableOverride] + media: UnfurledMediaItem + description: NotRequired[str] + spoiler: NotRequired[bool] + + +class MediaGalleryItem(TypedDict): + media: UnfurledMediaItem + description: NotRequired[str] + spoiler: NotRequired[bool] + + +class MediaGalleryComponent(BaseComponent): + type: Literal[12] # pyright: ignore[reportIncompatibleVariableOverride] + items: list[MediaGalleryItem] + + +class FileComponent(BaseComponent): + type: Literal[13] # pyright: ignore[reportIncompatibleVariableOverride] + file: UnfurledMediaItem + spoiler: NotRequired[bool] + name: str + size: int + + +class SeparatorComponent(BaseComponent): + type: Literal[14] # pyright: ignore[reportIncompatibleVariableOverride] + divider: NotRequired[bool] + spacing: NotRequired[SeparatorSpacingSize] + + +AllowedActionRowComponents = Union[ + ButtonComponent, InputText, StringSelect, UserSelect, RoleSelect, MentionableSelect, ChannelSelect +] + + +class ActionRow(BaseComponent): + type: Literal[1] # pyright: ignore[reportIncompatibleVariableOverride] + components: list[AllowedActionRowComponents] + + +AllowedContainerComponents = Union[ + ActionRow, + TextDisplayComponent, + MediaGalleryComponent, + FileComponent, + SeparatorComponent, + SectionComponent, +] + + +class ContainerComponent(BaseComponent): + type: Literal[17] # pyright: ignore[reportIncompatibleVariableOverride] + accent_color: NotRequired[int] + spoiler: NotRequired[bool] + components: list[AllowedContainerComponents] -Component = Union[ActionRow, ButtonComponent, SelectMenu, InputText] +Component = Union[ + ActionRow, + ButtonComponent, + StringSelect, + UserSelect, + RoleSelect, + MentionableSelect, + ChannelSelect, + InputText, + TextDisplayComponent, + SectionComponent, + ThumbnailComponent, + MediaGalleryComponent, + FileComponent, + SeparatorComponent, + ContainerComponent, +] diff --git a/discord/ui/__init__.py b/discord/ui/__init__.py deleted file mode 100644 index fa1767d220..0000000000 --- a/discord/ui/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -discord.ui -~~~~~~~~~~ - -UI Kit helper for the Discord API - -:copyright: (c) 2015-2021 Rapptz & (c) 2021-present Pycord Development -:license: MIT, see LICENSE for more details. -""" - -from .button import * -from .input_text import * -from .item import * -from .modal import * -from .select import * -from .view import * diff --git a/discord/ui/button.py b/discord/ui/button.py deleted file mode 100644 index e9c59d8cfa..0000000000 --- a/discord/ui/button.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import inspect -import os -from typing import TYPE_CHECKING, Callable, TypeVar - -from ..components import Button as ButtonComponent -from ..enums import ButtonStyle, ComponentType -from ..partial_emoji import PartialEmoji, _EmojiTag -from .item import Item, ItemCallbackType - -__all__ = ( - "Button", - "button", -) - -if TYPE_CHECKING: - from ..emoji import AppEmoji, GuildEmoji - from .view import View - -B = TypeVar("B", bound="Button") -V = TypeVar("V", bound="View", covariant=True) - - -class Button(Item[V]): - """Represents a UI button. - - .. versionadded:: 2.0 - - Parameters - ---------- - style: :class:`discord.ButtonStyle` - The style of the button. - custom_id: Optional[:class:`str`] - The ID of the button that gets received during an interaction. - If this button is for a URL, it does not have a custom ID. - url: Optional[:class:`str`] - The URL this button sends you to. - disabled: :class:`bool` - Whether the button is disabled or not. - label: Optional[:class:`str`] - The label of the button, if any. Maximum of 80 chars. - emoji: Optional[Union[:class:`.PartialEmoji`, :class:`GuildEmoji`, :class:`AppEmoji`, :class:`str`]] - The emoji of the button, if available. - sku_id: Optional[Union[:class:`int`]] - The ID of the SKU this button refers to. - row: Optional[:class:`int`] - The relative row this button belongs to. A Discord component can only have 5 - rows. By default, items are arranged automatically into those 5 rows. If you'd - like to control the relative positioning of the row then passing an index is advised. - For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic - ordering. The row number must be between 0 and 4 (i.e. zero indexed). - """ - - __item_repr_attributes__: tuple[str, ...] = ( - "style", - "url", - "disabled", - "label", - "emoji", - "sku_id", - "row", - ) - - def __init__( - self, - *, - style: ButtonStyle = ButtonStyle.secondary, - label: str | None = None, - disabled: bool = False, - custom_id: str | None = None, - url: str | None = None, - emoji: str | GuildEmoji | AppEmoji | PartialEmoji | None = None, - sku_id: int | None = None, - row: int | None = None, - ): - super().__init__() - if label and len(str(label)) > 80: - raise ValueError("label must be 80 characters or fewer") - if custom_id is not None and len(str(custom_id)) > 100: - raise ValueError("custom_id must be 100 characters or fewer") - if custom_id is not None and url is not None: - raise TypeError("cannot mix both url and custom_id with Button") - if sku_id is not None and url is not None: - raise TypeError("cannot mix both url and sku_id with Button") - if custom_id is not None and sku_id is not None: - raise TypeError("cannot mix both sku_id and custom_id with Button") - - if not isinstance(custom_id, str) and custom_id is not None: - raise TypeError(f"expected custom_id to be str, not {custom_id.__class__.__name__}") - - self._provided_custom_id = custom_id is not None - if url is None and custom_id is None and sku_id is None: - custom_id = os.urandom(16).hex() - - if url is not None: - style = ButtonStyle.link - if sku_id is not None: - style = ButtonStyle.premium - - if emoji is not None: - if isinstance(emoji, str): - emoji = PartialEmoji.from_str(emoji) - elif isinstance(emoji, _EmojiTag): - emoji = emoji._to_partial() - else: - raise TypeError( - f"expected emoji to be str, GuildEmoji, AppEmoji, or PartialEmoji not {emoji.__class__}" - ) - - self._underlying = ButtonComponent._raw_construct( - type=ComponentType.button, - custom_id=custom_id, - url=url, - disabled=disabled, - label=label, - style=style, - emoji=emoji, - sku_id=sku_id, - ) - self.row = row - - @property - def style(self) -> ButtonStyle: - """The style of the button.""" - return self._underlying.style - - @style.setter - def style(self, value: ButtonStyle): - self._underlying.style = value - - @property - def custom_id(self) -> str | None: - """The ID of the button that gets received during an interaction. - - If this button is for a URL, it does not have a custom ID. - """ - return self._underlying.custom_id - - @custom_id.setter - def custom_id(self, value: str | None): - if value is not None and not isinstance(value, str): - raise TypeError("custom_id must be None or str") - if value and len(value) > 100: - raise ValueError("custom_id must be 100 characters or fewer") - self._underlying.custom_id = value - - @property - def url(self) -> str | None: - """The URL this button sends you to.""" - return self._underlying.url - - @url.setter - def url(self, value: str | None): - if value is not None and not isinstance(value, str): - raise TypeError("url must be None or str") - self._underlying.url = value - - @property - def disabled(self) -> bool: - """Whether the button is disabled or not.""" - return self._underlying.disabled - - @disabled.setter - def disabled(self, value: bool): - self._underlying.disabled = bool(value) - - @property - def label(self) -> str | None: - """The label of the button, if available.""" - return self._underlying.label - - @label.setter - def label(self, value: str | None): - if value and len(str(value)) > 80: - raise ValueError("label must be 80 characters or fewer") - self._underlying.label = str(value) if value is not None else value - - @property - def emoji(self) -> PartialEmoji | None: - """The emoji of the button, if available.""" - return self._underlying.emoji - - @emoji.setter - def emoji(self, value: str | GuildEmoji | AppEmoji | PartialEmoji | None): # type: ignore - if value is None: - self._underlying.emoji = None - elif isinstance(value, str): - self._underlying.emoji = PartialEmoji.from_str(value) - elif isinstance(value, _EmojiTag): - self._underlying.emoji = value._to_partial() - else: - raise TypeError(f"expected str, GuildEmoji, AppEmoji, or PartialEmoji, received {value.__class__} instead") - - @property - def sku_id(self) -> int | None: - """The ID of the SKU this button refers to.""" - return self._underlying.sku_id - - @sku_id.setter - def sku_id(self, value: int | None): # type: ignore - if value is None: - self._underlying.sku_id = None - elif isinstance(value, int): - self._underlying.sku_id = value - else: - raise TypeError(f"expected int or None, received {value.__class__} instead") - - @classmethod - def from_component(cls: type[B], button: ButtonComponent) -> B: - return cls( - style=button.style, - label=button.label, - disabled=button.disabled, - custom_id=button.custom_id, - url=button.url, - emoji=button.emoji, - sku_id=button.sku_id, - row=None, - ) - - @property - def type(self) -> ComponentType: - return self._underlying.type - - def to_component_dict(self): - return self._underlying.to_dict() - - def is_dispatchable(self) -> bool: - return self.custom_id is not None - - def is_persistent(self) -> bool: - if self.style is ButtonStyle.link: - return self.url is not None - return super().is_persistent() - - def refresh_component(self, button: ButtonComponent) -> None: - self._underlying = button - - -def button( - *, - label: str | None = None, - custom_id: str | None = None, - disabled: bool = False, - style: ButtonStyle = ButtonStyle.secondary, - emoji: str | GuildEmoji | AppEmoji | PartialEmoji | None = None, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A decorator that attaches a button to a component. - - The function being decorated should have three parameters, ``self`` representing - the :class:`discord.ui.View`, the :class:`discord.ui.Button` being pressed and - the :class:`discord.Interaction` you receive. - - .. note:: - - Premium and link buttons cannot be created with this decorator. Consider - creating a :class:`Button` object manually instead. These types of - buttons do not have a callback associated since Discord doesn't handle - them when clicked. - - Parameters - ---------- - label: Optional[:class:`str`] - The label of the button, if any. - custom_id: Optional[:class:`str`] - The ID of the button that gets received during an interaction. - It is recommended not to set this parameter to prevent conflicts. - style: :class:`.ButtonStyle` - The style of the button. Defaults to :attr:`.ButtonStyle.grey`. - disabled: :class:`bool` - Whether the button is disabled or not. Defaults to ``False``. - emoji: Optional[Union[:class:`str`, :class:`GuildEmoji`, :class:`AppEmoji`, :class:`.PartialEmoji`]] - The emoji of the button. This can be in string form or a :class:`.PartialEmoji` - or a full :class:`GuildEmoji` or :class:`AppEmoji`. - row: Optional[:class:`int`] - The relative row this button belongs to. A Discord component can only have 5 - rows. By default, items are arranged automatically into those 5 rows. If you'd - like to control the relative positioning of the row then passing an index is advised. - For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic - ordering. The row number must be between 0 and 4 (i.e. zero indexed). - """ - - def decorator(func: ItemCallbackType) -> ItemCallbackType: - if not inspect.iscoroutinefunction(func): - raise TypeError("button function must be a coroutine function") - - func.__discord_ui_model_type__ = Button - func.__discord_ui_model_kwargs__ = { - "style": style, - "custom_id": custom_id, - "url": None, - "disabled": disabled, - "label": label, - "emoji": emoji, - "row": row, - } - return func - - return decorator diff --git a/discord/ui/input_text.py b/discord/ui/input_text.py deleted file mode 100644 index 4298324442..0000000000 --- a/discord/ui/input_text.py +++ /dev/null @@ -1,207 +0,0 @@ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -from ..components import InputText as InputTextComponent -from ..enums import ComponentType, InputTextStyle - -__all__ = ("InputText",) - -if TYPE_CHECKING: - from ..types.components import InputText as InputTextComponentPayload - - -class InputText: - """Represents a UI text input field. - - .. versionadded:: 2.0 - - Parameters - ---------- - style: :class:`~discord.InputTextStyle` - The style of the input text field. - custom_id: Optional[:class:`str`] - The ID of the input text field that gets received during an interaction. - label: :class:`str` - The label for the input text field. - Must be 45 characters or fewer. - placeholder: Optional[:class:`str`] - The placeholder text that is shown if nothing is selected, if any. - Must be 100 characters or fewer. - min_length: Optional[:class:`int`] - The minimum number of characters that must be entered. - Defaults to 0 and must be less than 4000. - max_length: Optional[:class:`int`] - The maximum number of characters that can be entered. - Must be between 1 and 4000. - required: Optional[:class:`bool`] - Whether the input text field is required or not. Defaults to ``True``. - value: Optional[:class:`str`] - Pre-fills the input text field with this value. - Must be 4000 characters or fewer. - row: Optional[:class:`int`] - The relative row this input text field belongs to. A modal dialog can only have 5 - rows. By default, items are arranged automatically into those 5 rows. If you'd - like to control the relative positioning of the row then passing an index is advised. - For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic - ordering. The row number must be between 0 and 4 (i.e. zero indexed). - """ - - def __init__( - self, - *, - style: InputTextStyle = InputTextStyle.short, - custom_id: str | None = None, - label: str, - placeholder: str | None = None, - min_length: int | None = None, - max_length: int | None = None, - required: bool | None = True, - value: str | None = None, - row: int | None = None, - ): - super().__init__() - if len(str(label)) > 45: - raise ValueError("label must be 45 characters or fewer") - if min_length and (min_length < 0 or min_length > 4000): - raise ValueError("min_length must be between 0 and 4000") - if max_length and (max_length < 0 or max_length > 4000): - raise ValueError("max_length must be between 1 and 4000") - if value and len(str(value)) > 4000: - raise ValueError("value must be 4000 characters or fewer") - if placeholder and len(str(placeholder)) > 100: - raise ValueError("placeholder must be 100 characters or fewer") - if not isinstance(custom_id, str) and custom_id is not None: - raise TypeError(f"expected custom_id to be str, not {custom_id.__class__.__name__}") - custom_id = os.urandom(16).hex() if custom_id is None else custom_id - - self._underlying = InputTextComponent._raw_construct( - type=ComponentType.input_text, - style=style, - custom_id=custom_id, - label=label, - placeholder=placeholder, - min_length=min_length, - max_length=max_length, - required=required, - value=value, - ) - self._input_value = False - self.row = row - self._rendered_row: int | None = None - - @property - def type(self) -> ComponentType: - return self._underlying.type - - @property - def style(self) -> InputTextStyle: - """The style of the input text field.""" - return self._underlying.style - - @style.setter - def style(self, value: InputTextStyle): - if not isinstance(value, InputTextStyle): - raise TypeError(f"style must be of type InputTextStyle not {value.__class__.__name__}") - self._underlying.style = value - - @property - def custom_id(self) -> str: - """The ID of the input text field that gets received during an interaction.""" - return self._underlying.custom_id - - @custom_id.setter - def custom_id(self, value: str): - if not isinstance(value, str): - raise TypeError(f"custom_id must be None or str not {value.__class__.__name__}") - self._underlying.custom_id = value - - @property - def label(self) -> str: - """The label of the input text field.""" - return self._underlying.label - - @label.setter - def label(self, value: str): - if not isinstance(value, str): - raise TypeError(f"label should be str not {value.__class__.__name__}") - if len(value) > 45: - raise ValueError("label must be 45 characters or fewer") - self._underlying.label = value - - @property - def placeholder(self) -> str | None: - """The placeholder text that is shown before anything is entered, if any.""" - return self._underlying.placeholder - - @placeholder.setter - def placeholder(self, value: str | None): - if value and not isinstance(value, str): - raise TypeError(f"placeholder must be None or str not {value.__class__.__name__}") # type: ignore - if value and len(value) > 100: - raise ValueError("placeholder must be 100 characters or fewer") - self._underlying.placeholder = value - - @property - def min_length(self) -> int | None: - """The minimum number of characters that must be entered. Defaults to 0.""" - return self._underlying.min_length - - @min_length.setter - def min_length(self, value: int | None): - if value and not isinstance(value, int): - raise TypeError(f"min_length must be None or int not {value.__class__.__name__}") # type: ignore - if value and (value < 0 or value) > 4000: - raise ValueError("min_length must be between 0 and 4000") - self._underlying.min_length = value - - @property - def max_length(self) -> int | None: - """The maximum number of characters that can be entered.""" - return self._underlying.max_length - - @max_length.setter - def max_length(self, value: int | None): - if value and not isinstance(value, int): - raise TypeError(f"min_length must be None or int not {value.__class__.__name__}") # type: ignore - if value and (value <= 0 or value > 4000): - raise ValueError("max_length must be between 1 and 4000") - self._underlying.max_length = value - - @property - def required(self) -> bool | None: - """Whether the input text field is required or not. Defaults to ``True``.""" - return self._underlying.required - - @required.setter - def required(self, value: bool | None): - if not isinstance(value, bool): - raise TypeError(f"required must be bool not {value.__class__.__name__}") # type: ignore - self._underlying.required = bool(value) - - @property - def value(self) -> str | None: - """The value entered in the text field.""" - if self._input_value is not False: - # only False on init, otherwise the value was either set or cleared - return self._input_value # type: ignore - return self._underlying.value - - @value.setter - def value(self, value: str | None): - if value and not isinstance(value, str): - raise TypeError(f"value must be None or str not {value.__class__.__name__}") # type: ignore - if value and len(str(value)) > 4000: - raise ValueError("value must be 4000 characters or fewer") - self._underlying.value = value - - @property - def width(self) -> int: - return 5 - - def to_component_dict(self) -> InputTextComponentPayload: - return self._underlying.to_dict() - - def refresh_state(self, data) -> None: - self._input_value = data["value"] diff --git a/discord/ui/item.py b/discord/ui/item.py deleted file mode 100644 index 77338bf29b..0000000000 --- a/discord/ui/item.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar - -from ..interactions import Interaction - -__all__ = ("Item",) - -if TYPE_CHECKING: - from ..components import Component - from ..enums import ComponentType - from .view import View - -I = TypeVar("I", bound="Item") -V = TypeVar("V", bound="View", covariant=True) -ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]] - - -class Item(Generic[V]): - """Represents the base UI item that all UI components inherit from. - - The current UI items supported are: - - - :class:`discord.ui.Button` - - :class:`discord.ui.Select` - - .. versionadded:: 2.0 - """ - - __item_repr_attributes__: tuple[str, ...] = ("row",) - - def __init__(self): - self._view: V | None = None - self._row: int | None = None - self._rendered_row: int | None = None - # This works mostly well but there is a gotcha with - # the interaction with from_component, since that technically provides - # a custom_id most dispatchable items would get this set to True even though - # it might not be provided by the library user. However, this edge case doesn't - # actually affect the intended purpose of this check because from_component is - # only called upon edit and we're mainly interested during initial creation time. - self._provided_custom_id: bool = False - - def to_component_dict(self) -> dict[str, Any]: - raise NotImplementedError - - def refresh_component(self, component: Component) -> None: - return None - - def refresh_state(self, interaction: Interaction) -> None: - return None - - @classmethod - def from_component(cls: type[I], component: Component) -> I: - return cls() - - @property - def type(self) -> ComponentType: - raise NotImplementedError - - def is_dispatchable(self) -> bool: - return False - - def is_persistent(self) -> bool: - return self._provided_custom_id - - def __repr__(self) -> str: - attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__) - return f"<{self.__class__.__name__} {attrs}>" - - @property - def row(self) -> int | None: - """Gets or sets the row position of this item within its parent view. - - The row position determines the vertical placement of the item in the UI. - The value must be an integer between 0 and 4 (inclusive), or ``None`` to indicate - that no specific row is set. - - Returns - ------- - Optional[:class:`int`] - The row position of the item, or ``None`` if not explicitly set. - - Raises - ------ - ValueError - If the row value is not ``None`` and is outside the range [0, 4]. - """ - return self._row - - @row.setter - def row(self, value: int | None): - if value is None: - self._row = None - elif 5 > value >= 0: - self._row = value - else: - raise ValueError("row cannot be negative or greater than or equal to 5") - - @property - def width(self) -> int: - """Gets the width of the item in the UI layout. - - The width determines how much horizontal space this item occupies within its row. - - Returns - ------- - :class:`int` - The width of the item. Defaults to 1. - """ - return 1 - - @property - def view(self) -> V | None: - """Gets the parent view associated with this item. - - The view refers to the container that holds this item. This is typically set - automatically when the item is added to a view. - - Returns - ------- - Optional[:class:`View`] - The parent view of this item, or ``None`` if the item is not attached to any view. - """ - return self._view - - async def callback(self, interaction: Interaction): - """|coro| - - The callback associated with this UI item. - - This can be overridden by subclasses. - - Parameters - ---------- - interaction: :class:`.Interaction` - The interaction that triggered this UI item. - """ diff --git a/discord/ui/modal.py b/discord/ui/modal.py deleted file mode 100644 index b2c428d134..0000000000 --- a/discord/ui/modal.py +++ /dev/null @@ -1,331 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import sys -import time -import traceback -from functools import partial -from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable - -from .input_text import InputText - -__all__ = ( - "Modal", - "ModalStore", -) - - -if TYPE_CHECKING: - from ..interactions import Interaction - from ..state import ConnectionState - - -class Modal: - """Represents a UI Modal dialog. - - This object must be inherited to create a UI within Discord. - - .. versionadded:: 2.0 - - Parameters - ---------- - children: :class:`InputText` - The initial InputText fields that are displayed in the modal dialog. - title: :class:`str` - The title of the modal dialog. - Must be 45 characters or fewer. - custom_id: Optional[:class:`str`] - The ID of the modal dialog that gets received during an interaction. - Must be 100 characters or fewer. - timeout: Optional[:class:`float`] - Timeout in seconds from last interaction with the UI before no longer accepting input. - If ``None`` then there is no timeout. - """ - - def __init__( - self, - *children: InputText, - title: str, - custom_id: str | None = None, - timeout: float | None = None, - ) -> None: - self.timeout: float | None = timeout - if not isinstance(custom_id, str) and custom_id is not None: - raise TypeError(f"expected custom_id to be str, not {custom_id.__class__.__name__}") - self._custom_id: str | None = custom_id or os.urandom(16).hex() - if len(title) > 45: - raise ValueError("title must be 45 characters or fewer") - self._title = title - self._children: list[InputText] = list(children) - self._weights = _ModalWeights(self._children) - loop = asyncio.get_running_loop() - self._stopped: asyncio.Future[bool] = loop.create_future() - self.__cancel_callback: Callable[[Modal], None] | None = None - self.__timeout_expiry: float | None = None - self.__timeout_task: asyncio.Task[None] | None = None - self.loop = asyncio.get_event_loop() - - def _start_listening_from_store(self, store: ModalStore) -> None: - self.__cancel_callback = partial(store.remove_modal) - if self.timeout: - loop = asyncio.get_running_loop() - if self.__timeout_task is not None: - self.__timeout_task.cancel() - - self.__timeout_expiry = time.monotonic() + self.timeout - self.__timeout_task = loop.create_task(self.__timeout_task_impl()) - - async def __timeout_task_impl(self) -> None: - while True: - # Guard just in case someone changes the value of the timeout at runtime - if self.timeout is None: - return - - if self.__timeout_expiry is None: - return self._dispatch_timeout() - - # Check if we've elapsed our currently set timeout - now = time.monotonic() - if now >= self.__timeout_expiry: - return self._dispatch_timeout() - - # Wait N seconds to see if timeout data has been refreshed - await asyncio.sleep(self.__timeout_expiry - now) - - @property - def _expires_at(self) -> float | None: - if self.timeout: - return time.monotonic() + self.timeout - return None - - def _dispatch_timeout(self): - if self._stopped.done(): - return - - self._stopped.set_result(True) - self.loop.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.custom_id}") - - @property - def title(self) -> str: - """The title of the modal dialog.""" - return self._title - - @title.setter - def title(self, value: str): - if len(value) > 45: - raise ValueError("title must be 45 characters or fewer") - if not isinstance(value, str): - raise TypeError(f"expected title to be str, not {value.__class__.__name__}") - self._title = value - - @property - def children(self) -> list[InputText]: - """The child components associated with the modal dialog.""" - return self._children - - @children.setter - def children(self, value: list[InputText]): - for item in value: - if not isinstance(item, InputText): - raise TypeError(f"all Modal children must be InputText, not {item.__class__.__name__}") - self._weights = _ModalWeights(self._children) - self._children = value - - @property - def custom_id(self) -> str: - """The ID of the modal dialog that gets received during an interaction.""" - return self._custom_id - - @custom_id.setter - def custom_id(self, value: str): - if not isinstance(value, str): - raise TypeError(f"expected custom_id to be str, not {value.__class__.__name__}") - if len(value) > 100: - raise ValueError("custom_id must be 100 characters or fewer") - self._custom_id = value - - async def callback(self, interaction: Interaction): - """|coro| - - The coroutine that is called when the modal dialog is submitted. - Should be overridden to handle the values submitted by the user. - - Parameters - ---------- - interaction: :class:`~discord.Interaction` - The interaction that submitted the modal dialog. - """ - self.stop() - - def to_components(self) -> list[dict[str, Any]]: - def key(item: InputText) -> int: - return item._rendered_row or 0 - - children = sorted(self._children, key=key) - components: list[dict[str, Any]] = [] - for _, group in groupby(children, key=key): - children = [item.to_component_dict() for item in group] - if not children: - continue - - components.append( - { - "type": 1, - "components": children, - } - ) - - return components - - def add_item(self, item: InputText): - """Adds an InputText component to the modal dialog. - - Parameters - ---------- - item: :class:`InputText` - The item to add to the modal dialog - """ - - if len(self._children) > 5: - raise ValueError("You can only have up to 5 items in a modal dialog.") - - if not isinstance(item, InputText): - raise TypeError(f"expected InputText not {item.__class__!r}") - - self._weights.add_item(item) - self._children.append(item) - - def remove_item(self, item: InputText): - """Removes an InputText component from the modal dialog. - - Parameters - ---------- - item: :class:`InputText` - The item to remove from the modal dialog. - """ - try: - self._children.remove(item) - except ValueError: - pass - - def stop(self) -> None: - """Stops listening to interaction events from the modal dialog.""" - if not self._stopped.done(): - self._stopped.set_result(True) - self.__timeout_expiry = None - if self.__timeout_task is not None: - self.__timeout_task.cancel() - self.__timeout_task = None - - async def wait(self) -> bool: - """Waits for the modal dialog to be submitted.""" - return await self._stopped - - def to_dict(self): - return { - "title": self.title, - "custom_id": self.custom_id, - "components": self.to_components(), - } - - async def on_error(self, error: Exception, interaction: Interaction) -> None: - """|coro| - - A callback that is called when the modal's callback fails with an error. - - The default implementation prints the traceback to stderr. - - Parameters - ---------- - error: :class:`Exception` - The exception that was raised. - interaction: :class:`~discord.Interaction` - The interaction that led to the failure. - """ - print(f"Ignoring exception in modal {self}:", file=sys.stderr) - traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) - - async def on_timeout(self) -> None: - """|coro| - - A callback that is called when a modal's timeout elapses without being explicitly stopped. - """ - - -class _ModalWeights: - __slots__ = ("weights",) - - def __init__(self, children: list[InputText]): - self.weights: list[int] = [0, 0, 0, 0, 0] - - key = lambda i: sys.maxsize if i.row is None else i.row - children = sorted(children, key=key) - for row, group in groupby(children, key=key): - for item in group: - self.add_item(item) - - def find_open_space(self, item: InputText) -> int: - for index, weight in enumerate(self.weights): - if weight + item.width <= 5: - return index - - raise ValueError("could not find open space for item") - - def add_item(self, item: InputText) -> None: - if item.row is not None: - total = self.weights[item.row] + item.width - if total > 5: - raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)") - self.weights[item.row] = total - item._rendered_row = item.row - else: - index = self.find_open_space(item) - self.weights[index] += item.width - item._rendered_row = index - - def remove_item(self, item: InputText) -> None: - if item._rendered_row is not None: - self.weights[item._rendered_row] -= item.width - item._rendered_row = None - - def clear(self) -> None: - self.weights = [0, 0, 0, 0, 0] - - -class ModalStore: - def __init__(self, state: ConnectionState) -> None: - # (user_id, custom_id) : Modal - self._modals: dict[tuple[int, str], Modal] = {} - self._state: ConnectionState = state - - def add_modal(self, modal: Modal, user_id: int): - self._modals[(user_id, modal.custom_id)] = modal - modal._start_listening_from_store(self) - - def remove_modal(self, modal: Modal, user_id): - modal.stop() - self._modals.pop((user_id, modal.custom_id)) - - async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction): - key = (user_id, custom_id) - value = self._modals.get(key) - if value is None: - return - - try: - components = [ - component - for parent_component in interaction.data["components"] - for component in parent_component["components"] - ] - for component in components: - for child in value.children: - if child.custom_id == component["custom_id"]: # type: ignore - child.refresh_state(component) - break - await value.callback(interaction) - self.remove_modal(value, user_id) - except Exception as e: - return await value.on_error(e, interaction) diff --git a/discord/ui/select.py b/discord/ui/select.py deleted file mode 100644 index 6e0e43751a..0000000000 --- a/discord/ui/select.py +++ /dev/null @@ -1,649 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import inspect -import os -from typing import TYPE_CHECKING, Callable, TypeVar - -from ..channel import _threaded_guild_channel_factory -from ..components import SelectMenu, SelectOption -from ..emoji import AppEmoji, GuildEmoji -from ..enums import ChannelType, ComponentType -from ..errors import InvalidArgument -from ..interactions import Interaction -from ..member import Member -from ..partial_emoji import PartialEmoji -from ..role import Role -from ..threads import Thread -from ..user import User -from ..utils import MISSING -from .item import Item, ItemCallbackType -from discord import utils - -__all__ = ( - "Select", - "select", - "string_select", - "user_select", - "role_select", - "mentionable_select", - "channel_select", -) - -if TYPE_CHECKING: - from ..abc import GuildChannel - from ..types.components import SelectMenu as SelectMenuPayload - from ..types.interactions import ComponentInteractionData - from .view import View - -S = TypeVar("S", bound="Select") -V = TypeVar("V", bound="View", covariant=True) - - -class Select(Item[V]): - """Represents a UI select menu. - - This is usually represented as a drop down menu. - - In order to get the selected items that the user has chosen, use :attr:`Select.values`. - - .. versionadded:: 2.0 - - .. versionchanged:: 2.3 - - Added support for :attr:`discord.ComponentType.string_select`, :attr:`discord.ComponentType.user_select`, - :attr:`discord.ComponentType.role_select`, :attr:`discord.ComponentType.mentionable_select`, - and :attr:`discord.ComponentType.channel_select`. - - Parameters - ---------- - select_type: :class:`discord.ComponentType` - The type of select to create. Must be one of - :attr:`discord.ComponentType.string_select`, :attr:`discord.ComponentType.user_select`, - :attr:`discord.ComponentType.role_select`, :attr:`discord.ComponentType.mentionable_select`, - or :attr:`discord.ComponentType.channel_select`. - custom_id: :class:`str` - The ID of the select menu that gets received during an interaction. - If not given then one is generated for you. - placeholder: Optional[:class:`str`] - The placeholder text that is shown if nothing is selected, if any. - min_values: :class:`int` - The minimum number of items that must be chosen for this select menu. - Defaults to 1 and must be between 1 and 25. - max_values: :class:`int` - The maximum number of items that must be chosen for this select menu. - Defaults to 1 and must be between 1 and 25. - options: List[:class:`discord.SelectOption`] - A list of options that can be selected in this menu. - Only valid for selects of type :attr:`discord.ComponentType.string_select`. - channel_types: List[:class:`discord.ChannelType`] - A list of channel types that can be selected in this menu. - Only valid for selects of type :attr:`discord.ComponentType.channel_select`. - disabled: :class:`bool` - Whether the select is disabled or not. - row: Optional[:class:`int`] - The relative row this select menu belongs to. A Discord component can only have 5 - rows. By default, items are arranged automatically into those 5 rows. If you'd - like to control the relative positioning of the row then passing an index is advised. - For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic - ordering. The row number must be between 0 and 4 (i.e. zero indexed). - """ - - __item_repr_attributes__: tuple[str, ...] = ( - "type", - "placeholder", - "min_values", - "max_values", - "options", - "channel_types", - "disabled", - ) - - def __init__( - self, - select_type: ComponentType = ComponentType.string_select, - *, - custom_id: str | None = None, - placeholder: str | None = None, - min_values: int = 1, - max_values: int = 1, - options: list[SelectOption] | None = None, - channel_types: list[ChannelType] | None = None, - disabled: bool = False, - row: int | None = None, - ) -> None: - if options and select_type is not ComponentType.string_select: - raise InvalidArgument("options parameter is only valid for string selects") - if channel_types and select_type is not ComponentType.channel_select: - raise InvalidArgument("channel_types parameter is only valid for channel selects") - super().__init__() - self._selected_values: list[str] = [] - self._interaction: Interaction | None = None - if min_values < 0 or min_values > 25: - raise ValueError("min_values must be between 0 and 25") - if max_values < 1 or max_values > 25: - raise ValueError("max_values must be between 1 and 25") - if placeholder and len(placeholder) > 150: - raise ValueError("placeholder must be 150 characters or fewer") - if not isinstance(custom_id, str) and custom_id is not None: - raise TypeError(f"expected custom_id to be str, not {custom_id.__class__.__name__}") - - self._provided_custom_id = custom_id is not None - custom_id = os.urandom(16).hex() if custom_id is None else custom_id - self._underlying: SelectMenu = SelectMenu._raw_construct( - custom_id=custom_id, - type=select_type, - placeholder=placeholder, - min_values=min_values, - max_values=max_values, - disabled=disabled, - options=options or [], - channel_types=channel_types or [], - ) - self.row = row - - @property - def custom_id(self) -> str: - """The ID of the select menu that gets received during an interaction.""" - return self._underlying.custom_id - - @custom_id.setter - def custom_id(self, value: str): - if not isinstance(value, str): - raise TypeError("custom_id must be None or str") - if len(value) > 100: - raise ValueError("custom_id must be 100 characters or fewer") - self._underlying.custom_id = value - - @property - def placeholder(self) -> str | None: - """The placeholder text that is shown if nothing is selected, if any.""" - return self._underlying.placeholder - - @placeholder.setter - def placeholder(self, value: str | None): - if value is not None and not isinstance(value, str): - raise TypeError("placeholder must be None or str") - if value and len(value) > 150: - raise ValueError("placeholder must be 150 characters or fewer") - - self._underlying.placeholder = value - - @property - def min_values(self) -> int: - """The minimum number of items that must be chosen for this select menu.""" - return self._underlying.min_values - - @min_values.setter - def min_values(self, value: int): - if value < 0 or value > 25: - raise ValueError("min_values must be between 0 and 25") - self._underlying.min_values = int(value) - - @property - def max_values(self) -> int: - """The maximum number of items that must be chosen for this select menu.""" - return self._underlying.max_values - - @max_values.setter - def max_values(self, value: int): - if value < 1 or value > 25: - raise ValueError("max_values must be between 1 and 25") - self._underlying.max_values = int(value) - - @property - def disabled(self) -> bool: - """Whether the select is disabled or not.""" - return self._underlying.disabled - - @disabled.setter - def disabled(self, value: bool): - self._underlying.disabled = bool(value) - - @property - def channel_types(self) -> list[ChannelType]: - """A list of channel types that can be selected in this menu.""" - return self._underlying.channel_types - - @channel_types.setter - def channel_types(self, value: list[ChannelType]): - if self._underlying.type is not ComponentType.channel_select: - raise InvalidArgument("channel_types can only be set on channel selects") - self._underlying.channel_types = value - - @property - def options(self) -> list[SelectOption]: - """A list of options that can be selected in this menu.""" - return self._underlying.options - - @options.setter - def options(self, value: list[SelectOption]): - if self._underlying.type is not ComponentType.string_select: - raise InvalidArgument("options can only be set on string selects") - if not isinstance(value, list): - raise TypeError("options must be a list of SelectOption") - if not all(isinstance(obj, SelectOption) for obj in value): - raise TypeError("all list items must subclass SelectOption") - - self._underlying.options = value - - def add_option( - self, - *, - label: str, - value: str | utils.Undefined = MISSING, - description: str | None = None, - emoji: str | GuildEmoji | AppEmoji | PartialEmoji | None = None, - default: bool = False, - ): - """Adds an option to the select menu. - - To append a pre-existing :class:`discord.SelectOption` use the - :meth:`append_option` method instead. - - Parameters - ---------- - label: :class:`str` - The label of the option. This is displayed to users. - Can only be up to 100 characters. - value: :class:`str` - The value of the option. This is not displayed to users. - If not given, defaults to the label. Can only be up to 100 characters. - description: Optional[:class:`str`] - An additional description of the option, if any. - Can only be up to 100 characters. - emoji: Optional[Union[:class:`str`, :class:`GuildEmoji`, :class:`AppEmoji`, :class:`.PartialEmoji`]] - The emoji of the option, if available. This can either be a string representing - the custom or unicode emoji or an instance of :class:`.PartialEmoji`, :class:`GuildEmoji`, or :class:`AppEmoji`. - default: :class:`bool` - Whether this option is selected by default. - - Raises - ------ - ValueError - The number of options exceeds 25. - """ - if self._underlying.type is not ComponentType.string_select: - raise Exception("options can only be set on string selects") - - option = SelectOption( - label=label, - value=value, - description=description, - emoji=emoji, - default=default, - ) - - self.append_option(option) - - def append_option(self, option: SelectOption): - """Appends an option to the select menu. - - Parameters - ---------- - option: :class:`discord.SelectOption` - The option to append to the select menu. - - Raises - ------ - ValueError - The number of options exceeds 25. - """ - if self._underlying.type is not ComponentType.string_select: - raise Exception("options can only be set on string selects") - - if len(self._underlying.options) > 25: - raise ValueError("maximum number of options already provided") - - self._underlying.options.append(option) - - @property - def values( - self, - ) -> list[str] | list[Member | User] | list[Role] | list[Member | User | Role] | list[GuildChannel | Thread]: - """List[:class:`str`] | List[:class:`discord.Member` | :class:`discord.User`]] | List[:class:`discord.Role`]] | - List[:class:`discord.Member` | :class:`discord.User` | :class:`discord.Role`]] | List[:class:`discord.abc.GuildChannel`] | None: - A list of values that have been selected by the user. This will be ``None`` if the select has not been interacted with yet. - """ - if self._interaction is None: - # The select has not been interacted with yet - return None - select_type = self._underlying.type - if select_type is ComponentType.string_select: - return self._selected_values - resolved = [] - selected_values = list(self._selected_values) - state = self._interaction._state - guild = self._interaction.guild - resolved_data = self._interaction.data.get("resolved", {}) - if select_type is ComponentType.channel_select: - for channel_id, _data in resolved_data.get("channels", {}).items(): - if channel_id not in selected_values: - continue - if int(channel_id) in guild._channels or int(channel_id) in guild._threads: - result = guild.get_channel_or_thread(int(channel_id)) - _data["_invoke_flag"] = True - (result._update(_data) if isinstance(result, Thread) else result._update(guild, _data)) - else: - # NOTE: - # This is a fallback in case the channel/thread is not found in the - # guild's channels/threads. For channels, if this fallback occurs, at the very minimum, - # permissions will be incorrect due to a lack of permission_overwrite data. - # For threads, if this fallback occurs, info like thread owner id, message count, - # flags, and more will be missing due to a lack of data sent by Discord. - obj_type = _threaded_guild_channel_factory(_data["type"])[0] - result = obj_type(state=state, data=_data, guild=guild) - resolved.append(result) - elif select_type in ( - ComponentType.user_select, - ComponentType.mentionable_select, - ): - cache_flag = state.member_cache_flags.interaction - resolved_user_data = resolved_data.get("users", {}) - resolved_member_data = resolved_data.get("members", {}) - for _id in selected_values: - if (_data := resolved_user_data.get(_id)) is not None: - if (_member_data := resolved_member_data.get(_id)) is not None: - member = dict(_member_data) - member["user"] = _data - _data = member - result = guild._get_and_update_member(_data, int(_id), cache_flag) - else: - result = User(state=state, data=_data) - resolved.append(result) - if select_type in (ComponentType.role_select, ComponentType.mentionable_select): - for role_id, _data in resolved_data.get("roles", {}).items(): - if role_id not in selected_values: - continue - resolved.append(Role(guild=guild, state=state, data=_data)) - return resolved - - @property - def width(self) -> int: - return 5 - - def to_component_dict(self) -> SelectMenuPayload: - return self._underlying.to_dict() - - def refresh_component(self, component: SelectMenu) -> None: - self._underlying = component - - def refresh_state(self, interaction: Interaction) -> None: - data: ComponentInteractionData = interaction.data # type: ignore - self._selected_values = data.get("values", []) - self._interaction = interaction - - @classmethod - def from_component(cls: type[S], component: SelectMenu) -> S: - return cls( - select_type=component.type, - custom_id=component.custom_id, - placeholder=component.placeholder, - min_values=component.min_values, - max_values=component.max_values, - options=component.options, - channel_types=component.channel_types, - disabled=component.disabled, - row=None, - ) - - @property - def type(self) -> ComponentType: - return self._underlying.type - - def is_dispatchable(self) -> bool: - return True - - -_select_types = ( - ComponentType.string_select, - ComponentType.user_select, - ComponentType.role_select, - ComponentType.mentionable_select, - ComponentType.channel_select, -) - - -def select( - select_type: ComponentType = ComponentType.string_select, - *, - placeholder: str | None = None, - custom_id: str | None = None, - min_values: int = 1, - max_values: int = 1, - options: list[SelectOption] | utils.Undefined = MISSING, - channel_types: list[ChannelType] | utils.Undefined = MISSING, - disabled: bool = False, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A decorator that attaches a select menu to a component. - - The function being decorated should have three parameters, ``self`` representing - the :class:`discord.ui.View`, the :class:`discord.ui.Select` being pressed and - the :class:`discord.Interaction` you receive. - - In order to get the selected items that the user has chosen within the callback - use :attr:`Select.values`. - - .. versionchanged:: 2.3 - - Creating select menus of different types is now supported. - - Parameters - ---------- - select_type: :class:`discord.ComponentType` - The type of select to create. Must be one of - :attr:`discord.ComponentType.string_select`, :attr:`discord.ComponentType.user_select`, - :attr:`discord.ComponentType.role_select`, :attr:`discord.ComponentType.mentionable_select`, - or :attr:`discord.ComponentType.channel_select`. - placeholder: Optional[:class:`str`] - The placeholder text that is shown if nothing is selected, if any. - custom_id: :class:`str` - The ID of the select menu that gets received during an interaction. - It is recommended not to set this parameter to prevent conflicts. - row: Optional[:class:`int`] - The relative row this select menu belongs to. A Discord component can only have 5 - rows. By default, items are arranged automatically into those 5 rows. If you'd - like to control the relative positioning of the row then passing an index is advised. - For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic - ordering. The row number must be between 0 and 4 (i.e. zero indexed). - min_values: :class:`int` - The minimum number of items that must be chosen for this select menu. - Defaults to 1 and must be between 0 and 25. - max_values: :class:`int` - The maximum number of items that must be chosen for this select menu. - Defaults to 1 and must be between 1 and 25. - options: List[:class:`discord.SelectOption`] - A list of options that can be selected in this menu. - Only valid for the :attr:`discord.ComponentType.string_select` type. - channel_types: List[:class:`discord.ChannelType`] - The channel types that should be selectable. - Only valid for the :attr:`discord.ComponentType.channel_select` type. - Defaults to all channel types. - disabled: :class:`bool` - Whether the select is disabled or not. Defaults to ``False``. - """ - if select_type not in _select_types: - raise ValueError("select_type must be one of " + ", ".join([i.name for i in _select_types])) - - if options is not MISSING and select_type not in ( - ComponentType.select, - ComponentType.string_select, - ): - raise TypeError("options may only be specified for string selects") - - if channel_types is not MISSING and select_type is not ComponentType.channel_select: - raise TypeError("channel_types may only be specified for channel selects") - - def decorator(func: ItemCallbackType) -> ItemCallbackType: - if not inspect.iscoroutinefunction(func): - raise TypeError("select function must be a coroutine function") - - model_kwargs = { - "select_type": select_type, - "placeholder": placeholder, - "custom_id": custom_id, - "row": row, - "min_values": min_values, - "max_values": max_values, - "disabled": disabled, - } - if options: - model_kwargs["options"] = options - if channel_types: - model_kwargs["channel_types"] = channel_types - - func.__discord_ui_model_type__ = Select - func.__discord_ui_model_kwargs__ = model_kwargs - - return func - - return decorator - - -def string_select( - *, - placeholder: str | None = None, - custom_id: str | None = None, - min_values: int = 1, - max_values: int = 1, - options: list[SelectOption] | utils.Undefined = MISSING, - disabled: bool = False, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A shortcut for :meth:`discord.ui.select` with select type :attr:`discord.ComponentType.string_select`. - - .. versionadded:: 2.3 - """ - return select( - ComponentType.string_select, - placeholder=placeholder, - custom_id=custom_id, - min_values=min_values, - max_values=max_values, - options=options, - disabled=disabled, - row=row, - ) - - -def user_select( - *, - placeholder: str | None = None, - custom_id: str | None = None, - min_values: int = 1, - max_values: int = 1, - disabled: bool = False, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A shortcut for :meth:`discord.ui.select` with select type :attr:`discord.ComponentType.user_select`. - - .. versionadded:: 2.3 - """ - return select( - ComponentType.user_select, - placeholder=placeholder, - custom_id=custom_id, - min_values=min_values, - max_values=max_values, - disabled=disabled, - row=row, - ) - - -def role_select( - *, - placeholder: str | None = None, - custom_id: str | None = None, - min_values: int = 1, - max_values: int = 1, - disabled: bool = False, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A shortcut for :meth:`discord.ui.select` with select type :attr:`discord.ComponentType.role_select`. - - .. versionadded:: 2.3 - """ - return select( - ComponentType.role_select, - placeholder=placeholder, - custom_id=custom_id, - min_values=min_values, - max_values=max_values, - disabled=disabled, - row=row, - ) - - -def mentionable_select( - *, - placeholder: str | None = None, - custom_id: str | None = None, - min_values: int = 1, - max_values: int = 1, - disabled: bool = False, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A shortcut for :meth:`discord.ui.select` with select type :attr:`discord.ComponentType.mentionable_select`. - - .. versionadded:: 2.3 - """ - return select( - ComponentType.mentionable_select, - placeholder=placeholder, - custom_id=custom_id, - min_values=min_values, - max_values=max_values, - disabled=disabled, - row=row, - ) - - -def channel_select( - *, - placeholder: str | None = None, - custom_id: str | None = None, - min_values: int = 1, - max_values: int = 1, - disabled: bool = False, - channel_types: list[ChannelType] | utils.Undefined = MISSING, - row: int | None = None, -) -> Callable[[ItemCallbackType], ItemCallbackType]: - """A shortcut for :meth:`discord.ui.select` with select type :attr:`discord.ComponentType.channel_select`. - - .. versionadded:: 2.3 - """ - return select( - ComponentType.channel_select, - placeholder=placeholder, - custom_id=custom_id, - min_values=min_values, - max_values=max_values, - disabled=disabled, - channel_types=channel_types, - row=row, - ) diff --git a/discord/ui/view.py b/discord/ui/view.py deleted file mode 100644 index 1b45077fb8..0000000000 --- a/discord/ui/view.py +++ /dev/null @@ -1,621 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import asyncio -import os -import sys -import time -import traceback -from functools import partial -from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence, TypeVar - -from ..components import ActionRow as ActionRowComponent -from ..components import Button as ButtonComponent -from ..components import Component -from ..components import SelectMenu as SelectComponent -from ..components import _component_factory -from ..utils import get -from .item import Item, ItemCallbackType - -__all__ = ("View",) - - -if TYPE_CHECKING: - from ..interactions import Interaction, InteractionMessage - from ..message import Message - from ..state import ConnectionState - from ..types.components import Component as ComponentPayload - -V = TypeVar("V", bound="View", covariant=True) - - -def _walk_all_components(components: list[Component]) -> Iterator[Component]: - for item in components: - if isinstance(item, ActionRowComponent): - yield from item.children - else: - yield item - - -def _component_to_item(component: Component) -> Item[V]: - if isinstance(component, ButtonComponent): - from .button import Button # noqa: PLC0415 - - return Button.from_component(component) - if isinstance(component, SelectComponent): - from .select import Select # noqa: PLC0415 - - return Select.from_component(component) - return Item.from_component(component) - - -class _ViewWeights: - __slots__ = ("weights",) - - def __init__(self, children: list[Item[V]]): - self.weights: list[int] = [0, 0, 0, 0, 0] - - key = lambda i: sys.maxsize if i.row is None else i.row - children = sorted(children, key=key) - for row, group in groupby(children, key=key): - for item in group: - self.add_item(item) - - def find_open_space(self, item: Item[V]) -> int: - for index, weight in enumerate(self.weights): - if weight + item.width <= 5: - return index - - raise ValueError("could not find open space for item") - - def add_item(self, item: Item[V]) -> None: - if item.row is not None: - total = self.weights[item.row] + item.width - if total > 5: - raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)") - self.weights[item.row] = total - item._rendered_row = item.row - else: - index = self.find_open_space(item) - self.weights[index] += item.width - item._rendered_row = index - - def remove_item(self, item: Item[V]) -> None: - if item._rendered_row is not None: - self.weights[item._rendered_row] -= item.width - item._rendered_row = None - - def clear(self) -> None: - self.weights = [0, 0, 0, 0, 0] - - -class View: - """Represents a UI view. - - This object must be inherited to create a UI within Discord. - - .. versionadded:: 2.0 - - Parameters - ---------- - *items: :class:`Item` - The initial items attached to this view. - timeout: Optional[:class:`float`] - Timeout in seconds from last interaction with the UI before no longer accepting input. Defaults to 180.0. - If ``None`` then there is no timeout. - - Attributes - ---------- - timeout: Optional[:class:`float`] - Timeout from last interaction with the UI before no longer accepting input. - If ``None`` then there is no timeout. - children: List[:class:`Item`] - The list of children attached to this view. - disable_on_timeout: :class:`bool` - Whether to disable the view when the timeout is reached. Defaults to ``False``. - message: Optional[:class:`.Message`] - The message that this view is attached to. - If ``None`` then the view has not been sent with a message. - parent: Optional[:class:`.Interaction`] - The parent interaction which this view was sent from. - If ``None`` then the view was not sent using :meth:`InteractionResponse.send_message`. - """ - - __discord_ui_view__: ClassVar[bool] = True - __view_children_items__: ClassVar[list[ItemCallbackType]] = [] - - def __init_subclass__(cls) -> None: - children: list[ItemCallbackType] = [] - for base in reversed(cls.__mro__): - for member in base.__dict__.values(): - if hasattr(member, "__discord_ui_model_type__"): - children.append(member) - - if len(children) > 25: - raise TypeError("View cannot have more than 25 children") - - cls.__view_children_items__ = children - - def __init__( - self, - *items: Item[V], - timeout: float | None = 180.0, - disable_on_timeout: bool = False, - ): - self.timeout = timeout - self.disable_on_timeout = disable_on_timeout - self.children: list[Item[V]] = [] - for func in self.__view_children_items__: - item: Item[V] = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) - item.callback = partial(func, self, item) - item._view = self - setattr(self, func.__name__, item) - self.children.append(item) - - self.__weights = _ViewWeights(self.children) - for item in items: - self.add_item(item) - - loop = asyncio.get_running_loop() - self.id: str = os.urandom(16).hex() - self.__cancel_callback: Callable[[View], None] | None = None - self.__timeout_expiry: float | None = None - self.__timeout_task: asyncio.Task[None] | None = None - self.__stopped: asyncio.Future[bool] = loop.create_future() - self._message: Message | InteractionMessage | None = None - self.parent: Interaction | None = None - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>" - - async def __timeout_task_impl(self) -> None: - while True: - # Guard just in case someone changes the value of the timeout at runtime - if self.timeout is None: - return - - if self.__timeout_expiry is None: - return self._dispatch_timeout() - - # Check if we've elapsed our currently set timeout - now = time.monotonic() - if now >= self.__timeout_expiry: - return self._dispatch_timeout() - - # Wait N seconds to see if timeout data has been refreshed - await asyncio.sleep(self.__timeout_expiry - now) - - def to_components(self) -> list[dict[str, Any]]: - def key(item: Item[V]) -> int: - return item._rendered_row or 0 - - children = sorted(self.children, key=key) - components: list[dict[str, Any]] = [] - for _, group in groupby(children, key=key): - children = [item.to_component_dict() for item in group] - if not children: - continue - - components.append( - { - "type": 1, - "components": children, - } - ) - - return components - - @classmethod - def from_message(cls, message: Message, /, *, timeout: float | None = 180.0) -> View: - """Converts a message's components into a :class:`View`. - - The :attr:`.Message.components` of a message are read-only - and separate types from those in the ``discord.ui`` namespace. - In order to modify and edit message components they must be - converted into a :class:`View` first. - - Parameters - ---------- - message: :class:`.Message` - The message with components to convert into a view. - timeout: Optional[:class:`float`] - The timeout of the converted view. - - Returns - ------- - :class:`View` - The converted view. This always returns a :class:`View` and not - one of its subclasses. - """ - view = View(timeout=timeout) - for component in _walk_all_components(message.components): - view.add_item(_component_to_item(component)) - return view - - @property - def _expires_at(self) -> float | None: - if self.timeout: - return time.monotonic() + self.timeout - return None - - def add_item(self, item: Item[V]) -> None: - """Adds an item to the view. - - Parameters - ---------- - item: :class:`Item` - The item to add to the view. - - Raises - ------ - TypeError - An :class:`Item` was not passed. - ValueError - Maximum number of children has been exceeded (25) - or the row the item is trying to be added to is full. - """ - - if len(self.children) > 25: - raise ValueError("maximum number of children exceeded") - - if not isinstance(item, Item): - raise TypeError(f"expected Item not {item.__class__!r}") - - self.__weights.add_item(item) - - item._view = self - self.children.append(item) - - def remove_item(self, item: Item[V]) -> None: - """Removes an item from the view. - - Parameters - ---------- - item: :class:`Item` - The item to remove from the view. - """ - - try: - self.children.remove(item) - except ValueError: - pass - else: - self.__weights.remove_item(item) - - def clear_items(self) -> None: - """Removes all items from the view.""" - self.children.clear() - self.__weights.clear() - - def get_item(self, custom_id: str) -> Item[V] | None: - """Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`. - - Parameters - ---------- - custom_id: :class:`str` - The custom_id of the item to get - - Returns - ------- - Optional[:class:`Item`] - The item with the matching ``custom_id`` if it exists. - """ - return get(self.children, custom_id=custom_id) - - async def interaction_check(self, interaction: Interaction) -> bool: - """|coro| - - A callback that is called when an interaction happens within the view - that checks whether the view should process item callbacks for the interaction. - - This is useful to override if, for example, you want to ensure that the - interaction author is a given user. - - The default implementation of this returns ``True``. - - If this returns ``False``, :meth:`on_check_failure` is called. - - .. note:: - - If an exception occurs within the body then the check - is considered a failure and :meth:`on_error` is called. - - Parameters - ---------- - interaction: :class:`~discord.Interaction` - The interaction that occurred. - - Returns - ------- - :class:`bool` - Whether the view children's callbacks should be called. - """ - return True - - async def on_timeout(self) -> None: - """|coro| - - A callback that is called when a view's timeout elapses without being explicitly stopped. - """ - if self.disable_on_timeout: - self.disable_all_items() - - if not self._message or self._message.flags.ephemeral: - message = self.parent - else: - message = self.message - - if message: - m = await message.edit(view=self) - if m: - self._message = m - - async def on_check_failure(self, interaction: Interaction) -> None: - """|coro| - A callback that is called when a :meth:`View.interaction_check` returns ``False``. - This can be used to send a response when a check failure occurs. - - Parameters - ---------- - interaction: :class:`~discord.Interaction` - The interaction that occurred. - """ - - async def on_error(self, error: Exception, item: Item[V], interaction: Interaction) -> None: - """|coro| - - A callback that is called when an item's callback or :meth:`interaction_check` - fails with an error. - - The default implementation prints the traceback to stderr. - - Parameters - ---------- - error: :class:`Exception` - The exception that was raised. - item: :class:`Item` - The item that failed the dispatch. - interaction: :class:`~discord.Interaction` - The interaction that led to the failure. - """ - print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr) - traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) - - async def _scheduled_task(self, item: Item[V], interaction: Interaction): - try: - if self.timeout: - self.__timeout_expiry = time.monotonic() + self.timeout - - allow = await self.interaction_check(interaction) - if not allow: - return await self.on_check_failure(interaction) - - await item.callback(interaction) - except Exception as e: - return await self.on_error(e, item, interaction) - - def _start_listening_from_store(self, store: ViewStore) -> None: - self.__cancel_callback = partial(store.remove_view) - if self.timeout: - loop = asyncio.get_running_loop() - if self.__timeout_task is not None: - self.__timeout_task.cancel() - - self.__timeout_expiry = time.monotonic() + self.timeout - self.__timeout_task = loop.create_task(self.__timeout_task_impl()) - - def _dispatch_timeout(self): - if self.__stopped.done(): - return - - self.__stopped.set_result(True) - asyncio.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}") - - def _dispatch_item(self, item: Item[V], interaction: Interaction): - if self.__stopped.done(): - return - - if interaction.message: - self.message = interaction.message - - asyncio.create_task( - self._scheduled_task(item, interaction), - name=f"discord-ui-view-dispatch-{self.id}", - ) - - def refresh(self, components: list[Component]): - # This is pretty hacky at the moment - old_state: dict[tuple[int, str], Item[V]] = { - (item.type.value, item.custom_id): item - for item in self.children - if item.is_dispatchable() # type: ignore - } - children: list[Item[V]] = [item for item in self.children if not item.is_dispatchable()] - for component in _walk_all_components(components): - try: - older = old_state[(component.type.value, component.custom_id)] # type: ignore - except (KeyError, AttributeError): - item = _component_to_item(component) - if not item.is_dispatchable(): - continue - children.append(item) - else: - older.refresh_component(component) - children.append(older) - - self.children = children - - def stop(self) -> None: - """Stops listening to interaction events from this view. - - This operation cannot be undone. - """ - if not self.__stopped.done(): - self.__stopped.set_result(False) - - self.__timeout_expiry = None - if self.__timeout_task is not None: - self.__timeout_task.cancel() - self.__timeout_task = None - - if self.__cancel_callback: - self.__cancel_callback(self) - self.__cancel_callback = None - - def is_finished(self) -> bool: - """Whether the view has finished interacting.""" - return self.__stopped.done() - - def is_dispatching(self) -> bool: - """Whether the view has been added for dispatching purposes.""" - return self.__cancel_callback is not None - - def is_persistent(self) -> bool: - """Whether the view is set up as persistent. - - A persistent view has all their components with a set ``custom_id`` and - a :attr:`timeout` set to ``None``. - """ - return self.timeout is None and all(item.is_persistent() for item in self.children) - - async def wait(self) -> bool: - """Waits until the view has finished interacting. - - A view is considered finished when :meth:`stop` - is called, or it times out. - - Returns - ------- - :class:`bool` - If ``True``, then the view timed out. If ``False`` then - the view finished normally. - """ - return await self.__stopped - - def disable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None: - """ - Disables all items in the view. - - Parameters - ---------- - exclusions: Optional[List[:class:`Item`]] - A list of items in `self.children` to not disable from the view. - """ - for child in self.children: - if exclusions is None or child not in exclusions: - child.disabled = True - - def enable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None: - """ - Enables all items in the view. - - Parameters - ---------- - exclusions: Optional[List[:class:`Item`]] - A list of items in `self.children` to not enable from the view. - """ - for child in self.children: - if exclusions is None or child not in exclusions: - child.disabled = False - - @property - def message(self): - return self._message - - @message.setter - def message(self, value): - self._message = value - - -class ViewStore: - def __init__(self, state: ConnectionState): - # (component_type, message_id, custom_id): (View, Item) - self._views: dict[tuple[int, int | None, str], tuple[View, Item[V]]] = {} - # message_id: View - self._synced_message_views: dict[int, View] = {} - self._state: ConnectionState = state - - @property - def persistent_views(self) -> Sequence[View]: - views = {view.id: view for (_, (view, _)) in self._views.items() if view.is_persistent()} - return list(views.values()) - - def __verify_integrity(self): - to_remove: list[tuple[int, int | None, str]] = [] - for k, (view, _) in self._views.items(): - if view.is_finished(): - to_remove.append(k) - - for k in to_remove: - del self._views[k] - - def add_view(self, view: View, message_id: int | None = None): - self.__verify_integrity() - - view._start_listening_from_store(self) - for item in view.children: - if item.is_dispatchable(): - self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore - - if message_id is not None: - self._synced_message_views[message_id] = view - - def remove_view(self, view: View): - for item in view.children: - if item.is_dispatchable(): - self._views.pop((item.type.value, item.custom_id), None) # type: ignore - - for key, value in self._synced_message_views.items(): - if value.id == view.id: - del self._synced_message_views[key] - break - - def dispatch(self, component_type: int, custom_id: str, interaction: Interaction): - self.__verify_integrity() - message_id: int | None = interaction.message and interaction.message.id - key = (component_type, message_id, custom_id) - # Fallback to None message_id searches in case a persistent view - # was added without an associated message_id - value = self._views.get(key) or self._views.get((component_type, None, custom_id)) - if value is None: - return - - view, item = value - item.refresh_state(interaction) - view._dispatch_item(item, interaction) - - def is_message_tracked(self, message_id: int): - return message_id in self._synced_message_views - - def remove_message_tracking(self, message_id: int) -> View | None: - return self._synced_message_views.pop(message_id, None) - - def update_from_message(self, message_id: int, components: list[ComponentPayload]): - # pre-req: is_message_tracked == true - view = self._synced_message_views[message_id] - view.refresh([_component_factory(d) for d in components]) diff --git a/discord/user.py b/discord/user.py index 10d2889048..e9aa9a465e 100644 --- a/discord/user.py +++ b/discord/user.py @@ -34,7 +34,8 @@ from .flags import PublicUserFlags from .iterators import EntitlementIterator from .monetization import Entitlement -from .utils import MISSING, Undefined, _bytes_to_base64_data, snowflake_time +from .utils import MISSING, Undefined, snowflake_time +from .utils.private import bytes_to_base64_data if TYPE_CHECKING: from datetime import datetime @@ -470,12 +471,12 @@ async def edit( if avatar is None: payload["avatar"] = None elif avatar is not MISSING: - payload["avatar"] = _bytes_to_base64_data(avatar) + payload["avatar"] = bytes_to_base64_data(avatar) if banner is None: payload["banner"] = None elif banner is not MISSING: - payload["banner"] = _bytes_to_base64_data(banner) + payload["banner"] = bytes_to_base64_data(banner) data: UserPayload = await self._state.http.edit_profile(payload) return ClientUser(state=self._state, data=data) diff --git a/discord/utils.py b/discord/utils.py deleted file mode 100644 index 18a53a9ed0..0000000000 --- a/discord/utils.py +++ /dev/null @@ -1,1403 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import array -import asyncio -import collections.abc -import datetime -from enum import Enum, auto -import functools -import itertools -import json -import re -import sys -import types -import unicodedata -import warnings -from base64 import b64encode -from bisect import bisect_left -from inspect import isawaitable as _isawaitable -from inspect import signature as _signature -from operator import attrgetter -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Awaitable, - Callable, - Coroutine, - ForwardRef, - Generic, - Iterable, - Iterator, - Literal, - Mapping, - Protocol, - Sequence, - TypeVar, - Union, - overload, -) - -from .errors import HTTPException, InvalidArgument - -try: - import msgspec -except ModuleNotFoundError: - HAS_MSGSPEC = False -else: - HAS_MSGSPEC = True - - -__all__ = ( - "parse_time", - "warn_deprecated", - "deprecated", - "oauth_url", - "snowflake_time", - "time_snowflake", - "find", - "get", - "get_or_fetch", - "sleep_until", - "utcnow", - "resolve_invite", - "resolve_template", - "remove_markdown", - "escape_markdown", - "escape_mentions", - "raw_mentions", - "raw_channel_mentions", - "raw_role_mentions", - "as_chunks", - "format_dt", - "generate_snowflake", - "basic_autocomplete", - "filter_params", -) - -DISCORD_EPOCH = 1420070400000 - - -class Undefined(Enum): - MISSING = auto() - - def __bool__(self) -> Literal[False]: - return False - - -MISSING: Literal[Undefined.MISSING] = Undefined.MISSING - - -class _cached_property: - def __init__(self, function): - self.function = function - self.__doc__ = getattr(function, "__doc__") - - def __get__(self, instance, owner): - if instance is None: - return self - - value = self.function(instance) - setattr(instance, self.function.__name__, value) - - return value - - -if TYPE_CHECKING: - from typing_extensions import ParamSpec - - from .abc import Snowflake - from .commands.context import AutocompleteContext - from .commands.options import OptionChoice - from .invite import Invite - from .permissions import Permissions - from .template import Template - - class _RequestLike(Protocol): - headers: Mapping[str, Any] - - cached_property = property - - P = ParamSpec("P") - -else: - cached_property = _cached_property - AutocompleteContext = Any - OptionChoice = Any - - -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -_Iter = Union[Iterator[T], AsyncIterator[T]] - - -class CachedSlotProperty(Generic[T, T_co]): - def __init__(self, name: str, function: Callable[[T], T_co]) -> None: - self.name = name - self.function = function - self.__doc__ = getattr(function, "__doc__") - - @overload - def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... - - @overload - def __get__(self, instance: T, owner: type[T]) -> T_co: ... - - def __get__(self, instance: T | None, owner: type[T]) -> Any: - if instance is None: - return self - - try: - return getattr(instance, self.name) - except AttributeError: - value = self.function(instance) - setattr(instance, self.name, value) - return value - - -class classproperty(Generic[T_co]): - def __init__(self, fget: Callable[[Any], T_co]) -> None: - self.fget = fget - - def __get__(self, instance: Any | None, owner: type[Any]) -> T_co: - return self.fget(owner) - - def __set__(self, instance, value) -> None: - raise AttributeError("cannot set attribute") - - -def cached_slot_property( - name: str, -) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: - def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: - return CachedSlotProperty(name, func) - - return decorator - - -class SequenceProxy(Generic[T_co], collections.abc.Sequence): - """Read-only proxy of a Sequence.""" - - def __init__(self, proxied: Sequence[T_co]): - self.__proxied = proxied - - def __getitem__(self, idx: int) -> T_co: - return self.__proxied[idx] - - def __len__(self) -> int: - return len(self.__proxied) - - def __contains__(self, item: Any) -> bool: - return item in self.__proxied - - def __iter__(self) -> Iterator[T_co]: - return iter(self.__proxied) - - def __reversed__(self) -> Iterator[T_co]: - return reversed(self.__proxied) - - def index(self, value: Any, *args, **kwargs) -> int: - return self.__proxied.index(value, *args, **kwargs) - - def count(self, value: Any) -> int: - return self.__proxied.count(value) - - -def delay_task(delay: float, func: Coroutine): - async def inner_call(): - await asyncio.sleep(delay) - try: - await func - except HTTPException: - pass - - asyncio.create_task(inner_call()) - - -@overload -def parse_time(timestamp: None) -> None: ... - - -@overload -def parse_time(timestamp: str) -> datetime.datetime: ... - - -@overload -def parse_time(timestamp: str | None) -> datetime.datetime | None: ... - - -def parse_time(timestamp: str | None) -> datetime.datetime | None: - """A helper function to convert an ISO 8601 timestamp to a datetime object. - - Parameters - ---------- - timestamp: Optional[:class:`str`] - The timestamp to convert. - - Returns - ------- - Optional[:class:`datetime.datetime`] - The converted datetime object. - """ - if timestamp: - return datetime.datetime.fromisoformat(timestamp) - return None - - -def copy_doc(original: Callable) -> Callable[[T], T]: - def decorator(overridden: T) -> T: - overridden.__doc__ = original.__doc__ - overridden.__signature__ = _signature(original) # type: ignore - return overridden - - return decorator - - -def warn_deprecated( - name: str, - instead: str | None = None, - since: str | None = None, - removed: str | None = None, - reference: str | None = None, - stacklevel: int = 3, -) -> None: - """Warn about a deprecated function, with the ability to specify details about the deprecation. Emits a - DeprecationWarning. - - Parameters - ---------- - name: str - The name of the deprecated function. - instead: Optional[:class:`str`] - A recommended alternative to the function. - since: Optional[:class:`str`] - The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where - the patch version is optional. - removed: Optional[:class:`str`] - The version in which the function is planned to be removed. This should be in the format - ``major.minor(.patch)``, where the patch version is optional. - reference: Optional[:class:`str`] - A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub - issue/PR. - stacklevel: :class:`int` - The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. - """ - warnings.simplefilter("always", DeprecationWarning) # turn off filter - message = f"{name} is deprecated" - if since: - message += f" since version {since}" - if removed: - message += f" and will be removed in version {removed}" - if instead: - message += f", consider using {instead} instead" - message += "." - if reference: - message += f" See {reference} for more information." - - warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning) - warnings.simplefilter("default", DeprecationWarning) # reset filter - - -def deprecated( - instead: str | None = None, - since: str | None = None, - removed: str | None = None, - reference: str | None = None, - stacklevel: int = 3, - *, - use_qualname: bool = True, -) -> Callable[[Callable[[P], T]], Callable[[P], T]]: - """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when - the decorated function is called. - - Parameters - ---------- - instead: Optional[:class:`str`] - A recommended alternative to the function. - since: Optional[:class:`str`] - The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where - the patch version is optional. - removed: Optional[:class:`str`] - The version in which the function is planned to be removed. This should be in the format - ``major.minor(.patch)``, where the patch version is optional. - reference: Optional[:class:`str`] - A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub - issue/PR. - stacklevel: :class:`int` - The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. - use_qualname: :class:`bool` - Whether to use the qualified name of the function in the deprecation warning. If ``False``, the short name of - the function will be used instead. For example, __qualname__ will display as ``Client.login`` while __name__ - will display as ``login``. Defaults to ``True``. - """ - - def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: - @functools.wraps(func) - def decorated(*args: P.args, **kwargs: P.kwargs) -> T: - warn_deprecated( - name=func.__qualname__ if use_qualname else func.__name__, - instead=instead, - since=since, - removed=removed, - reference=reference, - stacklevel=stacklevel, - ) - return func(*args, **kwargs) - - return decorated - - return actual_decorator - - -def oauth_url( - client_id: int | str, - *, - permissions: Permissions | Undefined = MISSING, - guild: Snowflake | Undefined = MISSING, - redirect_uri: str | Undefined = MISSING, - scopes: Iterable[str] | Undefined = MISSING, - disable_guild_select: bool = False, -) -> str: - """A helper function that returns the OAuth2 URL for inviting the bot - into guilds. - - Parameters - ---------- - client_id: Union[:class:`int`, :class:`str`] - The client ID for your bot. - permissions: :class:`~discord.Permissions` - The permissions you're requesting. If not given then you won't be requesting any - permissions. - guild: :class:`~discord.abc.Snowflake` - The guild to pre-select in the authorization screen, if available. - redirect_uri: :class:`str` - An optional valid redirect URI. - scopes: Iterable[:class:`str`] - An optional valid list of scopes. Defaults to ``('bot',)``. - - .. versionadded:: 1.7 - disable_guild_select: :class:`bool` - Whether to disallow the user from changing the guild dropdown. - - .. versionadded:: 2.0 - - Returns - ------- - :class:`str` - The OAuth2 URL for inviting the bot into guilds. - """ - url = f"https://discord.com/oauth2/authorize?client_id={client_id}" - url += f"&scope={'+'.join(scopes or ('bot',))}" - if permissions is not MISSING: - url += f"&permissions={permissions.value}" - if guild is not MISSING: - url += f"&guild_id={guild.id}" - if redirect_uri is not MISSING: - from urllib.parse import urlencode # noqa: PLC0415 - - url += f"&response_type=code&{urlencode({'redirect_uri': redirect_uri})}" - if disable_guild_select: - url += "&disable_guild_select=true" - return url - - -def snowflake_time(id: int) -> datetime.datetime: - """Converts a Discord snowflake ID to a UTC-aware datetime object. - - Parameters - ---------- - id: :class:`int` - The snowflake ID. - - Returns - ------- - :class:`datetime.datetime` - An aware datetime in UTC representing the creation time of the snowflake. - """ - timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 - return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) - - -def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: - """Returns a numeric snowflake pretending to be created at the given date. - - When using as the lower end of a range, use ``time_snowflake(high=False) - 1`` - to be inclusive, ``high=True`` to be exclusive. - - When using as the higher end of a range, use ``time_snowflake(high=True) + 1`` - to be inclusive, ``high=False`` to be exclusive - - Parameters - ---------- - dt: :class:`datetime.datetime` - A datetime object to convert to a snowflake. - If naive, the timezone is assumed to be local time. - high: :class:`bool` - Whether to set the lower 22 bit to high or low. - - Returns - ------- - :class:`int` - The snowflake representing the time given. - """ - discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) - return (discord_millis << 22) + (2**22 - 1 if high else 0) - - -def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: - """A helper to return the first element found in the sequence - that meets the predicate. For example: :: - - member = discord.utils.find(lambda m: m.name == "Mighty", channel.guild.members) - - would find the first :class:`~discord.Member` whose name is 'Mighty' and return it. - If an entry is not found, then ``None`` is returned. - - This is different from :func:`py:filter` due to the fact it stops the moment it finds - a valid entry. - - Parameters - ---------- - predicate - A function that returns a boolean-like result. - seq: :class:`collections.abc.Iterable` - The iterable to search through. - """ - - for element in seq: - if predicate(element): - return element - return None - - -def get(iterable: Iterable[T], **attrs: Any) -> T | None: - r"""A helper that returns the first element in the iterable that meets - all the traits passed in ``attrs``. This is an alternative for - :func:`~discord.utils.find`. - - When multiple attributes are specified, they are checked using - logical AND, not logical OR. Meaning they have to meet every - attribute passed in and not one of them. - - To have a nested attribute search (i.e. search by ``x.y``) then - pass in ``x__y`` as the keyword argument. - - If nothing is found that matches the attributes passed, then - ``None`` is returned. - - Examples - --------- - - Basic usage: - - .. code-block:: python3 - - member = discord.utils.get(message.guild.members, name="Foo") - - Multiple attribute matching: - - .. code-block:: python3 - - channel = discord.utils.get(guild.voice_channels, name="Foo", bitrate=64000) - - Nested attribute matching: - - .. code-block:: python3 - - channel = discord.utils.get(client.get_all_channels(), guild__name="Cool", name="general") - - Parameters - ----------- - iterable - An iterable to search through. - \*\*attrs - Keyword arguments that denote attributes to search with. - """ - - # global -> local - _all = all - attrget = attrgetter - - # Special case the single element call - if len(attrs) == 1: - k, v = attrs.popitem() - pred = attrget(k.replace("__", ".")) - for elem in iterable: - if pred(elem) == v: - return elem - return None - - converted = [(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()] - - for elem in iterable: - if _all(pred(elem) == value for pred, value in converted): - return elem - return None - - -async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: - """|coro| - - Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. - If the fetch also fails, an error will be raised. - - Parameters - ---------- - obj: Any - The object to use the get or fetch methods in - attr: :class:`str` - The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. - id: :class:`int` - The ID of the object - default: Any - The default value to return if the object is not found, instead of raising an error. - - Returns - ------- - Any - The object found or the default value. - - Raises - ------ - :exc:`AttributeError` - The object is missing a ``get_`` or ``fetch_`` method - :exc:`NotFound` - Invalid ID for the object - :exc:`HTTPException` - An error occurred fetching the object - :exc:`Forbidden` - You do not have permission to fetch the object - - Examples - -------- - - Getting a guild from a guild ID: :: - - guild = await utils.get_or_fetch(client, "guild", guild_id) - - Getting a channel from the guild. If the channel is not found, return None: :: - - channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) - """ - getter = getattr(obj, f"get_{attr}")(id) - if getter is None: - try: - getter = await getattr(obj, f"fetch_{attr}")(id) - except AttributeError: - getter = await getattr(obj, f"_fetch_{attr}")(id) - if getter is None: - raise ValueError(f"Could not find {attr} with id {id} on {obj}") - except (HTTPException, ValueError): - if default is not MISSING: - return default - else: - raise - return getter - - -def _unique(iterable: Iterable[T]) -> list[T]: - return [x for x in dict.fromkeys(iterable)] - - -def _get_as_snowflake(data: Any, key: str) -> int | None: - try: - value = data[key] - except KeyError: - return None - else: - return value and int(value) - - -def _get_mime_type_for_image(data: bytes): - if data.startswith(b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"): - return "image/png" - elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): - return "image/jpeg" - elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")): - return "image/gif" - elif data.startswith(b"RIFF") and data[8:12] == b"WEBP": - return "image/webp" - else: - raise InvalidArgument("Unsupported image type given") - - -def _bytes_to_base64_data(data: bytes) -> str: - fmt = "data:{mime};base64,{data}" - mime = _get_mime_type_for_image(data) - b64 = b64encode(data).decode("ascii") - return fmt.format(mime=mime, data=b64) - - -if HAS_MSGSPEC: - - def _to_json(obj: Any) -> str: # type: ignore - return msgspec.json.encode(obj).decode("utf-8") - - _from_json = msgspec.json.decode # type: ignore - -else: - - def _to_json(obj: Any) -> str: - return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) - - _from_json = json.loads - - -def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: - reset_after: str | None = request.headers.get("X-Ratelimit-Reset-After") - if not use_clock and reset_after: - return float(reset_after) - utc = datetime.timezone.utc - now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) - return (reset - now).total_seconds() - - -async def maybe_coroutine(f, *args, **kwargs): - value = f(*args, **kwargs) - if _isawaitable(value): - return await value - else: - return value - - -async def async_all(gen, *, check=_isawaitable): - for elem in gen: - if check(elem): - elem = await elem - if not elem: - return False - return True - - -async def sane_wait_for(futures, *, timeout): - ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) - - if len(pending) != 0: - raise asyncio.TimeoutError() - - return done - - -def get_slots(cls: type[Any]) -> Iterator[str]: - for mro in reversed(cls.__mro__): - try: - yield from mro.__slots__ - except AttributeError: - continue - - -def compute_timedelta(dt: datetime.datetime): - if dt.tzinfo is None: - dt = dt.astimezone() - now = datetime.datetime.now(datetime.timezone.utc) - return max((dt - now).total_seconds(), 0) - - -async def sleep_until(when: datetime.datetime, result: T | None = None) -> T | None: - """|coro| - - Sleep until a specified time. - - If the time supplied is in the past this function will yield instantly. - - .. versionadded:: 1.3 - - Parameters - ---------- - when: :class:`datetime.datetime` - The timestamp in which to sleep until. If the datetime is naive then - it is assumed to be local time. - result: Any - If provided is returned to the caller when the coroutine completes. - """ - delta = compute_timedelta(when) - return await asyncio.sleep(delta, result) - - -def utcnow() -> datetime.datetime: - """A helper function to return an aware UTC datetime representing the current time. - - This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware - datetime, compared to the naive datetime in the standard library. - - .. versionadded:: 2.0 - - Returns - ------- - :class:`datetime.datetime` - The current aware datetime in UTC. - """ - return datetime.datetime.now(datetime.timezone.utc) - - -def valid_icon_size(size: int) -> bool: - """Icons must be power of 2 within [16, 4096].""" - return not size & (size - 1) and 4096 >= size >= 16 - - -class SnowflakeList(array.array): - """Internal data storage class to efficiently store a list of snowflakes. - - This should have the following characteristics: - - - Low memory usage - - O(n) iteration (obviously) - - O(n log n) initial creation if data is unsorted - - O(log n) search and indexing - - O(n) insertion - """ - - __slots__ = () - - if TYPE_CHECKING: - - def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... - - def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): - return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore - - def add(self, element: int) -> None: - i = bisect_left(self, element) - self.insert(i, element) - - def get(self, element: int) -> int | None: - i = bisect_left(self, element) - return self[i] if i != len(self) and self[i] == element else None - - def has(self, element: int) -> bool: - i = bisect_left(self, element) - return i != len(self) and self[i] == element - - -_IS_ASCII = re.compile(r"^[\x00-\x7f]+$") - - -def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: - """Returns string's width.""" - match = _IS_ASCII.match(string) - if match: - return match.endpos - - UNICODE_WIDE_CHAR_TYPE = "WFA" - func = unicodedata.east_asian_width - return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) - - -def resolve_invite(invite: Invite | str) -> str: - """ - Resolves an invite from a :class:`~discord.Invite`, URL or code. - - Parameters - ---------- - invite: Union[:class:`~discord.Invite`, :class:`str`] - The invite. - - Returns - ------- - :class:`str` - The invite code. - """ - from .invite import Invite # circular import # noqa: PLC0415 - - if isinstance(invite, Invite): - return invite.code - rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)" - m = re.match(rx, invite) - if m: - return m.group(1) - return invite - - -def resolve_template(code: Template | str) -> str: - """ - Resolves a template code from a :class:`~discord.Template`, URL or code. - - .. versionadded:: 1.4 - - Parameters - ---------- - code: Union[:class:`~discord.Template`, :class:`str`] - The code. - - Returns - ------- - :class:`str` - The template code. - """ - from .template import Template # circular import # noqa: PLC0415 - - if isinstance(code, Template): - return code.code - rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)" - m = re.match(rx, code) - if m: - return m.group(1) - return code - - -_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?(?:>>)?\s|{_MARKDOWN_ESCAPE_LINKS}" - -_MARKDOWN_ESCAPE_REGEX = re.compile( - rf"(?P{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", - re.MULTILINE | re.X, -) - -_URL_REGEX = r"(?P<[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])" - -_MARKDOWN_STOCK_REGEX = rf"(?P[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})" - - -def remove_markdown(text: str, *, ignore_links: bool = True) -> str: - """A helper function that removes markdown characters. - - .. versionadded:: 1.7 - - .. note:: - This function is not markdown aware and may remove meaning from the original text. For example, - if the input contains ``10 * 5`` then it will be converted into ``10 5``. - - Parameters - ---------- - text: :class:`str` - The text to remove markdown from. - ignore_links: :class:`bool` - Whether to leave links alone when removing markdown. For example, - if a URL in the text contains characters such as ``_`` then it will - be left alone. Defaults to ``True``. - - Returns - ------- - :class:`str` - The text with the markdown special characters removed. - """ - - def replacement(match): - groupdict = match.groupdict() - return groupdict.get("url", "") - - regex = _MARKDOWN_STOCK_REGEX - if ignore_links: - regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE) - - -def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: - r"""A helper function that escapes Discord's markdown. - - Parameters - ----------- - text: :class:`str` - The text to escape markdown from. - as_needed: :class:`bool` - Whether to escape the markdown characters as needed. This - means that it does not escape extraneous characters if it's - not necessary, e.g. ``**hello**`` is escaped into ``\*\*hello**`` - instead of ``\*\*hello\*\*``. Note however that this can open - you up to some clever syntax abuse. Defaults to ``False``. - ignore_links: :class:`bool` - Whether to leave links alone when escaping markdown. For example, - if a URL in the text contains characters such as ``_`` then it will - be left alone. This option is not supported with ``as_needed``. - Defaults to ``True``. - - Returns - -------- - :class:`str` - The text with the markdown special characters escaped with a slash. - """ - - if not as_needed: - - def replacement(match): - groupdict = match.groupdict() - is_url = groupdict.get("url") - if is_url: - return is_url - return f"\\{groupdict['markdown']}" - - regex = _MARKDOWN_STOCK_REGEX - if ignore_links: - regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE | re.X) - else: - text = re.sub(r"\\", r"\\\\", text) - return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text) - - -def escape_mentions(text: str) -> str: - """A helper function that escapes everyone, here, role, and user mentions. - - .. note:: - - This does not include channel mentions. - - .. note:: - - For more granular control over what mentions should be escaped - within messages, refer to the :class:`~discord.AllowedMentions` - class. - - Parameters - ---------- - text: :class:`str` - The text to escape mentions from. - - Returns - ------- - :class:`str` - The text with the mentions removed. - """ - return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text) - - -def raw_mentions(text: str) -> list[int]: - """Returns a list of user IDs matching ``<@user_id>`` in the string. - - .. versionadded:: 2.2 - - Parameters - ---------- - text: :class:`str` - The string to get user mentions from. - - Returns - ------- - List[:class:`int`] - A list of user IDs found in the string. - """ - return [int(x) for x in re.findall(r"<@!?([0-9]+)>", text)] - - -def raw_channel_mentions(text: str) -> list[int]: - """Returns a list of channel IDs matching ``<@#channel_id>`` in the string. - - .. versionadded:: 2.2 - - Parameters - ---------- - text: :class:`str` - The string to get channel mentions from. - - Returns - ------- - List[:class:`int`] - A list of channel IDs found in the string. - """ - return [int(x) for x in re.findall(r"<#([0-9]+)>", text)] - - -def raw_role_mentions(text: str) -> list[int]: - """Returns a list of role IDs matching ``<@&role_id>`` in the string. - - .. versionadded:: 2.2 - - Parameters - ---------- - text: :class:`str` - The string to get role mentions from. - - Returns - ------- - List[:class:`int`] - A list of role IDs found in the string. - """ - return [int(x) for x in re.findall(r"<@&([0-9]+)>", text)] - - -def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[list[T]]: - ret = [] - n = 0 - for item in iterator: - ret.append(item) - n += 1 - if n == max_size: - yield ret - ret = [] - n = 0 - if ret: - yield ret - - -async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[list[T]]: - ret = [] - n = 0 - async for item in iterator: - ret.append(item) - n += 1 - if n == max_size: - yield ret - ret = [] - n = 0 - if ret: - yield ret - - -@overload -def as_chunks(iterator: Iterator[T], max_size: int) -> Iterator[list[T]]: ... - - -@overload -def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[list[T]]: ... - - -def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[list[T]]: - """A helper function that collects an iterator into chunks of a given size. - - .. versionadded:: 2.0 - - .. warning:: - - The last chunk collected may not be as large as ``max_size``. - - Parameters - ---------- - iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] - The iterator to chunk, can be sync or async. - max_size: :class:`int` - The maximum chunk size. - - Returns - ------- - Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] - A new iterator which yields chunks of a given size. - """ - if max_size <= 0: - raise ValueError("Chunk sizes must be greater than 0.") - - if isinstance(iterator, AsyncIterator): - return _achunk(iterator, max_size) - return _chunk(iterator, max_size) - - -PY_310 = sys.version_info >= (3, 10) - - -def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]: - params = [] - literal_cls = type(Literal[0]) - for p in parameters: - if isinstance(p, literal_cls): - params.extend(p.__args__) - else: - params.append(p) - return tuple(params) - - -def normalise_optional_params(parameters: Iterable[Any]) -> tuple[Any, ...]: - none_cls = type(None) - return tuple(p for p in parameters if p is not none_cls) + (none_cls,) - - -def evaluate_annotation( - tp: Any, - globals: dict[str, Any], - locals: dict[str, Any], - cache: dict[str, Any], - *, - implicit_str: bool = True, -): - if isinstance(tp, ForwardRef): - tp = tp.__forward_arg__ - # ForwardRefs always evaluate their internals - implicit_str = True - - if implicit_str and isinstance(tp, str): - if tp in cache: - return cache[tp] - evaluated = eval(tp, globals, locals) - cache[tp] = evaluated - return evaluate_annotation(evaluated, globals, locals, cache) - - if hasattr(tp, "__args__"): - implicit_str = True - is_literal = False - args = tp.__args__ - if not hasattr(tp, "__origin__"): - if PY_310 and tp.__class__ is types.UnionType: # type: ignore - converted = Union[args] # type: ignore - return evaluate_annotation(converted, globals, locals, cache) - - return tp - if tp.__origin__ is Union: - try: - if args.index(type(None)) != len(args) - 1: - args = normalise_optional_params(tp.__args__) - except ValueError: - pass - if tp.__origin__ is Literal: - if not PY_310: - args = flatten_literal_params(tp.__args__) - implicit_str = False - is_literal = True - - evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args - ) - - if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): - raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") - - if evaluated_args == args: - return tp - - try: - return tp.copy_with(evaluated_args) - except AttributeError: - return tp.__origin__[evaluated_args] - - return tp - - -def resolve_annotation( - annotation: Any, - globalns: dict[str, Any], - localns: dict[str, Any] | None, - cache: dict[str, Any] | None, -) -> Any: - if annotation is None: - return type(None) - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - - locals = globalns if localns is None else localns - if cache is None: - cache = {} - return evaluate_annotation(annotation, globalns, locals, cache) - - -TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] - - -def format_dt(dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None) -> str: - """A helper function to format a :class:`datetime.datetime` for presentation within Discord. - - This allows for a locale-independent way of presenting data using Discord specific Markdown. - - +-------------+----------------------------+-----------------+ - | Style | Example Output | Description | - +=============+============================+=================+ - | t | 22:57 | Short Time | - +-------------+----------------------------+-----------------+ - | T | 22:57:58 | Long Time | - +-------------+----------------------------+-----------------+ - | d | 17/05/2016 | Short Date | - +-------------+----------------------------+-----------------+ - | D | 17 May 2016 | Long Date | - +-------------+----------------------------+-----------------+ - | f (default) | 17 May 2016 22:57 | Short Date Time | - +-------------+----------------------------+-----------------+ - | F | Tuesday, 17 May 2016 22:57 | Long Date Time | - +-------------+----------------------------+-----------------+ - | R | 5 years ago | Relative Time | - +-------------+----------------------------+-----------------+ - - Note that the exact output depends on the user's locale setting in the client. The example output - presented is using the ``en-GB`` locale. - - .. versionadded:: 2.0 - - Parameters - ---------- - dt: Union[:class:`datetime.datetime`, :class:`datetime.time`] - The datetime to format. - style: :class:`str` - The style to format the datetime with. - - Returns - ------- - :class:`str` - The formatted string. - """ - if isinstance(dt, datetime.time): - dt = datetime.datetime.combine(datetime.datetime.now(), dt) - if style is None: - return f"" - return f"" - - -def generate_snowflake(dt: datetime.datetime | None = None) -> int: - """Returns a numeric snowflake pretending to be created at the given date but more accurate and random - than :func:`time_snowflake`. If dt is not passed, it makes one from the current time using utcnow. - - Parameters - ---------- - dt: :class:`datetime.datetime` - A datetime object to convert to a snowflake. - If naive, the timezone is assumed to be local time. - - Returns - ------- - :class:`int` - The snowflake representing the time given. - """ - - dt = dt or utcnow() - return int(dt.timestamp() * 1000 - DISCORD_EPOCH) << 22 | 0x3FFFFF - - -V = Union[Iterable[OptionChoice], Iterable[str], Iterable[int], Iterable[float]] -AV = Awaitable[V] -Values = Union[V, Callable[[AutocompleteContext], Union[V, AV]], AV] -AutocompleteFunc = Callable[[AutocompleteContext], AV] -FilterFunc = Callable[[AutocompleteContext, Any], Union[bool, Awaitable[bool]]] - - -def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: - """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and - will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is - callable, it will be called with the AutocompleteContext. - - This is meant to be passed into the :attr:`discord.Option.autocomplete` attribute. - - Parameters - ---------- - values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] - Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a - single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. - filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] - An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: - the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. - - .. versionadded:: 2.7 - - Returns - ------- - Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] - A wrapped callback for the autocomplete. - - Examples - -------- - - Basic usage: - - .. code-block:: python3 - - Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"))) - - # or - - - async def autocomplete(ctx): - return "foo", "bar", "baz", ctx.interaction.user.name - - - Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) - - With filter parameter: - - .. code-block:: python3 - - Option( - str, - "color", - autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i), - ) - - .. versionadded:: 2.0 - - Note - ---- - Autocomplete cannot be used for options that have specified choices. - """ - - async def autocomplete_callback(ctx: AutocompleteContext) -> V: - _values = values # since we reassign later, python considers it local if we don't do this - - if callable(_values): - _values = _values(ctx) - if asyncio.iscoroutine(_values): - _values = await _values - - if filter is None: - - def _filter(ctx: AutocompleteContext, item: Any) -> bool: - item = getattr(item, "name", item) - return str(item).lower().startswith(str(ctx.value or "").lower()) - - gen = (val for val in _values if _filter(ctx, val)) - - elif asyncio.iscoroutinefunction(filter): - gen = (val for val in _values if await filter(ctx, val)) - - elif callable(filter): - gen = (val for val in _values if filter(ctx, val)) - - else: - raise TypeError("``filter`` must be callable.") - - return iter(itertools.islice(gen, 25)) - - return autocomplete_callback - - -def filter_params(params, **kwargs): - """A helper function to filter out and replace certain keyword parameters - - Parameters - ---------- - params: Dict[str, Any] - The initial parameters to filter. - **kwargs: Dict[str, Optional[str]] - Key to value pairs where the key's contents would be moved to the - value, or if the value is None, remove key's contents (see code example). - - Example - ------- - .. code-block:: python3 - - >>> params = {"param1": 12, "param2": 13} - >>> filter_params(params, param1="param3", param2=None) - {'param3': 12} - # values of 'param1' is moved to 'param3' - # and values of 'param2' are completely removed. - """ - for old_param, new_param in kwargs.items(): - if old_param in params: - if new_param is None: - params.pop(old_param) - else: - params[new_param] = params.pop(old_param) - - return params diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py new file mode 100644 index 0000000000..41e39a9a6c --- /dev/null +++ b/discord/utils/__init__.py @@ -0,0 +1,137 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import ( + Any, + AsyncIterator, + Iterator, + Mapping, + Protocol, + TypeVar, + Union, +) + +from ..errors import HTTPException +from .public import ( + basic_autocomplete, + generate_snowflake, + utcnow, + find, + snowflake_time, + oauth_url, + Undefined, + MISSING, + format_dt, + escape_mentions, + raw_mentions, + raw_channel_mentions, + raw_role_mentions, + remove_markdown, + escape_markdown, +) + +DISCORD_EPOCH = 1420070400000 + + +__all__ = ( + "oauth_url", + "snowflake_time", + "find", + "get_or_fetch", + "utcnow", + "remove_markdown", + "escape_markdown", + "escape_mentions", + "raw_mentions", + "raw_channel_mentions", + "raw_role_mentions", + "format_dt", + "generate_snowflake", + "basic_autocomplete", + "Undefined", + "MISSING", +) + + +async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: + """|coro| + + Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. + If the fetch also fails, an error will be raised. + + Parameters + ---------- + obj: Any + The object to use the get or fetch methods in + attr: :class:`str` + The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. + id: :class:`int` + The ID of the object + default: Any + The default value to return if the object is not found, instead of raising an error. + + Returns + ------- + Any + The object found or the default value. + + Raises + ------ + :exc:`AttributeError` + The object is missing a ``get_`` or ``fetch_`` method + :exc:`NotFound` + Invalid ID for the object + :exc:`HTTPException` + An error occurred fetching the object + :exc:`Forbidden` + You do not have permission to fetch the object + + Examples + -------- + + Getting a guild from a guild ID: :: + + guild = await utils.get_or_fetch(client, "guild", guild_id) + + Getting a channel from the guild. If the channel is not found, return None: :: + + channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) + """ + getter = getattr(obj, f"get_{attr}")(id) + if getter is None: + try: + getter = await getattr(obj, f"fetch_{attr}")(id) + except AttributeError: + getter = await getattr(obj, f"_fetch_{attr}")(id) + if getter is None: + raise ValueError(f"Could not find {attr} with id {id} on {obj}") + except (HTTPException, ValueError): + if default is not MISSING: + return default + else: + raise + return getter diff --git a/discord/utils/private.py b/discord/utils/private.py new file mode 100644 index 0000000000..491bb19e1e --- /dev/null +++ b/discord/utils/private.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import array +import asyncio +import collections.abc +import datetime +import functools +import re +import sys +import types +import unicodedata +import warnings +from _bisect import bisect_left +from base64 import b64encode +from inspect import isawaitable, signature +from typing import ( + TYPE_CHECKING, + Any, + overload, + Callable, + TypeVar, + ParamSpec, + Iterable, + Literal, + ForwardRef, + Union, + Coroutine, + Awaitable, + reveal_type, + Generic, + Sequence, + Iterator, +) + +from ..errors import InvalidArgument, HTTPException + +if TYPE_CHECKING: + from ..invite import Invite + from ..template import Template + +_IS_ASCII = re.compile(r"^[\x00-\x7f]+$") + +P = ParamSpec("P") +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) + + +def resolve_invite(invite: Invite | str) -> str: + """ + Resolves an invite from a :class:`~discord.Invite`, URL or code. + + Parameters + ---------- + invite: Union[:class:`~discord.Invite`, :class:`str`] + The invite. + + Returns + ------- + :class:`str` + The invite code. + """ + from ..invite import Invite # noqa: PLC0415 # circular import + + if isinstance(invite, Invite): + return invite.code + rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)" + m = re.match(rx, invite) + if m: + return m.group(1) + return invite + + +__all__ = ("resolve_invite",) + + +def get_as_snowflake(data: Any, key: str) -> int | None: + try: + value = data[key] + except KeyError: + return None + else: + return value and int(value) + + +def get_mime_type_for_image(data: bytes): + if data.startswith(b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"): + return "image/png" + elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): + return "image/jpeg" + elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")): + return "image/gif" + elif data.startswith(b"RIFF") and data[8:12] == b"WEBP": + return "image/webp" + else: + raise InvalidArgument("Unsupported image type given") + + +def bytes_to_base64_data(data: bytes) -> str: + fmt = "data:{mime};base64,{data}" + mime = get_mime_type_for_image(data) + b64 = b64encode(data).decode("ascii") + return fmt.format(mime=mime, data=b64) + + +def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: + reset_after: str | None = request.headers.get("X-Ratelimit-Reset-After") + if not use_clock and reset_after: + return float(reset_after) + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) + return (reset - now).total_seconds() + + +def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: + """Returns string's width.""" + match = _IS_ASCII.match(string) + if match: + return match.endpos + + UNICODE_WIDE_CHAR_TYPE = "WFA" + func = unicodedata.east_asian_width + return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) + + +def resolve_template(code: Template | str) -> str: + """ + Resolves a template code from a :class:`~discord.Template`, URL or code. + + .. versionadded:: 1.4 + + Parameters + ---------- + code: Union[:class:`~discord.Template`, :class:`str`] + The code. + + Returns + ------- + :class:`str` + The template code. + """ + from ..template import Template # noqa: PLC0415 # circular import + + if isinstance(code, Template): + return code.code + rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)" + m = re.match(rx, code) + if m: + return m.group(1) + return code + + +__all__ = ( + "resolve_invite", + "get_as_snowflake", + "get_mime_type_for_image", + "bytes_to_base64_data", + "parse_ratelimit_header", + "string_width", +) + + +@overload +def parse_time(timestamp: None) -> None: ... + + +@overload +def parse_time(timestamp: str) -> datetime.datetime: ... + + +@overload +def parse_time(timestamp: str | None) -> datetime.datetime | None: ... + + +def parse_time(timestamp: str | None) -> datetime.datetime | None: + """A helper function to convert an ISO 8601 timestamp to a datetime object. + + Parameters + ---------- + timestamp: Optional[:class:`str`] + The timestamp to convert. + + Returns + ------- + Optional[:class:`datetime.datetime`] + The converted datetime object. + """ + if timestamp: + return datetime.datetime.fromisoformat(timestamp) + return None + + +def warn_deprecated( + name: str, + instead: str | None = None, + since: str | None = None, + removed: str | None = None, + reference: str | None = None, + stacklevel: int = 3, +) -> None: + """Warn about a deprecated function, with the ability to specify details about the deprecation. Emits a + DeprecationWarning. + + Parameters + ---------- + name: str + The name of the deprecated function. + instead: Optional[:class:`str`] + A recommended alternative to the function. + since: Optional[:class:`str`] + The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where + the patch version is optional. + removed: Optional[:class:`str`] + The version in which the function is planned to be removed. This should be in the format + ``major.minor(.patch)``, where the patch version is optional. + reference: Optional[:class:`str`] + A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub + issue/PR. + stacklevel: :class:`int` + The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. + """ + warnings.simplefilter("always", DeprecationWarning) # turn off filter + message = f"{name} is deprecated" + if since: + message += f" since version {since}" + if removed: + message += f" and will be removed in version {removed}" + if instead: + message += f", consider using {instead} instead" + message += "." + if reference: + message += f" See {reference} for more information." + + warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning) + warnings.simplefilter("default", DeprecationWarning) # reset filter + + +def deprecated( + instead: str | None = None, + since: str | None = None, + removed: str | None = None, + reference: str | None = None, + stacklevel: int = 3, + *, + use_qualname: bool = True, +) -> Callable[[Callable[[P], T]], Callable[[P], T]]: + """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when + the decorated function is called. + + Parameters + ---------- + instead: Optional[:class:`str`] + A recommended alternative to the function. + since: Optional[:class:`str`] + The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where + the patch version is optional. + removed: Optional[:class:`str`] + The version in which the function is planned to be removed. This should be in the format + ``major.minor(.patch)``, where the patch version is optional. + reference: Optional[:class:`str`] + A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub + issue/PR. + stacklevel: :class:`int` + The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. + use_qualname: :class:`bool` + Whether to use the qualified name of the function in the deprecation warning. If ``False``, the short name of + the function will be used instead. For example, __qualname__ will display as ``Client.login`` while __name__ + will display as ``login``. Defaults to ``True``. + """ + + def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: + @functools.wraps(func) + def decorated(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecated( + name=func.__qualname__ if use_qualname else func.__name__, + instead=instead, + since=since, + removed=removed, + reference=reference, + stacklevel=stacklevel, + ) + return func(*args, **kwargs) + + return decorated + + return actual_decorator + + +PY_310 = sys.version_info >= (3, 10) + + +def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]: + params = [] + literal_cls = type(Literal[0]) + for p in parameters: + if isinstance(p, literal_cls): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + + +def normalise_optional_params(parameters: Iterable[Any]) -> tuple[Any, ...]: + none_cls = type(None) + return tuple(p for p in parameters if p is not none_cls) + (none_cls,) + + +def evaluate_annotation( + tp: Any, + globals: dict[str, Any], + locals: dict[str, Any], + cache: dict[str, Any], + *, + implicit_str: bool = True, +): + if isinstance(tp, ForwardRef): + tp = tp.__forward_arg__ + # ForwardRefs always evaluate their internals + implicit_str = True + + if implicit_str and isinstance(tp, str): + if tp in cache: + return cache[tp] + evaluated = eval(tp, globals, locals) + cache[tp] = evaluated + return evaluate_annotation(evaluated, globals, locals, cache) + + if hasattr(tp, "__args__"): + implicit_str = True + is_literal = False + args = tp.__args__ + if not hasattr(tp, "__origin__"): + if PY_310 and tp.__class__ is types.UnionType: # type: ignore + converted = Union[args] # type: ignore + return evaluate_annotation(converted, globals, locals, cache) + + return tp + if tp.__origin__ is Union: + try: + if args.index(type(None)) != len(args) - 1: + args = normalise_optional_params(tp.__args__) + except ValueError: + pass + if tp.__origin__ is Literal: + if not PY_310: + args = flatten_literal_params(tp.__args__) + implicit_str = False + is_literal = True + + evaluated_args = tuple( + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args + ) + + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + + if evaluated_args == args: + return tp + + try: + return tp.copy_with(evaluated_args) + except AttributeError: + return tp.__origin__[evaluated_args] + + return tp + + +def resolve_annotation( + annotation: Any, + globalns: dict[str, Any], + localns: dict[str, Any] | None, + cache: dict[str, Any] | None, +) -> Any: + if annotation is None: + return type(None) + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + + locals = globalns if localns is None else localns + if cache is None: + cache = {} + return evaluate_annotation(annotation, globalns, locals, cache) + + +def delay_task(delay: float, func: Coroutine): + async def inner_call(): + await asyncio.sleep(delay) + try: + await func + except HTTPException: + pass + + asyncio.create_task(inner_call()) + + +async def async_all(gen: Iterable[Any]) -> bool: + for elem in gen: + if isawaitable(elem): + elem = await elem + if not elem: + return False + return True + + +async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: + value = f(*args, **kwargs) + if isawaitable(value): + reveal_type(f) + return await value + return value + + +async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]: + ensured = [asyncio.ensure_future(fut) for fut in futures] + done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) + + if len(pending) != 0: + raise asyncio.TimeoutError() + + return done + + +class SnowflakeList(array.array): + """Internal data storage class to efficiently store a list of snowflakes. + + This should have the following characteristics: + + - Low memory usage + - O(n) iteration (obviously) + - O(n log n) initial creation if data is unsorted + - O(log n) search and indexing + - O(n) insertion + """ + + __slots__ = () + + if TYPE_CHECKING: + + def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... + + def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): + return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore + + def add(self, element: int) -> None: + i = bisect_left(self, element) + self.insert(i, element) + + def get(self, element: int) -> int | None: + i = bisect_left(self, element) + return self[i] if i != len(self) and self[i] == element else None + + def has(self, element: int) -> bool: + i = bisect_left(self, element) + return i != len(self) and self[i] == element + + +def copy_doc(original: Callable) -> Callable[[T], T]: + def decorator(overridden: T) -> T: + overridden.__doc__ = original.__doc__ + overridden.__signature__ = signature(original) # type: ignore + return overridden + + return decorator + + +class SequenceProxy(collections.abc.Sequence, Generic[T_co]): + """Read-only proxy of a Sequence.""" + + def __init__(self, proxied: Sequence[T_co]): + self.__proxied = proxied + + def __getitem__(self, idx: int) -> T_co: + return self.__proxied[idx] + + def __len__(self) -> int: + return len(self.__proxied) + + def __contains__(self, item: Any) -> bool: + return item in self.__proxied + + def __iter__(self) -> Iterator[T_co]: + return iter(self.__proxied) + + def __reversed__(self) -> Iterator[T_co]: + return reversed(self.__proxied) + + def index(self, value: Any, *args, **kwargs) -> int: + return self.__proxied.index(value, *args, **kwargs) + + def count(self, value: Any) -> int: + return self.__proxied.count(value) + + +class CachedSlotProperty(Generic[T, T_co]): + def __init__(self, name: str, function: Callable[[T], T_co]) -> None: + self.name = name + self.function = function + self.__doc__ = getattr(function, "__doc__") + + @overload + def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... + + @overload + def __get__(self, instance: T, owner: type[T]) -> T_co: ... + + def __get__(self, instance: T | None, owner: type[T]) -> Any: + if instance is None: + return self + + try: + return getattr(instance, self.name) + except AttributeError: + value = self.function(instance) + setattr(instance, self.name, value) + return value + + +def get_slots(cls: type[Any]) -> Iterator[str]: + for mro in reversed(cls.__mro__): + try: + yield from mro.__slots__ + except AttributeError: + continue + + +def cached_slot_property( + name: str, +) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: + def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: + return CachedSlotProperty(name, func) + + return decorator + + +try: + import msgspec + + def to_json(obj: Any) -> str: # type: ignore + return msgspec.json.encode(obj).decode("utf-8") + + from_json = msgspec.json.decode # type: ignore + +except ModuleNotFoundError: + import json + + def to_json(obj: Any) -> str: + return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) + + from_json = json.loads diff --git a/discord/utils/public.py b/discord/utils/public.py new file mode 100644 index 0000000000..dd476d50d5 --- /dev/null +++ b/discord/utils/public.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +import asyncio +import re +import datetime +from enum import Enum, auto +import itertools +from collections.abc import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Any, Literal, TypeVar + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..commands.context import AutocompleteContext + from ..commands.options import OptionChoice + from ..permissions import Permissions + + +T = TypeVar("T") + + +class Undefined(Enum): + MISSING = auto() + + def __bool__(self) -> Literal[False]: + return False + + +MISSING: Literal[Undefined.MISSING] = Undefined.MISSING + +DISCORD_EPOCH = 1420070400000 + + +def utcnow() -> datetime.datetime: + """A helper function to return an aware UTC datetime representing the current time. + + This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware + datetime, compared to the naive datetime in the standard library. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`datetime.datetime` + The current aware datetime in UTC. + """ + return datetime.datetime.now(datetime.timezone.utc) + + +V = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] +AV = Awaitable[V] +Values = V | Callable[["AutocompleteContext"], V | AV] | AV +AutocompleteFunc = Callable[["AutocompleteContext"], AV] +FilterFunc = Callable[["AutocompleteContext", Any], bool | Awaitable[bool]] + + +def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: + """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and + will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is + callable, it will be called with the AutocompleteContext. + + This is meant to be passed into the :attr:`discord.Option.autocomplete` attribute. + + Parameters + ---------- + values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] + Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a + single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. + filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] + An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: + the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. + + .. versionadded:: 2.7 + + Returns + ------- + Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] + A wrapped callback for the autocomplete. + + Examples + -------- + + Basic usage: + + .. code-block:: python3 + + Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"))) + + # or + + + async def autocomplete(ctx): + return "foo", "bar", "baz", ctx.interaction.user.name + + + Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) + + With filter parameter: + + .. code-block:: python3 + + Option( + str, + "color", + autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i), + ) + + .. versionadded:: 2.0 + + Note + ---- + Autocomplete cannot be used for options that have specified choices. + """ + + async def autocomplete_callback(ctx: AutocompleteContext) -> V: + _values = values # since we reassign later, python considers it local if we don't do this + + if callable(_values): + _values = _values(ctx) + if asyncio.iscoroutine(_values): + _values = await _values + + if filter is None: + + def _filter(ctx: AutocompleteContext, item: Any) -> bool: + item = getattr(item, "name", item) + return str(item).lower().startswith(str(ctx.value or "").lower()) + + gen = (val for val in _values if _filter(ctx, val)) + + elif asyncio.iscoroutinefunction(filter): + gen = (val for val in _values if await filter(ctx, val)) + + elif callable(filter): + gen = (val for val in _values if filter(ctx, val)) + + else: + raise TypeError("``filter`` must be callable.") + + return iter(itertools.islice(gen, 25)) + + return autocomplete_callback + + +def generate_snowflake( + dt: datetime.datetime | None = None, + *, + mode: Literal["boundary", "realistic"] = "boundary", + high: bool = False, +) -> int: + """Returns a numeric snowflake pretending to be created at the given date. + + This function can generate both realistic snowflakes (for general use) and + boundary snowflakes (for range queries). + + Parameters + ---------- + dt: :class:`datetime.datetime` + A datetime object to convert to a snowflake. + If naive, the timezone is assumed to be local time. + If None, uses current UTC time. + mode: :class:`str` + The type of snowflake to generate: + - "realistic": Creates a snowflake with random-like lower bits + - "boundary": Creates a snowflake for range queries (default) + high: :class:`bool` + Only used when mode="boundary". Whether to set the lower 22 bits + to high (True) or low (False). Default is False. + + Returns + ------- + :class:`int` + The snowflake representing the time given. + + Examples + -------- + # Generate realistic snowflake + snowflake = generate_snowflake(dt) + + # Generate boundary snowflakes + lower_bound = generate_snowflake(dt, mode="boundary", high=False) + upper_bound = generate_snowflake(dt, mode="boundary", high=True) + + # For inclusive ranges: + # Lower: generate_snowflake(dt, mode="boundary", high=False) - 1 + # Upper: generate_snowflake(dt, mode="boundary", high=True) + 1 + """ + dt = dt or utcnow() + discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) + + if mode == "realistic": + return (discord_millis << 22) | 0x3FFFFF + elif mode == "boundary": + return (discord_millis << 22) + (2**22 - 1 if high else 0) + else: + raise ValueError(f"Invalid mode '{mode}'. Must be 'realistic' or 'boundary'") + + +def snowflake_time(id: int) -> datetime.datetime: + """Converts a Discord snowflake ID to a UTC-aware datetime object. + + Parameters + ---------- + id: :class:`int` + The snowflake ID. + + Returns + ------- + :class:`datetime.datetime` + An aware datetime in UTC representing the creation time of the snowflake. + """ + timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 + return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + + +def oauth_url( + client_id: int | str, + *, + permissions: Permissions | Undefined = MISSING, + guild: Snowflake | Undefined = MISSING, + redirect_uri: str | Undefined = MISSING, + scopes: Iterable[str] | Undefined = MISSING, + disable_guild_select: bool = False, +) -> str: + """A helper function that returns the OAuth2 URL for inviting the bot + into guilds. + + Parameters + ---------- + client_id: Union[:class:`int`, :class:`str`] + The client ID for your bot. + permissions: :class:`~discord.Permissions` + The permissions you're requesting. If not given then you won't be requesting any + permissions. + guild: :class:`~discord.abc.Snowflake` + The guild to pre-select in the authorization screen, if available. + redirect_uri: :class:`str` + An optional valid redirect URI. + scopes: Iterable[:class:`str`] + An optional valid list of scopes. Defaults to ``('bot',)``. + + .. versionadded:: 1.7 + disable_guild_select: :class:`bool` + Whether to disallow the user from changing the guild dropdown. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`str` + The OAuth2 URL for inviting the bot into guilds. + """ + url = f"https://discord.com/oauth2/authorize?client_id={client_id}" + url += f"&scope={'+'.join(scopes or ('bot',))}" + if permissions is not MISSING: + url += f"&permissions={permissions.value}" + if guild is not MISSING: + url += f"&guild_id={guild.id}" + if redirect_uri is not MISSING: + from urllib.parse import urlencode # noqa: PLC0415 + + url += f"&response_type=code&{urlencode({'redirect_uri': redirect_uri})}" + if disable_guild_select: + url += "&disable_guild_select=true" + return url + + +TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] + + +def format_dt(dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None) -> str: + """A helper function to format a :class:`datetime.datetime` for presentation within Discord. + + This allows for a locale-independent way of presenting data using Discord specific Markdown. + + +-------------+----------------------------+-----------------+ + | Style | Example Output | Description | + +=============+============================+=================+ + | t | 22:57 | Short Time | + +-------------+----------------------------+-----------------+ + | T | 22:57:58 | Long Time | + +-------------+----------------------------+-----------------+ + | d | 17/05/2016 | Short Date | + +-------------+----------------------------+-----------------+ + | D | 17 May 2016 | Long Date | + +-------------+----------------------------+-----------------+ + | f (default) | 17 May 2016 22:57 | Short Date Time | + +-------------+----------------------------+-----------------+ + | F | Tuesday, 17 May 2016 22:57 | Long Date Time | + +-------------+----------------------------+-----------------+ + | R | 5 years ago | Relative Time | + +-------------+----------------------------+-----------------+ + + Note that the exact output depends on the user's locale setting in the client. The example output + presented is using the ``en-GB`` locale. + + .. versionadded:: 2.0 + + Parameters + ---------- + dt: Union[:class:`datetime.datetime`, :class:`datetime.time`] + The datetime to format. + style: :class:`str`R + The style to format the datetime with. + + Returns + ------- + :class:`str` + The formatted string. + """ + if isinstance(dt, datetime.time): + dt = datetime.datetime.combine(datetime.datetime.now(), dt) + if style is None: + return f"" + return f"" + + +MENTION_PATTERN = re.compile(r"@(everyone|here|[!&]?[0-9]{17,20})") + + +def escape_mentions(text: str) -> str: + """A helper function that escapes everyone, here, role, and user mentions. + + .. note:: + + This does not include channel mentions. + + .. note:: + + For more granular control over what mentions should be escaped + within messages, refer to the :class:`~discord.AllowedMentions` + class. + + Parameters + ---------- + text: :class:`str` + The text to escape mentions from. + + Returns + ------- + :class:`str` + The text with the mentions removed. + """ + return MENTION_PATTERN.sub("@\u200b\\1", text) + + +RAW_MENTION_PATTERN = re.compile(r"<@!?([0-9]+)>") + + +def raw_mentions(text: str) -> list[int]: + """Returns a list of user IDs matching ``<@user_id>`` in the string. + + .. versionadded:: 2.2 + + Parameters + ---------- + text: :class:`str` + The string to get user mentions from. + + Returns + ------- + List[:class:`int`] + A list of user IDs found in the string. + """ + return [int(x) for x in RAW_MENTION_PATTERN.findall(text)] + + +RAW_CHANNEL_PATTERN = re.compile(r"<#([0-9]+)>") + + +def raw_channel_mentions(text: str) -> list[int]: + """Returns a list of channel IDs matching ``<@#channel_id>`` in the string. + + .. versionadded:: 2.2 + + Parameters + ---------- + text: :class:`str` + The string to get channel mentions from. + + Returns + ------- + List[:class:`int`] + A list of channel IDs found in the string. + """ + return [int(x) for x in RAW_CHANNEL_PATTERN.findall(text)] + + +RAW_ROLE_PATTERN = re.compile(r"<@&([0-9]+)>") + + +def raw_role_mentions(text: str) -> list[int]: + """Returns a list of role IDs matching ``<@&role_id>`` in the string. + + .. versionadded:: 2.2 + + Parameters + ---------- + text: :class:`str` + The string to get role mentions from. + + Returns + ------- + List[:class:`int`] + A list of role IDs found in the string. + """ + return [int(x) for x in RAW_ROLE_PATTERN.findall(text)] + + +_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?(?:>>)?\s|{_MARKDOWN_ESCAPE_LINKS}" + +_MARKDOWN_ESCAPE_REGEX = re.compile( + rf"(?P{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", + re.MULTILINE | re.X, +) + +_URL_REGEX = r"(?P<[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])" + +_MARKDOWN_STOCK_REGEX = rf"(?P[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})" + + +def remove_markdown(text: str, *, ignore_links: bool = True) -> str: + """A helper function that removes markdown characters. + + .. versionadded:: 1.7 + + .. note:: + This function is not markdown aware and may remove meaning from the original text. For example, + if the input contains ``10 * 5`` then it will be converted into ``10 5``. + + Parameters + ---------- + text: :class:`str` + The text to remove markdown from. + ignore_links: :class:`bool` + Whether to leave links alone when removing markdown. For example, + if a URL in the text contains characters such as ``_`` then it will + be left alone. Defaults to ``True``. + + Returns + ------- + :class:`str` + The text with the markdown special characters removed. + """ + + def replacement(match): + groupdict = match.groupdict() + return groupdict.get("url", "") + + regex = _MARKDOWN_STOCK_REGEX + if ignore_links: + regex = f"(?:{_URL_REGEX}|{regex})" + return re.sub(regex, replacement, text, 0, re.MULTILINE) + + +def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: + r"""A helper function that escapes Discord's markdown. + + Parameters + ----------- + text: :class:`str` + The text to escape markdown from. + as_needed: :class:`bool` + Whether to escape the markdown characters as needed. This + means that it does not escape extraneous characters if it's + not necessary, e.g. ``**hello**`` is escaped into ``\*\*hello**`` + instead of ``\*\*hello\*\*``. Note however that this can open + you up to some clever syntax abuse. Defaults to ``False``. + ignore_links: :class:`bool` + Whether to leave links alone when escaping markdown. For example, + if a URL in the text contains characters such as ``_`` then it will + be left alone. This option is not supported with ``as_needed``. + Defaults to ``True``. + + Returns + -------- + :class:`str` + The text with the markdown special characters escaped with a slash. + """ + + if not as_needed: + + def replacement(match): + groupdict = match.groupdict() + is_url = groupdict.get("url") + if is_url: + return is_url + return f"\\{groupdict['markdown']}" + + regex = _MARKDOWN_STOCK_REGEX + if ignore_links: + regex = f"(?:{_URL_REGEX}|{regex})" + return re.sub(regex, replacement, text, 0, re.MULTILINE | re.X) + else: + text = re.sub(r"\\", r"\\\\", text) + return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text) + + +def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: + """A helper to return the first element found in the sequence + that meets the predicate. For example: :: + + member = discord.utils.find(lambda m: m.name == "Mighty", channel.guild.members) + + would find the first :class:`~discord.Member` whose name is 'Mighty' and return it. + If an entry is not found, then ``None`` is returned. + + This is different from :func:`py:filter` due to the fact it stops the moment it finds + a valid entry. + + Parameters + ---------- + predicate + A function that returns a boolean-like result. + seq: :class:`collections.abc.Iterable` + The iterable to search through. + """ + + for element in seq: + if predicate(element): + return element + return None diff --git a/discord/voice_client.py b/discord/voice_client.py index 09f166658b..2a79d11a12 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -56,6 +56,7 @@ from .player import AudioPlayer, AudioSource from .sinks import RawData, RecordingException, Sink from .utils import MISSING +from .utils.private import sane_wait_for if TYPE_CHECKING: from . import abc @@ -325,11 +326,7 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: ) return - self.endpoint, _, _ = endpoint.rpartition(":") - if self.endpoint.startswith("wss://"): - # Just in case, strip it off since we're going to add it later - self.endpoint = self.endpoint[6:] - + self.endpoint = endpoint.removeprefix("wss://") # This gets set later self.endpoint_ip = MISSING @@ -392,7 +389,7 @@ async def connect(self, *, reconnect: bool, timeout: float) -> None: await self.voice_connect() try: - await utils.sane_wait_for(futures, timeout=timeout) + await sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: await self.disconnect(force=True) raise @@ -467,8 +464,8 @@ async def poll_voice_ws(self, reconnect: bool) -> None: # The following close codes are undocumented, so I will document them here. # 1000 - normal closure (obviously) # 4014 - voice channel has been deleted. - # 4015 - voice server has crashed - if exc.code in (1000, 4015): + # 4015 - voice server has crashed, we should resume + if exc.code == 1000: _log.info( "Disconnecting from voice normally, close code %d.", exc.code, @@ -484,6 +481,19 @@ async def poll_voice_ws(self, reconnect: bool) -> None: _log.info("Reconnect was unsuccessful, disconnecting from voice normally...") await self.disconnect() break + if exc.code == 4015: + _log.info("Disconnected from voice, trying to resume...") + + try: + await self.ws.resume() + except asyncio.TimeoutError: + _log.info("Could not resume the voice connection... Disconnection...") + if self._connected.is_set(): + await self.disconnect(force=True) + else: + _log.info("Successfully resumed voice connection") + continue + if not reconnect: await self.disconnect() raise diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 8256592d77..2f1fd4a780 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -26,6 +26,7 @@ from __future__ import annotations import asyncio +from collections.abc import Sequence import json import logging import re @@ -36,6 +37,9 @@ import aiohttp +from ..components import AnyComponent + +from ..utils.private import bytes_to_base64_data, get_as_snowflake, parse_ratelimit_header, to_json from .. import utils from ..asset import Asset from ..channel import ForumChannel, PartialMessageable @@ -80,7 +84,6 @@ from ..types.message import Message as MessagePayload from ..types.webhook import FollowerWebhook as FollowerWebhookPayload from ..types.webhook import Webhook as WebhookPayload - from ..ui.view import View MISSING = utils.MISSING @@ -134,7 +137,7 @@ async def request( if payload is not None: headers["Content-Type"] = "application/json" - to_send = utils._to_json(payload) + to_send = to_json(payload) if auth_token is not None: headers["Authorization"] = f"Bot {auth_token}" @@ -181,7 +184,7 @@ async def request( remaining = response.headers.get("X-Ratelimit-Remaining") if remaining == "0" and response.status != 429: - delta = utils._parse_ratelimit_header(response) + delta = parse_ratelimit_header(response) _log.debug( ("Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds"), webhook_id, @@ -327,12 +330,20 @@ def execute_webhook( multipart: list[dict[str, Any]] | None = None, files: list[File] | None = None, thread_id: int | None = None, + thread_name: str | None = None, + with_components: bool | None = None, wait: bool = False, ) -> Response[MessagePayload | None]: params = {"wait": int(wait)} if thread_id: params["thread_id"] = thread_id + if thread_name: + payload["thread_name"] = thread_name + + if with_components is not None: + params["with_components"] = int(with_components) + route = Route( "POST", "/webhooks/{webhook_id}/{webhook_token}", @@ -388,12 +399,16 @@ def edit_webhook_message( payload: dict[str, Any] | None = None, multipart: list[dict[str, Any]] | None = None, files: list[File] | None = None, + with_components: bool | None = None, ) -> Response[WebhookMessage]: params = {} if thread_id: params["thread_id"] = thread_id + if with_components is not None: + params["with_components"] = int(with_components) + route = Route( "PATCH", "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}", @@ -508,7 +523,7 @@ def create_interaction_response( ) if attachments: payload["data"]["attachments"] = attachments - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) route = Route( "POST", @@ -607,7 +622,7 @@ def handle_message_parameters( attachments: list[Attachment] | utils.Undefined = MISSING, embed: Embed | None | utils.Undefined = MISSING, embeds: list[Embed] | utils.Undefined = MISSING, - view: View | None | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, poll: Poll | None | utils.Undefined = MISSING, applied_tags: list[Snowflake] | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | None | utils.Undefined = MISSING, @@ -634,8 +649,18 @@ def handle_message_parameters( if attachments is not MISSING: _attachments = [a.to_dict() for a in attachments] - if view is not MISSING: - payload["components"] = view.to_components() if view is not None else [] + flags = MessageFlags( + suppress_embeds=suppress, + ephemeral=ephemeral, + ) + + if components is not MISSING: + payload["components"] = [] + if components: + for c in components: + payload["components"].append(c.to_dict()) + if c.any_is_v2()(): + flags.is_components_v2 = True if poll is not MISSING: payload["poll"] = poll.to_dict() payload["tts"] = tts @@ -644,11 +669,6 @@ def handle_message_parameters( if username: payload["username"] = username - flags = MessageFlags( - suppress_embeds=suppress, - ephemeral=ephemeral, - ) - if applied_tags is not MISSING: payload["applied_tags"] = applied_tags @@ -700,7 +720,7 @@ def handle_message_parameters( payload["thread_name"] = thread_name if multipart_files: - multipart.append({"name": "payload_json", "value": utils._to_json(payload)}) + multipart.append({"name": "payload_json", "value": to_json(payload)}) payload = None multipart += multipart_files @@ -844,7 +864,7 @@ async def edit( file: File | utils.Undefined = MISSING, files: list[File] | utils.Undefined = MISSING, attachments: list[Attachment] | utils.Undefined = MISSING, - view: View | None | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | None = None, suppress: bool | None | utils.Undefined = MISSING, ) -> WebhookMessage: @@ -883,11 +903,12 @@ async def edit( allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. - view: Optional[:class:`~discord.ui.View`] - The updated view to update this message with. If ``None`` is passed then - the view is removed. + components: Optional[Sequence[AnyComponent]] + A sequence of components to edit the message with. + If ``None`` is passed, then the components are cleared. + + .. versionadded:: 3.0 - .. versionadded:: 2.0 suppress: Optional[:class:`bool`] Whether to suppress embeds for the message. @@ -931,7 +952,7 @@ async def edit( file=file, files=files, attachments=attachments, - view=view, + components=components, allowed_mentions=allowed_mentions, thread=thread, suppress=suppress, @@ -1006,8 +1027,8 @@ def __init__( def _update(self, data: WebhookPayload | FollowerWebhookPayload): self.id = int(data["id"]) self.type = try_enum(WebhookType, int(data["type"])) - self.channel_id = utils._get_as_snowflake(data, "channel_id") - self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.channel_id = get_as_snowflake(data, "channel_id") + self.guild_id = get_as_snowflake(data, "guild_id") self.name = data.get("name") self._avatar = data.get("avatar") self.token = data.get("token") @@ -1493,7 +1514,7 @@ async def edit( payload["name"] = str(name) if name is not None else None if avatar is not MISSING: - payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None + payload["avatar"] = bytes_to_base64_data(avatar) if avatar is not None else None adapter = async_context.get() @@ -1568,7 +1589,7 @@ async def send( embed: Embed | utils.Undefined = MISSING, embeds: list[Embed] | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | utils.Undefined = MISSING, - view: View | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, poll: Poll | utils.Undefined = MISSING, thread: Snowflake | utils.Undefined = MISSING, thread_name: str | None = None, @@ -1591,7 +1612,7 @@ async def send( embed: Embed | utils.Undefined = MISSING, embeds: list[Embed] | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | utils.Undefined = MISSING, - view: View | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, poll: Poll | utils.Undefined = MISSING, thread: Snowflake | utils.Undefined = MISSING, thread_name: str | None | utils.Undefined = None, @@ -1613,7 +1634,7 @@ async def send( embed: Embed | utils.Undefined = MISSING, embeds: list[Embed] | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | utils.Undefined = MISSING, - view: View | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, poll: Poll | utils.Undefined = MISSING, thread: Snowflake | utils.Undefined = MISSING, thread_name: str | None = None, @@ -1655,8 +1676,6 @@ async def send( ephemeral: :class:`bool` Indicates if the message should only be visible to the user. This is only available to :attr:`WebhookType.application` webhooks. - If a view is sent with an ephemeral message, and it has no timeout set - then the timeout is set to 15 minutes. .. versionadded:: 2.0 file: :class:`File` @@ -1674,13 +1693,10 @@ async def send( Controls the mentions being processed in this message. .. versionadded:: 1.4 - view: :class:`discord.ui.View` - The view to send with the message. You can only send a view - if this webhook is not partial and has state attached. A - webhook has state attached if the webhook is managed by the - library. + components: Optional[Sequence[AnyComponent]] + A sequence of components to send with the message. - .. versionadded:: 2.0 + .. versionadded:: 3.0 thread: :class:`~discord.abc.Snowflake` The thread to send this webhook to. @@ -1721,8 +1737,8 @@ async def send( InvalidArgument Either there was no token associated with this webhook, ``ephemeral`` was passed with the improper webhook type, there was no state attached with this webhook when - giving it a view, you specified both ``thread_name`` and ``thread``, or ``applied_tags`` - was passed with neither ``thread_name`` nor ``thread`` specified. + giving it dispatchable components, you specified both ``thread_name`` and ``thread``, + or ``applied_tags`` was passed with neither ``thread_name`` nor ``thread`` specified. """ if self.token is None: @@ -1745,11 +1761,13 @@ async def send( if application_webhook: wait = True - if view is not MISSING: - if isinstance(self._state, _WebhookState): - raise InvalidArgument("Webhook views require an associated state with the webhook") - if ephemeral is True and view.timeout is None: - view.timeout = 15 * 60.0 + with_components = False + + if components is not MISSING: + if isinstance(self._state, _WebhookState) and components and any(c.any_is_dispatchable() for c in components): + raise InvalidArgument("Dispatchable Webhook components require an associated state with the webhook") + if not application_webhook: + with_components = True if poll is None: poll = MISSING @@ -1764,7 +1782,7 @@ async def send( embed=embed, embeds=embeds, ephemeral=ephemeral, - view=view, + components=components, poll=poll, applied_tags=applied_tags, allowed_mentions=allowed_mentions, @@ -1787,16 +1805,13 @@ async def send( files=params.files, thread_id=thread_id, wait=wait, + with_components=with_components, ) msg = None if wait: msg = self._create_message(data) - if view is not MISSING and not view.is_finished(): - message_id = None if msg is None else msg.id - view.message = None if msg is None else msg - self._state.store_view(view, message_id) if delete_after is not None: @@ -1868,7 +1883,7 @@ async def edit_message( file: File | utils.Undefined = MISSING, files: list[File] | utils.Undefined = MISSING, attachments: list[Attachment] | utils.Undefined = MISSING, - view: View | None | utils.Undefined = MISSING, + components: Sequence[AnyComponent] | None | utils.Undefined = MISSING, allowed_mentions: AllowedMentions | None = None, thread: Snowflake | None | utils.Undefined = MISSING, suppress: bool = False, @@ -1911,12 +1926,9 @@ async def edit_message( allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. - view: Optional[:class:`~discord.ui.View`] - The updated view to update this message with. If ``None`` is passed then - the view is removed. The webhook must have state attached, similar to - :meth:`send`. - - .. versionadded:: 2.0 + components: Optional[Sequence[AnyComponent]] + # TODO: docstring + .. versionadded:: 3.0 thread: Optional[:class:`~discord.abc.Snowflake`] The thread that contains the message. suppress: :class:`bool` @@ -1945,11 +1957,14 @@ async def edit_message( if self.token is None: raise InvalidArgument("This webhook does not have a token associated with it") - if view is not MISSING: - if isinstance(self._state, _WebhookState): - raise InvalidArgument("This webhook does not have state associated with it") + with_components = False + + if components is not MISSING: + if isinstance(self._state, _WebhookState) and components and any(c.any_is_dispatchable() for c in components): + raise InvalidArgument("Dispatchable Webhook components require an associated state with the webhook") - self._state.prevent_view_updates_for(message_id) + if self.type is not WebhookType.application: + with_components = True previous_mentions: AllowedMentions | None = getattr(self._state, "allowed_mentions", None) params = handle_message_parameters( @@ -1959,7 +1974,7 @@ async def edit_message( attachments=attachments, embed=embed, embeds=embeds, - view=view, + components=components, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, suppress=suppress, @@ -1981,12 +1996,10 @@ async def edit_message( payload=params.payload, multipart=params.multipart, files=params.files, + with_components=with_components, ) message = self._create_message(data) - if view and not view.is_finished(): - view.message = message - self._state.store_view(view, message_id) return message async def delete_message(self, message_id: int, *, thread_id: int | None = None) -> None: diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index 6808f745ac..57ee45657c 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -41,6 +41,7 @@ from urllib.parse import quote as urlquote from .. import utils +from ..utils.private import parse_ratelimit_header, bytes_to_base64_data, to_json from ..channel import PartialMessageable from ..errors import ( DiscordServerError, @@ -124,7 +125,7 @@ def request( if payload is not None: headers["Content-Type"] = "application/json" - to_send = utils._to_json(payload) + to_send = to_json(payload) if auth_token is not None: headers["Authorization"] = f"Bot {auth_token}" @@ -183,7 +184,7 @@ def request( remaining = response.headers.get("X-Ratelimit-Remaining") if remaining == "0" and response.status_code != 429: - delta = utils._parse_ratelimit_header(response) + delta = parse_ratelimit_header(response) _log.debug( ("Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds"), webhook_id, @@ -846,7 +847,7 @@ def edit( payload["name"] = str(name) if name is not None else None if avatar is not MISSING: - payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None + payload["avatar"] = bytes_to_base64_data(avatar) if avatar is not None else None adapter: WebhookAdapter = _get_webhook_adapter() diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py index f7d630d4aa..eda968becc 100644 --- a/discord/welcome_screen.py +++ b/discord/welcome_screen.py @@ -28,7 +28,8 @@ from typing import TYPE_CHECKING, overload from .partial_emoji import _EmojiTag -from .utils import _get_as_snowflake, get +from . import utils +from .utils.private import get_as_snowflake if TYPE_CHECKING: from .abc import Snowflake @@ -96,13 +97,13 @@ def to_dict(self) -> WelcomeScreenChannelPayload: @classmethod def _from_dict(cls, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeScreenChannel: - channel_id = _get_as_snowflake(data, "channel_id") + channel_id = get_as_snowflake(data, "channel_id") channel = guild.get_channel(channel_id) description = data.get("description") - _emoji_id = _get_as_snowflake(data, "emoji_id") + _emoji_id = get_as_snowflake(data, "emoji_id") _emoji_name = data.get("emoji_name") - emoji = get(guild.emojis, id=_emoji_id) if _emoji_id else _emoji_name + emoji = utils.find(lambda e: e.id == _emoji_id, guild.emojis) if _emoji_id else _emoji_name return cls(channel=channel, description=description, emoji=emoji) # type: ignore @@ -192,7 +193,7 @@ async def edit(self, **options): rules_channel = guild.get_channel(12345678) announcements_channel = guild.get_channel(87654321) - custom_emoji = utils.get(guild.emojis, name="loudspeaker") + custom_emoji = utils.find(lambda e: e.name == "loudspeaker", guild.emojis) await welcome_screen.edit( description="This is a very cool community server!", welcome_channels=[ diff --git a/discord/widget.py b/discord/widget.py index 004b87d66e..eb54e606b3 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -31,7 +31,8 @@ from .enums import Status, try_enum from .invite import Invite from .user import BaseUser -from .utils import _get_as_snowflake, resolve_invite, snowflake_time +from .utils import snowflake_time +from .utils.private import resolve_invite, get_as_snowflake if TYPE_CHECKING: import datetime @@ -265,7 +266,7 @@ def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: self.members: list[WidgetMember] = [] channels = {channel.id: channel for channel in self.channels} for member in data.get("members", []): - connected_channel = _get_as_snowflake(member, "channel_id") + connected_channel = get_as_snowflake(member, "channel_id") if connected_channel in channels: connected_channel = channels[connected_channel] # type: ignore elif connected_channel: diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css index 57ef457b9e..cd2b41d9a2 100644 --- a/docs/_static/css/custom.css +++ b/docs/_static/css/custom.css @@ -8,9 +8,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/saira/v8/memWYa2wxmKQyPMrZX79wwYZQMhsyuShhKMjjbU9uXuA773FksAxljYm.woff2) format("woff2"); - unicode-range: - U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, - U+1EA0-1EF9, U+20AB; + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, + U+01AF-01B0, U+1EA0-1EF9, U+20AB; } /* latin-ext */ @font-face { @@ -21,9 +20,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/saira/v8/memWYa2wxmKQyPMrZX79wwYZQMhsyuShhKMjjbU9uXuA773FksExljYm.woff2) format("woff2"); - unicode-range: - U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, - U+2C60-2C7F, U+A720-A7FF; + unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, + U+2113, U+2C60-2C7F, U+A720-A7FF; } /* latin */ @font-face { @@ -34,9 +32,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/saira/v8/memWYa2wxmKQyPMrZX79wwYZQMhsyuShhKMjjbU9uXuA773Fks8xlg.woff2) format("woff2"); - unicode-range: - U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, - U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, + U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* latin */ @font-face { @@ -46,9 +43,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/outfit/v4/QGYyz_MVcBeNP4NjuGObqx1XmO1I4W61O4a0Ew.woff2) format("woff2"); - unicode-range: - U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, - U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, + U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* latin */ @font-face { @@ -58,9 +54,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/outfit/v4/QGYyz_MVcBeNP4NjuGObqx1XmO1I4deyO4a0Ew.woff2) format("woff2"); - unicode-range: - U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, - U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, + U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* attribute tables */ diff --git a/docs/api/async_iter.rst b/docs/api/async_iter.rst index 22967bab28..a958b94cb3 100644 --- a/docs/api/async_iter.rst +++ b/docs/api/async_iter.rst @@ -33,17 +33,6 @@ Certain utilities make working with async iterators easier, detailed below. Advances the iterator by one, if possible. If no more items are found then this raises :exc:`NoMoreItems`. - .. method:: get(**attrs) - :async: - - |coro| - - Similar to :func:`utils.get` except run over the async iterator. - - Getting the last message by a user named 'Dave' or ``None``: :: - - msg = await channel.history().get(author__name='Dave') - .. method:: find(predicate) :async: diff --git a/docs/api/data_classes.rst b/docs/api/data_classes.rst index 1d891b90cd..2e2e814289 100644 --- a/docs/api/data_classes.rst +++ b/docs/api/data_classes.rst @@ -26,6 +26,16 @@ dynamic attributes in mind. .. autoclass:: SelectOption :members: +.. attributetable:: MediaGalleryItem + +.. autoclass:: SelectOption + :members: + +.. attributetable:: UnfurledMediaItem + +.. autoclass:: UnfurledMediaItem + :members: + .. attributetable:: Intents .. autoclass:: Intents diff --git a/docs/api/enums.rst b/docs/api/enums.rst index 4d278e3758..1210990717 100644 --- a/docs/api/enums.rst +++ b/docs/api/enums.rst @@ -2519,3 +2519,18 @@ of :class:`enum.Enum`. .. attribute:: inactive The subscription is inactive and the subscription owner is not being charged. + + +.. class:: SeparatorSpacingSize + + Represents the padding size around a separator component. + + .. versionadded:: 2.7 + + .. attribute:: small + + The separator uses small padding. + + .. attribute:: large + + The separator uses large padding. diff --git a/docs/api/models.rst b/docs/api/models.rst index cb702b2c38..ed8452ec1c 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -368,6 +368,9 @@ Interactions .. autoclass:: AuthorizingIntegrationOwners() :members: +Message Components +------------ + .. attributetable:: Component .. autoclass:: Component() @@ -390,6 +393,48 @@ Interactions :members: :inherited-members: +.. attributetable:: Section + +.. autoclass:: Section() + :members: + :inherited-members: + +.. attributetable:: TextDisplay + +.. autoclass:: TextDisplay() + :members: + :inherited-members: + +.. attributetable:: Thumbnail + +.. autoclass:: Thumbnail() + :members: + :inherited-members: + +.. attributetable:: MediaGallery + +.. autoclass:: MediaGallery() + :members: + :inherited-members: + +.. attributetable:: FileComponent + +.. autoclass:: FileComponent() + :members: + :inherited-members: + +.. attributetable:: Separator + +.. autoclass:: Separator() + :members: + :inherited-members: + +.. attributetable:: Container + +.. autoclass:: Container() + :members: + :inherited-members: + Emoji ----- diff --git a/docs/api/ui_kit.rst b/docs/api/ui_kit.rst index 18bb9de89b..ad2769eb03 100644 --- a/docs/api/ui_kit.rst +++ b/docs/api/ui_kit.rst @@ -55,6 +55,48 @@ Objects :members: :inherited-members: +.. attributetable:: discord.ui.Section + +.. autoclass:: discord.ui.Section + :members: + :inherited-members: + +.. attributetable:: discord.ui.TextDisplay + +.. autoclass:: discord.ui.TextDisplay + :members: + :inherited-members: + +.. attributetable:: discord.ui.Thumbnail + +.. autoclass:: discord.ui.Thumbnail + :members: + :inherited-members: + +.. attributetable:: discord.ui.MediaGallery + +.. autoclass:: discord.ui.MediaGallery + :members: + :inherited-members: + +.. attributetable:: discord.ui.File + +.. autoclass:: discord.ui.File + :members: + :inherited-members: + +.. attributetable:: discord.ui.Separator + +.. autoclass:: discord.ui.Separator + :members: + :inherited-members: + +.. attributetable:: discord.ui.Container + +.. autoclass:: discord.ui.Container + :members: + :inherited-members: + .. attributetable:: discord.ui.Modal .. autoclass:: discord.ui.Modal diff --git a/docs/api/utils.rst b/docs/api/utils.rst index e426c2f2d0..db6930cf0d 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -7,7 +7,6 @@ Utility Functions .. autofunction:: discord.utils.find -.. autofunction:: discord.utils.get .. autofunction:: discord.utils.get_or_fetch @@ -25,30 +24,12 @@ Utility Functions .. autofunction:: discord.utils.raw_role_mentions -.. autofunction:: discord.utils.resolve_invite - -.. autofunction:: discord.utils.resolve_template - -.. autofunction:: discord.utils.sleep_until - .. autofunction:: discord.utils.utcnow .. autofunction:: discord.utils.snowflake_time -.. autofunction:: discord.utils.parse_time - .. autofunction:: discord.utils.format_dt -.. autofunction:: discord.utils.time_snowflake - .. autofunction:: discord.utils.generate_snowflake .. autofunction:: discord.utils.basic_autocomplete - -.. autofunction:: discord.utils.as_chunks - -.. autofunction:: discord.utils.filter_params - -.. autofunction:: discord.utils.warn_deprecated - -.. autofunction:: discord.utils.deprecated diff --git a/docs/faq.rst b/docs/faq.rst index bf8bfe82d0..c1ba99f9ce 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -214,7 +214,7 @@ Quick example: :: await message.add_reaction(emoji) # no ID, do a lookup - emoji = discord.utils.get(guild.emojis, name='LUL') + emoji = discord.utils.find(lambda e: e.name == "LUL", guild.emojis) if emoji: await message.add_reaction(emoji) @@ -279,13 +279,13 @@ The following use an HTTP request: - :meth:`Guild.fetch_member` -If the functions above do not help you, then use of :func:`utils.find` or :func:`utils.get` would serve some use in finding +If the functions above do not help you, then use of :func:`utils.find` would serve some use in finding specific models. Quick example: :: # find a guild by name - guild = discord.utils.get(client.guilds, name='My Server') + guild = discord.utils.find(lambda g: g.name == "My Server", client.guilds) # make sure to check if it's found if guild is not None: diff --git a/docs/locales/en/LC_MESSAGES/api/utils.po b/docs/locales/en/LC_MESSAGES/api/utils.po index 54425aad81..c88c180521 100644 --- a/docs/locales/en/LC_MESSAGES/api/utils.po +++ b/docs/locales/en/LC_MESSAGES/api/utils.po @@ -613,7 +613,7 @@ msgstr "" msgid "The formatted string." msgstr "" -#: 3dbe398f94684d7d81111c3a9d78ddee discord.utils.time_snowflake:1 of +#: 3dbe398f94684d7d81111c3a9d78ddee discord.utils.:1 of msgid "Returns a numeric snowflake pretending to be created at the given date." msgstr "" diff --git a/examples/views/new_components.py b/examples/views/new_components.py new file mode 100644 index 0000000000..5f323c7bca --- /dev/null +++ b/examples/views/new_components.py @@ -0,0 +1,81 @@ +from io import BytesIO + +from discord import ( + ApplicationContext, + Bot, + ButtonStyle, + Color, + File, + Interaction, + SeparatorSpacingSize, + User, +) +from discord.ui import ( + Button, + Container, + MediaGallery, + Section, + Select, + Separator, + TextDisplay, + Thumbnail, + View, + button, +) + + +class MyView(View): + def __init__(self, user: User): + super().__init__(timeout=30) + text1 = TextDisplay("### This is a sample `TextDisplay` in a `Section`.") + text2 = TextDisplay( + "This section is contained in a `Container`.\nTo the right, you can see a `Thumbnail`." + ) + thumbnail = Thumbnail(user.display_avatar.url) + + section = Section(text1, text2, accessory=thumbnail) + section.add_text("-# Small text") + + container = Container( + section, + TextDisplay("Another `TextDisplay` separate from the `Section`."), + color=Color.blue(), + ) + container.add_separator(divider=True, spacing=SeparatorSpacingSize.large) + container.add_item(Separator()) + container.add_file("attachment://sample.png") + container.add_text("Above is two `Separator`s followed by a `File`.") + + gallery = MediaGallery() + gallery.add_item(user.default_avatar.url) + gallery.add_item(user.avatar.url) + + self.add_item(container) + self.add_item(gallery) + self.add_item( + TextDisplay("Above is a `MediaGallery` containing two `MediaGalleryItem`s.") + ) + + @button(label="Delete Message", style=ButtonStyle.red, id=200) + async def delete_button(self, button: Button, interaction: Interaction): + await interaction.response.defer(invisible=True) + await interaction.message.delete() + + async def on_timeout(self): + self.get_item(200).disabled = True + await self.message.edit(view=self) + + +bot = Bot() + + +@bot.command() +async def show_view(ctx: ApplicationContext): + """Display a sample View showcasing various new components.""" + + f = await ctx.author.display_avatar.read() + file = File(BytesIO(f), filename="sample.png") + await ctx.respond(view=MyView(ctx.author), files=[file]) + + +bot.run("TOKEN") diff --git a/pyproject.toml b/pyproject.toml index 5193a6a9d5..c45ea8f86d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ authors = [ {name = "Pycord Development"} ] description = "A Python wrapper for the Discord API" -readme = "README.rst" +readme = {content-type = "text/x-rst", file = "README.rst"} requires-python = ">=3.10" -license = {text = "MIT"} +license = "MIT" +license-files = ["LICENSE"] classifiers = [ "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: MIT License", +# "License-Expression :: MIT", "Intended Audience :: Developers", "Natural Language :: English", "Operating System :: OS Independent", @@ -31,7 +32,7 @@ dynamic = ["version"] dependencies = [ "aiohttp>=3.6.0,<4.0", "colorlog~=6.9.0", - "typing-extensions>=4,<5", + "typing-extensions>=4.5.0,<5", ] [project.urls] diff --git a/tests/test_utils.py b/tests/test_utils.py index d0f94acb93..61931451a0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,16 +30,9 @@ from discord.utils import ( MISSING, - _cached_property, - _parse_ratelimit_header, - _unique, - async_all, - copy_doc, find, - get, - maybe_coroutine, + generate_snowflake, snowflake_time, - time_snowflake, utcnow, ) @@ -79,23 +72,6 @@ def test_temporary(): # assert not MISSING # assert repr(MISSING) == '...' # -# -# def test_cached_property() -> None: -# class Test: -# def __init__(self, x: int): -# self.x = x -# -# @_cached_property -# def foo(self) -> int: -# self.x += 1 -# return self.x -# -# t = Test(0) -# assert isinstance(_cached_property.__get__(_cached_property(None), None, None), _cached_property) -# assert t.foo == 1 -# assert t.foo == 1 -# -# # def test_find_get() -> None: # class Obj: # def __init__(self, value: int): diff --git a/thing.py b/thing.py new file mode 100644 index 0000000000..fec2189e65 --- /dev/null +++ b/thing.py @@ -0,0 +1,16 @@ +import discord +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig(level=logging.INFO) + +bot = discord.Bot(intents=discord.Intents.default()) + + +@bot.command() +async def ping(ctx: discord.ApplicationContex) -> None: + await ctx.respond("slurp") diff --git a/uv.lock b/uv.lock index de60d64387..a5bf27e7c2 100644 --- a/uv.lock +++ b/uv.lock @@ -1336,7 +1336,7 @@ requires-dist = [ { name = "colorlog", specifier = "~=6.9.0" }, { name = "msgspec", marker = "extra == 'speed'", specifier = "~=0.19.0" }, { name = "pynacl", marker = "extra == 'voice'", specifier = ">=1.3.0,<1.6" }, - { name = "typing-extensions", specifier = ">=4,<5" }, + { name = "typing-extensions", specifier = ">=4.5.0,<5" }, ] provides-extras = ["speed", "voice"]