diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e43197fe9e..fa31ed1af3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude: \.(po|pot|yml|yaml)$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.9 + rev: v0.12.0 hooks: - id: ruff args: [ --fix ] diff --git a/CHANGELOG-V3.md b/CHANGELOG-V3.md index eb94407d2e..525f1c475a 100644 --- a/CHANGELOG-V3.md +++ b/CHANGELOG-V3.md @@ -12,3 +12,17 @@ release. ### Deprecated ### Removed + +- `utils.filter_params` +- `utils.sleep_until` use `asyncio.sleep` combined with `datetime.datetime` instead +- `utils.compute_timedelta` use the `datetime` module instead +- `utils.resolve_invite` +- `utils.resolve_template` +- `utils.parse_time` use `datetime.datetime.fromisoformat` instead +- `utils.time_snowflake` use `utils.generate_snowflake` instead +- `utils.warn_deprecated` +- `utils.deprecated` +- `utils.get` use `utils.find` with `lambda i: i.attr == val` instead +- `AsyncIterator.get` use `AsyncIterator.find` with `lambda i: i.attr == val` instead +- `utils.as_chunks` use `itertools.batched` on Python 3.12+ or your own implementation + instead diff --git a/discord/__version.py b/discord/__version.py index 88b3df1b5c..f7020d1ad8 100644 --- a/discord/__version.py +++ b/discord/__version.py @@ -35,7 +35,7 @@ from typing import Literal, NamedTuple -from .utils import deprecated +from .utils.private import deprecated from ._version import __version__, __version_tuple__ diff --git a/discord/abc.py b/discord/abc.py index c59a6276be..1eeb00d7d3 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -41,6 +41,7 @@ runtime_checkable, ) +from .utils.private import warn_deprecated from . import utils from .context_managers import Typing from .enums import ChannelType @@ -724,7 +725,7 @@ def permissions_for(self, obj: Member | Role, /) -> Permissions: if obj.is_default(): return base - overwrite = utils.get(self._overwrites, type=_Overwrites.ROLE, id=obj.id) + overwrite = utils.find(lambda o: o.type == _Overwrites.ROLE and o.id == obj.id, self._overwrites) if overwrite is not None: base.handle_overwrite(overwrite.allow, overwrite.deny) @@ -1529,7 +1530,7 @@ async def send( from .message import MessageReference # noqa: PLC0415 if not isinstance(reference, MessageReference): - utils.warn_deprecated( + warn_deprecated( f"Passing {type(reference).__name__} to reference", "MessageReference", "2.7", diff --git a/discord/activity.py b/discord/activity.py index 9cda6bc3f2..3377892aa8 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -32,7 +32,7 @@ from .colour import Colour from .enums import ActivityType, try_enum from .partial_emoji import PartialEmoji -from .utils import _get_as_snowflake +from .utils.private import get_as_snowflake __all__ = ( "BaseActivity", @@ -226,7 +226,7 @@ def __init__(self, **kwargs): self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {}) self.assets: ActivityAssets = kwargs.pop("assets", {}) self.party: ActivityParty = kwargs.pop("party", {}) - self.application_id: int | None = _get_as_snowflake(kwargs, "application_id") + self.application_id: int | None = get_as_snowflake(kwargs, "application_id") self.url: str | None = kwargs.pop("url", None) self.flags: int = kwargs.pop("flags", 0) self.sync_id: str | None = kwargs.pop("sync_id", None) diff --git a/discord/appinfo.py b/discord/appinfo.py index bfe7bc77c5..ab697c25a3 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING +from .utils.private import warn_deprecated, get_as_snowflake from . import utils from .asset import Asset from .permissions import Permissions @@ -200,9 +201,9 @@ def __init__(self, state: ConnectionState, data: AppInfoPayload): self._summary: str = data["summary"] self.verify_key: str = data["verify_key"] - self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + self.guild_id: int | None = get_as_snowflake(data, "guild_id") - self.primary_sku_id: int | None = utils._get_as_snowflake(data, "primary_sku_id") + self.primary_sku_id: int | None = get_as_snowflake(data, "primary_sku_id") self.slug: str | None = data.get("slug") self._cover_image: str | None = data.get("cover_image") self.terms_of_service_url: str | None = data.get("terms_of_service_url") @@ -261,7 +262,7 @@ def summary(self) -> str | None: .. versionadded:: 1.3 .. deprecated:: 2.7 """ - utils.warn_deprecated( + warn_deprecated( "summary", "description", reference="https://discord.com/developers/docs/resources/application#application-object-application-structure", diff --git a/discord/asset.py b/discord/asset.py index 3bc04faa7c..931e7bbcaf 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -47,6 +47,11 @@ MISSING = utils.MISSING +def _valid_icon_size(size: int) -> bool: + """Icons must be power of 2 within [16, 4096].""" + return not size & (size - 1) and 4096 >= size >= 16 + + class AssetMixin: url: str _state: Any | None @@ -371,7 +376,7 @@ def replace( url = url.with_path(f"{path}.{static_format}") if size is not MISSING: - if not utils.valid_icon_size(size): + if not _valid_icon_size(size): raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = url.with_query(size=size) else: @@ -398,7 +403,7 @@ def with_size(self, size: int, /) -> Asset: InvalidArgument The asset had an invalid size. """ - if not utils.valid_icon_size(size): + if not _valid_icon_size(size): raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = str(yarl.URL(self._url).with_query(size=size)) diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 65cbbfe6ec..068323f770 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -26,7 +26,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generator, TypeVar +from functools import cached_property +from .utils.private import get_as_snowflake from . import enums, utils from .asset import Asset from .automod import AutoModAction, AutoModTriggerMetadata @@ -559,8 +561,8 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: # into meaningful data when requested self._changes = data.get("changes", []) - self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore - self._target_id = utils._get_as_snowflake(data, "target_id") + self.user = self._get_member(get_as_snowflake(data, "user_id")) # type: ignore + self._target_id = get_as_snowflake(data, "target_id") def _get_member(self, user_id: int) -> Member | User | None: return self.guild.get_member(user_id) or self._users.get(user_id) @@ -568,12 +570,12 @@ def _get_member(self, user_id: int) -> Member | User | None: def __repr__(self) -> str: return f"" - @utils.cached_property + @cached_property def created_at(self) -> datetime.datetime: """Returns the entry's creation time in UTC.""" return utils.snowflake_time(self.id) - @utils.cached_property + @cached_property def target( self, ) -> ( @@ -597,24 +599,24 @@ def target( else: return converter(self._target_id) - @utils.cached_property + @property def category(self) -> enums.AuditLogActionCategory: """The category of the action, if applicable.""" return self.action.category - @utils.cached_property + @cached_property def changes(self) -> AuditLogChanges: """The list of changes this entry has.""" obj = AuditLogChanges(self, self._changes, state=self._state) del self._changes return obj - @utils.cached_property + @property def before(self) -> AuditLogDiff: """The target's prior state.""" return self.changes.before - @utils.cached_property + @property def after(self) -> AuditLogDiff: """The target's subsequent state.""" return self.changes.after diff --git a/discord/bot.py b/discord/bot.py index f22362350a..56ea3bb66c 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -63,7 +63,8 @@ from .shard import AutoShardedClient from .types import interactions from .user import User -from .utils import MISSING, async_all, find, get +from .utils import MISSING, find +from .utils.private import async_all if TYPE_CHECKING: from .member import Member @@ -216,13 +217,13 @@ def get_application_command( return command elif (names := name.split())[0] == command.name and isinstance(command, SlashCommandGroup): while len(names) > 1: - command = get(commands, name=names.pop(0)) + command = find(lambda c: c.name == names.pop(0), commands) if not isinstance(command, SlashCommandGroup) or ( guild_ids is not None and command.guild_ids != guild_ids ): return commands = command.subcommands - command = get(commands, name=names.pop()) + command = find(lambda c: c.name == names.pop(), commands) if not isinstance(command, type) or (guild_ids is not None and command.guild_ids != guild_ids): return return command @@ -357,7 +358,7 @@ def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool: # Now let's see if there are any commands on discord that we need to delete for cmd, value_ in registered_commands_dict.items(): - match = get(pending, name=value_["name"]) + match = find(lambda c: c.name == value_["name"], pending) if match is None: # We have this command registered but not in our list return_value.append( @@ -516,7 +517,7 @@ def register( ) continue # We can assume the command item is a command, since it's only a string if action is delete - match = get(pending, name=cmd["command"].name, type=cmd["command"].type) + match = find(lambda c: c.name == cmd["command"].name and c.type == cmd["command"].type, pending) if match is None: continue if cmd["action"] == "edit": @@ -605,10 +606,9 @@ def register( registered = await register("bulk", data, guild_id=guild_id) for i in registered: - cmd = get( + cmd = find( + lambda c: c.name == i["name"] and c.type == i.get("type"), self.pending_application_commands, - name=i["name"], - type=i.get("type"), ) if not cmd: raise ValueError(f"Registered command {i['name']}, type {i.get('type')} not found in pending commands") @@ -712,11 +712,9 @@ async def on_connect(): registered_guild_commands[guild_id] = app_cmds for i in registered_commands: - cmd = get( + cmd = find( + lambda c: c.name == i["name"] and c.guild_ids is None and c.type == i.get("type"), self.pending_application_commands, - name=i["name"], - guild_ids=None, - type=i.get("type"), ) if cmd: cmd.id = i["id"] diff --git a/discord/channel.py b/discord/channel.py index cff03c1614..7e4be01969 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -30,6 +30,7 @@ import discord.abc +from .utils.private import bytes_to_base64_data, get_as_snowflake, copy_doc from . import utils from .asset import Asset from .emoji import GuildEmoji @@ -157,7 +158,7 @@ def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> ForumTag self.moderated = data.get("moderated", False) emoji_name = data["emoji_name"] or "" - emoji_id = utils._get_as_snowflake(data, "emoji_id") or None + emoji_id = get_as_snowflake(data, "emoji_id") or None self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) return self @@ -219,7 +220,7 @@ def _update(self, guild: Guild, data: TextChannelPayload | ForumChannelPayload) # This data will always exist self.guild: Guild = guild self.name: str = data["name"] - self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self.category_id: int | None = get_as_snowflake(data, "parent_id") self._type: int = data["type"] # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): @@ -230,7 +231,7 @@ def _update(self, guild: Guild, data: TextChannelPayload | ForumChannelPayload) self.slowmode_delay: int = data.get("rate_limit_per_user", 0) self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440) self.default_thread_slowmode_delay: int | None = data.get("default_thread_rate_limit_per_user") - self.last_message_id: int | None = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: int | None = get_as_snowflake(data, "last_message_id") self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) self._fill_overwrites(data) @@ -243,7 +244,7 @@ def type(self) -> ChannelType: def _sorting_bucket(self) -> int: return ChannelType.text.value - @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + @copy_doc(discord.abc.GuildChannel.permissions_for) def permissions_for(self, obj: Member | Role, /) -> Permissions: base = super().permissions_for(obj) @@ -294,7 +295,7 @@ async def edit(self, **options) -> _TextChannel: """Edits the channel.""" raise NotImplementedError - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> TextChannel: return await self._clone_impl( { @@ -498,7 +499,7 @@ async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason from .webhook import Webhook # noqa: PLC0415 if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore + avatar = bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -1001,9 +1002,7 @@ def _update(self, guild: Guild, data: ForumChannelPayload) -> None: if emoji_name is not None: self.default_reaction_emoji = reaction_emoji_ctx["emoji_name"] else: - self.default_reaction_emoji = self._state.get_emoji( - utils._get_as_snowflake(reaction_emoji_ctx, "emoji_id") - ) + self.default_reaction_emoji = self._state.get_emoji(get_as_snowflake(reaction_emoji_ctx, "emoji_id")) @property def guidelines(self) -> str | None: @@ -1026,7 +1025,7 @@ def get_tag(self, id: int, /) -> ForumTag | None: .. versionadded:: 2.3 """ - return utils.get(self.available_tags, id=id) + return utils.find(lambda t: t.id == id, self.available_tags) @overload async def edit( @@ -1529,14 +1528,14 @@ def _update(self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload) # This data will always exist self.guild = guild self.name: str = data["name"] - self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self.category_id: int | None = get_as_snowflake(data, "parent_id") # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): rtc = data.get("rtc_region") self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) - self.last_message_id: int | None = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: int | None = get_as_snowflake(data, "last_message_id") self.position: int = data.get("position") self.slowmode_delay = data.get("rate_limit_per_user", 0) self.bitrate: int = data.get("bitrate") @@ -1582,7 +1581,7 @@ def voice_states(self) -> dict[int, VoiceState]: if value.channel and value.channel.id == self.id } - @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + @copy_doc(discord.abc.GuildChannel.permissions_for) def permissions_for(self, obj: Member | Role, /) -> Permissions: base = super().permissions_for(obj) @@ -1936,7 +1935,7 @@ async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason from .webhook import Webhook # noqa: PLC0415 if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore + avatar = bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -1946,7 +1945,7 @@ def type(self) -> ChannelType: """The channel's Discord type.""" return ChannelType.voice - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> VoiceChannel: return await self._clone_impl( {"bitrate": self.bitrate, "user_limit": self.user_limit}, @@ -2463,7 +2462,7 @@ async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason from .webhook import Webhook # noqa: PLC0415 if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore + avatar = bytes_to_base64_data(avatar) # type: ignore data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -2482,7 +2481,7 @@ def type(self) -> ChannelType: """The channel's Discord type.""" return ChannelType.stage_voice - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> StageChannel: return await self._clone_impl({}, name=name, reason=reason) @@ -2492,7 +2491,7 @@ def instance(self) -> StageInstance | None: .. versionadded:: 2.0 """ - return utils.get(self.guild.stage_instances, channel_id=self.id) + return utils.find(lambda s: s.channel_id == self.id, self.guild.stage_instances) async def create_instance( self, @@ -2723,7 +2722,7 @@ def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: # This data will always exist self.guild: Guild = guild self.name: str = data["name"] - self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") + self.category_id: int | None = get_as_snowflake(data, "parent_id") # This data may be missing depending on how this object is being created/updated if not data.pop("_invoke_flag", False): @@ -2745,7 +2744,7 @@ def is_nsfw(self) -> bool: """Checks if the category is NSFW.""" return self.nsfw - @utils.copy_doc(discord.abc.GuildChannel.clone) + @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @@ -2811,7 +2810,7 @@ async def edit(self, *, reason=None, **options): # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - @utils.copy_doc(discord.abc.GuildChannel.move) + @copy_doc(discord.abc.GuildChannel.move) async def move(self, **kwargs): kwargs.pop("category", None) await super().move(**kwargs) @@ -3113,7 +3112,7 @@ def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannel self._update_group(data) def _update_group(self, data: GroupChannelPayload) -> None: - self.owner_id: int | None = utils._get_as_snowflake(data, "owner_id") + self.owner_id: int | None = get_as_snowflake(data, "owner_id") self._icon: str | None = data.get("icon") self.name: str | None = data.get("name") self.recipients: list[User] = [self._state.store_user(u) for u in data.get("recipients", [])] diff --git a/discord/client.py b/discord/client.py index 4180182ace..9c4f754a4e 100644 --- a/discord/client.py +++ b/discord/client.py @@ -38,6 +38,7 @@ from discord.banners import print_banner, start_logging from . import utils +from .utils.private import resolve_invite, resolve_template, bytes_to_base64_data, SequenceProxy from .activity import ActivityTypes, BaseActivity, create_activity from .appinfo import AppInfo, PartialAppInfo from .application_role_connection import ApplicationRoleConnectionMetadata @@ -384,7 +385,7 @@ def cached_messages(self) -> Sequence[Message]: .. versionadded:: 1.1 """ - return utils.SequenceProxy(self._connection._messages or []) + return SequenceProxy(self._connection._messages or []) @property def private_channels(self) -> list[PrivateChannel]: @@ -1542,7 +1543,7 @@ async def fetch_template(self, code: Template | str) -> Template: :exc:`HTTPException` Getting the template failed. """ - code = utils.resolve_template(code) + code = resolve_template(code) data = await self.http.get_template(code) return Template(data=data, state=self._connection) # type: ignore @@ -1626,7 +1627,7 @@ async def create_guild( Invalid icon image format given. Must be PNG or JPG. """ if icon is not MISSING: - icon_base64 = utils._bytes_to_base64_data(icon) + icon_base64 = bytes_to_base64_data(icon) else: icon_base64 = None @@ -1718,7 +1719,7 @@ async def fetch_invite( Getting the invite failed. """ - invite_id = utils.resolve_invite(url) + invite_id = resolve_invite(url) data = await self.http.get_invite( invite_id, with_counts=with_counts, @@ -1750,7 +1751,7 @@ async def delete_invite(self, invite: Invite | str) -> None: Revoking the invite failed. """ - invite_id = utils.resolve_invite(invite) + invite_id = resolve_invite(invite) await self.http.delete_invite(invite_id) # Miscellaneous stuff @@ -2226,7 +2227,7 @@ async def create_emoji( The created emoji. """ - img = utils._bytes_to_base64_data(image) + img = bytes_to_base64_data(image) data = await self._connection.http.create_application_emoji(self.application_id, name, img) return self._connection.maybe_store_app_emoji(self.application_id, data) diff --git a/discord/commands/context.py b/discord/commands/context.py index a046afdde6..d4ddc8552a 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -53,7 +53,7 @@ from typing import Callable, Awaitable -from ..utils import cached_property +from ..utils.private import copy_doc T = TypeVar("T") CogT = TypeVar("CogT", bound="Cog") @@ -136,53 +136,53 @@ async def invoke( """ return await command(self, *args, **kwargs) - @cached_property + @property def channel(self) -> InteractionChannel | None: """Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]: Returns the channel associated with this context's command. Shorthand for :attr:`.Interaction.channel`. """ return self.interaction.channel - @cached_property + @property def channel_id(self) -> int | None: """Returns the ID of the channel associated with this context's command. Shorthand for :attr:`.Interaction.channel_id`. """ return self.interaction.channel_id - @cached_property + @property def guild(self) -> Guild | None: """Returns the guild associated with this context's command. Shorthand for :attr:`.Interaction.guild`. """ return self.interaction.guild - @cached_property + @property def guild_id(self) -> int | None: """Returns the ID of the guild associated with this context's command. Shorthand for :attr:`.Interaction.guild_id`. """ return self.interaction.guild_id - @cached_property + @property def locale(self) -> str | None: """Returns the locale of the guild associated with this context's command. Shorthand for :attr:`.Interaction.locale`. """ return self.interaction.locale - @cached_property + @property def guild_locale(self) -> str | None: """Returns the locale of the guild associated with this context's command. Shorthand for :attr:`.Interaction.guild_locale`. """ return self.interaction.guild_locale - @cached_property + @property def app_permissions(self) -> Permissions: return self.interaction.app_permissions - @cached_property + @property def me(self) -> Member | ClientUser | None: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message @@ -190,14 +190,14 @@ def me(self) -> Member | ClientUser | None: """ return self.interaction.guild.me if self.interaction.guild is not None else self.bot.user - @cached_property + @property def message(self) -> Message | None: """Returns the message sent with this context's command. Shorthand for :attr:`.Interaction.message`, if applicable. """ return self.interaction.message - @cached_property + @property def user(self) -> Member | User: """Returns the user that sent this context's command. Shorthand for :attr:`.Interaction.user`. @@ -216,7 +216,7 @@ def voice_client(self) -> VoiceClient | None: return self.interaction.guild.voice_client - @cached_property + @property def response(self) -> InteractionResponse: """Returns the response object associated with this context's command. Shorthand for :attr:`.Interaction.response`. @@ -258,17 +258,17 @@ def unselected_options(self) -> list[Option] | None: return None @property - @discord.utils.copy_doc(InteractionResponse.send_modal) + @copy_doc(InteractionResponse.send_modal) def send_modal(self) -> Callable[..., Awaitable[Interaction]]: return self.interaction.response.send_modal @property - @discord.utils.copy_doc(Interaction.respond) + @copy_doc(Interaction.respond) def respond(self, *args, **kwargs) -> Callable[..., Awaitable[Interaction | WebhookMessage]]: return self.interaction.respond @property - @discord.utils.copy_doc(InteractionResponse.send_message) + @copy_doc(InteractionResponse.send_message) def send_response(self) -> Callable[..., Awaitable[Interaction]]: if not self.interaction.response.is_done(): return self.interaction.response.send_message @@ -278,7 +278,7 @@ def send_response(self) -> Callable[..., Awaitable[Interaction]]: ) @property - @discord.utils.copy_doc(Webhook.send) + @copy_doc(Webhook.send) def send_followup(self) -> Callable[..., Awaitable[WebhookMessage]]: if self.interaction.response.is_done(): return self.followup.send @@ -288,7 +288,7 @@ def send_followup(self) -> Callable[..., Awaitable[WebhookMessage]]: ) @property - @discord.utils.copy_doc(InteractionResponse.defer) + @copy_doc(InteractionResponse.defer) def defer(self) -> Callable[..., Awaitable[None]]: return self.interaction.response.defer @@ -322,7 +322,7 @@ async def delete(self, *, delay: float | None = None) -> None: return await self.interaction.delete_original_response(delay=delay) @property - @discord.utils.copy_doc(Interaction.edit_original_response) + @copy_doc(Interaction.edit_original_response) def edit(self) -> Callable[..., Awaitable[InteractionMessage]]: return self.interaction.edit_original_response diff --git a/discord/commands/core.py b/discord/commands/core.py index 1920610826..af405bcecd 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -68,7 +68,8 @@ from ..role import Role from ..threads import Thread from ..user import User -from ..utils import MISSING, async_all, find, maybe_coroutine, utcnow, warn_deprecated +from ..utils import MISSING, find, utcnow +from ..utils.private import warn_deprecated, async_all, maybe_awaitable from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice @@ -432,7 +433,7 @@ async def can_run(self, ctx: ApplicationContext) -> bool: if cog is not None: local_check = cog._get_overridden_method(cog.cog_check) if local_check is not None: - ret = await maybe_coroutine(local_check, ctx) + ret = await maybe_awaitable(local_check, ctx) if not ret: return False diff --git a/discord/components.py b/discord/components.py index 0d9a927a66..b59517c828 100644 --- a/discord/components.py +++ b/discord/components.py @@ -29,7 +29,8 @@ from .enums import ButtonStyle, ChannelType, ComponentType, InputTextStyle, try_enum from .partial_emoji import PartialEmoji, _EmojiTag -from .utils import MISSING, Undefined, get_slots +from .utils import MISSING, Undefined +from .utils.private import get_slots if TYPE_CHECKING: from .emoji import AppEmoji, GuildEmoji diff --git a/discord/embeds.py b/discord/embeds.py index cab91a5176..b3194447d5 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -28,6 +28,7 @@ import datetime from typing import TYPE_CHECKING, Any, Mapping, TypeVar +from .utils.private import parse_time from . import utils from .colour import Colour @@ -437,7 +438,7 @@ def from_dict(cls: type[E], data: Mapping[str, Any]) -> E: pass try: - self._timestamp = utils.parse_time(data["timestamp"]) + self._timestamp = parse_time(data["timestamp"]) except KeyError: pass diff --git a/discord/emoji.py b/discord/emoji.py index 5dd1eafd3d..40df9ca252 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -30,7 +30,8 @@ from .asset import Asset, AssetMixin from .partial_emoji import PartialEmoji, _EmojiTag from .user import User -from .utils import MISSING, SnowflakeList, Undefined, snowflake_time +from .utils import MISSING, Undefined, snowflake_time +from .utils.private import SnowflakeList __all__ = ( "Emoji", diff --git a/discord/ext/bridge/core.py b/discord/ext/bridge/core.py index 1c3c5d3d1b..6f319f07a7 100644 --- a/discord/ext/bridge/core.py +++ b/discord/ext/bridge/core.py @@ -40,7 +40,8 @@ SlashCommandOptionType, ) -from ...utils import MISSING, find, get, warn_deprecated +from discord.utils import MISSING, find +from discord.utils.private import warn_deprecated from ..commands import ( BadArgument, ) @@ -608,7 +609,7 @@ async def convert(self, ctx, argument: str) -> Any: if self.choices: choices_names: list[str | int | float] = [choice.name for choice in self.choices] - if converted in choices_names and (choice := get(self.choices, name=converted)): + if converted in choices_names and (choice := find(lambda c: c.name == converted, self.choices)): converted = choice.value else: choices = [choice.value for choice in self.choices] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 2e038fbc92..09d7f7a22d 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -33,6 +33,7 @@ import discord from discord.utils import Undefined +from discord.utils.private import copy_doc, maybe_awaitable, async_all from . import errors from .context import Context @@ -132,7 +133,7 @@ def __init__( self.help_command = DefaultHelpCommand() if help_command is MISSING else help_command self.strip_after_prefix = options.get("strip_after_prefix", False) - @discord.utils.copy_doc(discord.Client.close) + @copy_doc(discord.Client.close) async def close(self) -> None: for extension in tuple(self.__extensions): try: @@ -179,7 +180,7 @@ async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: return True # type-checker doesn't distinguish between functions and methods - return await discord.utils.async_all(f(ctx) for f in data) # type: ignore + return await async_all(f(ctx) for f in data) # type: ignore # help command stuff @@ -223,7 +224,7 @@ async def get_prefix(self, message: Message) -> list[str] | str: """ prefix = ret = self.command_prefix if callable(prefix): - ret = await discord.utils.maybe_coroutine(prefix, self, message) + ret = await maybe_awaitable(prefix, self, message) if not isinstance(ret, str): try: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 79e2cfba6e..55e376c9c8 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -30,8 +30,9 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union import discord.abc -import discord.utils from discord.message import Message +from discord.utils.private import copy_doc +from discord.utils import Undefined, MISSING if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -50,8 +51,6 @@ __all__ = ("Context",) -MISSING: Any = discord.utils.MISSING - T = TypeVar("T") BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") @@ -125,12 +124,12 @@ def __init__( message: Message, bot: BotT, view: StringView, - args: list[Any] | discord.utils.Undefined = MISSING, - kwargs: dict[str, Any] | discord.utils.Undefined = MISSING, + args: list[Any] | Undefined = MISSING, + kwargs: dict[str, Any] | Undefined = MISSING, prefix: str | None = None, command: Command | None = None, invoked_with: str | None = None, - invoked_parents: list[str] | discord.utils.Undefined = MISSING, + invoked_parents: list[str] | Undefined = MISSING, invoked_subcommand: Command | None = None, subcommand_passed: str | None = None, command_failed: bool = False, @@ -281,28 +280,28 @@ def cog(self) -> Cog | None: return None return self.command.cog - @discord.utils.cached_property + @property def guild(self) -> Guild | None: """Returns the guild associated with this context's command. None if not available. """ return self.message.guild - @discord.utils.cached_property + @property def channel(self) -> MessageableChannel: """Returns the channel associated with this context's command. Shorthand for :attr:`.Message.channel`. """ return self.message.channel - @discord.utils.cached_property + @property def author(self) -> User | Member: """Union[:class:`~discord.User`, :class:`.Member`]: Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` """ return self.message.author - @discord.utils.cached_property + @property def me(self) -> Member | ClientUser: """Union[:class:`.Member`, :class:`.ClientUser`]: Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message @@ -398,10 +397,10 @@ async def send_help(self, *args: Any) -> Any: except CommandError as e: await cmd.on_help_command_error(self, e) - @discord.utils.copy_doc(Message.reply) + @copy_doc(Message.reply) async def reply(self, content: str | None = None, **kwargs: Any) -> Message: return await self.message.reply(content, **kwargs) - @discord.utils.copy_doc(Message.forward_to) + @copy_doc(Message.forward_to) async def forward_to(self, channel: discord.abc.Messageable, **kwargs: Any) -> Message: return await self.message.forward_to(channel, **kwargs) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index e9352e4fce..d34815001d 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -89,7 +89,6 @@ def _get_from_guilds(bot, getter, argument): return result -_utils_get = discord.utils.get T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) CT = TypeVar("CT", bound=discord.abc.GuildChannel) @@ -194,7 +193,7 @@ async def query_member_named(self, guild, argument): if len(argument) > 5 and argument[-5] == "#": username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get(members, name=username, discriminator=discriminator) + return discord.utils.find(lambda m: m.name == username and m.discriminator == discriminator, members) members = await guild.query_members(argument, limit=100, cache=cache) return discord.utils.find( lambda m: argument in (m.nick, m.name, m.global_name), @@ -239,7 +238,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Member: if guild: result = guild.get_member(user_id) if ctx.message is not None and result is None: - result = _utils_get(ctx.message.mentions, id=user_id) + result = discord.utils.find(lambda e: e.id == user_id, ctx.message.mentions) else: result = _get_from_guilds(bot, "get_member", user_id) @@ -287,7 +286,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.User: user_id = int(match.group(1)) result = ctx.bot.get_user(user_id) if ctx.message is not None and result is None: - result = _utils_get(ctx.message.mentions, id=user_id) + result = discord.utils.find(lambda e: e.id == user_id, ctx.message.mentions) if result is None: try: result = await ctx.bot.fetch_user(user_id) @@ -441,7 +440,7 @@ def _resolve_channel(ctx: Context, argument: str, attribute: str, type: type[CT] # not a mention if guild: iterable: Iterable[CT] = getattr(guild, attribute) - result: CT | None = discord.utils.get(iterable, name=argument) + result: CT | None = discord.utils.find(lambda e: e.name == argument, iterable) else: def check(c): @@ -470,7 +469,7 @@ def _resolve_thread(ctx: Context, argument: str, attribute: str, type: type[TT]) # not a mention if guild: iterable: Iterable[TT] = getattr(guild, attribute) - result: TT | None = discord.utils.get(iterable, name=argument) + result: TT | None = discord.utils.find(lambda e: e.name == argument, iterable) else: thread_id = int(match.group(1)) if guild: @@ -709,7 +708,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Role: if match: result = guild.get_role(int(match.group(1))) else: - result = discord.utils.get(guild._roles.values(), name=argument) + result = discord.utils.find(lambda e: e.name == argument, guild._roles.values()) if result is None: raise RoleNotFound(argument) @@ -760,7 +759,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Guild: result = ctx.bot.get_guild(guild_id) if result is None: - result = discord.utils.get(ctx.bot.guilds, name=argument) + result = discord.utils.find(lambda e: e.name == argument, ctx.bot.guilds) if result is None: raise GuildNotFound(argument) @@ -792,10 +791,10 @@ async def convert(self, ctx: Context, argument: str) -> discord.GuildEmoji: if match is None: # Try to get the emoji by name. Try local guild first. if guild: - result = discord.utils.get(guild.emojis, name=argument) + result = discord.utils.find(lambda e: e.name == argument, guild.emojis) if result is None: - result = discord.utils.get(bot.emojis, name=argument) + result = discord.utils.find(lambda e: e.name == argument, bot.emojis) else: emoji_id = int(match.group(1)) @@ -858,10 +857,10 @@ async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker: if match is None: # Try to get the sticker by name. Try local guild first. if guild: - result = discord.utils.get(guild.stickers, name=argument) + result = discord.utils.find(lambda s: s.name == argument, guild.stickers) if result is None: - result = discord.utils.get(bot.stickers, name=argument) + result = discord.utils.find(lambda s: s.name == argument, bot.stickers) else: sticker_id = int(match.group(1)) @@ -913,17 +912,23 @@ async def convert(self, ctx: Context, argument: str) -> str: if ctx.guild: def resolve_member(id: int) -> str: - m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_member(id) + m = ( + None if msg is None else discord.utils.find(lambda e: e.id == id, msg.mentions) + ) or ctx.guild.get_member(id) return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: - r = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_role(id) + r = ( + None if msg is None else discord.utils.find(lambda e: e.id == id, msg.mentions) + ) or ctx.guild.get_role(id) return f"@{r.name}" if r else "@deleted-role" else: def resolve_member(id: int) -> str: - m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.bot.get_user(id) + m = ( + None if msg is None else discord.utils.find(lambda e: e.id == id, msg.mentions) + ) or ctx.bot.get_user(id) return f"@{m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index ac2c7a1073..fd90615f5b 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -43,6 +43,7 @@ ) import discord +from discord.utils.private import evaluate_annotation, async_all, maybe_awaitable from discord import utils from discord.utils import Undefined @@ -138,7 +139,6 @@ def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, A signature = inspect.signature(function) params = {} cache: dict[str, Any] = {} - eval_annotation = discord.utils.evaluate_annotation for name, parameter in signature.parameters.items(): annotation = parameter.annotation if annotation is parameter.empty: @@ -148,7 +148,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: dict[str, A params[name] = parameter.replace(annotation=type(None)) continue - annotation = eval_annotation(annotation, globalns, globalns, cache) + annotation = evaluate_annotation(annotation, globalns, globalns, cache) if annotation is Greedy: raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") @@ -1146,7 +1146,7 @@ async def can_run(self, ctx: Context) -> bool: if cog is not None: local_check = Cog._get_overridden_method(cog.cog_check) if local_check is not None: - ret = await discord.utils.maybe_coroutine(local_check, ctx) + ret = await maybe_awaitable(local_check, ctx) if not ret: return False @@ -1155,7 +1155,7 @@ async def can_run(self, ctx: Context) -> bool: # since we have no checks, then we just return True. return True - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore + return await async_all(predicate(ctx) for predicate in predicates) # type: ignore finally: ctx.command = original @@ -1865,9 +1865,9 @@ def predicate(ctx: Context) -> bool: # ctx.guild is None doesn't narrow ctx.author to Member if isinstance(item, int): - role = discord.utils.get(ctx.author.roles, id=item) # type: ignore + role = discord.utils.find(lambda r: r.id == item, ctx.author.roles) # type: ignore else: - role = discord.utils.get(ctx.author.roles, name=item) # type: ignore + role = discord.utils.find(lambda r: r.name == item, ctx.author.roles) # type: ignore if role is None: raise MissingRole(item) return True @@ -1912,9 +1912,14 @@ def predicate(ctx): raise NoPrivateMessage() # ctx.guild is None doesn't narrow ctx.author to Member - getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore + getter = functools.partial(discord.utils.find, seq=ctx.author.roles) # type: ignore if any( - (getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None) for item in items + ( + getter(lambda e: e.id == item) is not None + if isinstance(item, int) + else getter(lambda e: e.name == item) is not None + ) + for item in items ): return True raise MissingAnyRole(list(items)) @@ -1942,9 +1947,9 @@ def predicate(ctx): me = ctx.me if isinstance(item, int): - role = discord.utils.get(me.roles, id=item) + role = discord.utils.find(lambda r: r.id == item, me.roles) else: - role = discord.utils.get(me.roles, name=item) + role = discord.utils.find(lambda r: r.name == item, me.roles) if role is None: raise BotMissingRole(item) return True @@ -1971,9 +1976,14 @@ def predicate(ctx): raise NoPrivateMessage() me = ctx.me - getter = functools.partial(discord.utils.get, me.roles) + getter = functools.partial(discord.utils.find, seq=me.roles) if any( - (getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None) for item in items + ( + getter(lambda e: e.id == item) is not None + if isinstance(item, int) + else getter(lambda e: e.name == item) is not None + ) + for item in items ): return True raise BotMissingAnyRole(list(items)) diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index d1d8b23fce..da95adf7cb 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -32,7 +32,8 @@ from typing import TYPE_CHECKING, Any, Iterator, Literal, Pattern, TypeVar, Union from discord import utils -from discord.utils import MISSING, Undefined, maybe_coroutine, resolve_annotation +from discord.utils import MISSING, Undefined +from discord.utils.private import resolve_annotation, maybe_awaitable from .converter import run_converters from .errors import ( @@ -489,7 +490,7 @@ async def _construct_default(cls: type[F], ctx: Context) -> F: flags = cls.__commands_flags__ for flag in flags.values(): if callable(flag.default): - default = await maybe_coroutine(flag.default, ctx) + default = await maybe_awaitable(flag.default, ctx) setattr(self, flag.attribute, default) else: setattr(self, flag.attribute, flag.default) @@ -588,7 +589,7 @@ async def convert(cls: type[F], ctx: Context, argument: str) -> F: raise MissingRequiredFlag(flag) else: if callable(flag.default): - default = await maybe_coroutine(flag.default, ctx) + default = await maybe_awaitable(flag.default, ctx) setattr(self, flag.attribute, default) else: setattr(self, flag.attribute, flag.default) diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index be6221fe80..94e5c96755 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -31,7 +31,8 @@ import re from typing import TYPE_CHECKING, Any -import discord.utils +from discord import utils +from discord.utils.private import string_width, maybe_awaitable from .core import Command, Group from .errors import CommandError @@ -335,7 +336,7 @@ def __init__(self, **options): self.command_attrs = attrs = options.pop("command_attrs", {}) attrs.setdefault("name", "help") attrs.setdefault("help", "Shows this message") - self.context: Context = discord.utils.MISSING + self.context: Context = utils.MISSING self._command_impl = _HelpCommandImpl(self, **self.command_attrs) def copy(self): @@ -570,12 +571,14 @@ async def filter_commands(self, commands, *, sort=False, key=None, exclude: tupl key = lambda c: c.name # Ignore Application Commands because they don't have hidden/docs + from discord import ApplicationCommand # noqa: PLC0415 + new_commands = [ command for command in commands if not isinstance( command, - (discord.commands.ApplicationCommand, *(exclude if exclude else ())), + (ApplicationCommand, *(exclude if exclude else ())), ) ] iterator = new_commands if self.show_hidden else filter(lambda c: not c.hidden, new_commands) @@ -620,7 +623,7 @@ def get_max_size(self, commands): The maximum width of the commands. """ - as_lengths = (discord.utils._string_width(c.name) for c in commands) + as_lengths = (string_width(c.name) for c in commands) return max(as_lengths, default=0) def get_destination(self): @@ -858,8 +861,6 @@ async def command_callback(self, ctx, *, command=None): if cog is not None: return await self.send_cog_help(cog) - maybe_coro = discord.utils.maybe_coroutine - # If it's not a cog then it's a command. # Since we want to have detailed errors when someone # passes an invalid subcommand, we need to walk through @@ -867,18 +868,18 @@ async def command_callback(self, ctx, *, command=None): keys = command.split(" ") cmd = bot.all_commands.get(keys[0]) if cmd is None: - string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0])) + string = await maybe_awaitable(self.command_not_found, self.remove_mentions(keys[0])) return await self.send_error_message(string) for key in keys[1:]: try: found = cmd.all_commands.get(key) except AttributeError: - string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) + string = await maybe_awaitable(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) else: if found is None: - string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) + string = await maybe_awaitable(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) cmd = found @@ -984,10 +985,9 @@ def add_indented_commands(self, commands, *, heading, max_size=None): self.paginator.add_line(heading) max_size = max_size or self.get_max_size(commands) - get_width = discord.utils._string_width for command in commands: name = command.name - width = max_size - (get_width(name) - len(name)) + width = max_size - (string_width(name) - len(name)) entry = f"{self.indent * ' '}{name:<{width}} {command.short_doc}" self.paginator.add_line(self.shorten_text(entry)) diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 8fecd9a85f..bbff6f3750 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -49,6 +49,13 @@ ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) +def compute_timedelta(dt: datetime.datetime): + if dt.tzinfo is None: + dt = dt.astimezone() + now = datetime.datetime.now(datetime.timezone.utc) + return max((dt - now).total_seconds(), 0) + + class SleepHandle: __slots__ = ("future", "loop", "handle") diff --git a/discord/gateway.py b/discord/gateway.py index 05f1277f7d..572a88d9d5 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -42,6 +42,7 @@ from .activity import BaseActivity from .enums import SpeakingState from .errors import ConnectionClosed, InvalidArgument +from .utils.private import from_json, to_json _log = logging.getLogger(__name__) @@ -450,7 +451,7 @@ async def received_message(self, msg, /): self._buffer = bytearray() self.log_receive(msg) - msg = utils._from_json(msg) + msg = from_json(msg) _log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg) event = msg.get("t") @@ -637,7 +638,7 @@ async def send(self, data, /): async def send_as_json(self, data): try: - await self.send(utils._to_json(data)) + await self.send(to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc @@ -645,7 +646,7 @@ async def send_as_json(self, data): async def send_heartbeat(self, data): # This bypasses the rate limit handling code since it has a higher priority try: - await self.socket.send_str(utils._to_json(data)) + await self.socket.send_str(to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc @@ -671,7 +672,7 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): }, } - sent = utils._to_json(payload) + sent = to_json(payload) _log.debug('Sending "%s" to change status', sent) await self.send(sent) @@ -774,7 +775,7 @@ async def _hook(self, *args): async def send_as_json(self, data): _log.debug("Sending voice websocket frame: %s.", data) - await self.ws.send_str(utils._to_json(data)) + await self.ws.send_str(to_json(data)) send_heartbeat = send_as_json @@ -931,7 +932,7 @@ async def poll_event(self): # This exception is handled up the chain msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) if msg.type is aiohttp.WSMsgType.TEXT: - await self.received_message(utils._from_json(msg.data)) + await self.received_message(from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data diff --git a/discord/guild.py b/discord/guild.py index 3b8ffcad37..96db69454d 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -40,6 +40,7 @@ overload, ) +from .utils.private import get_as_snowflake, bytes_to_base64_data from . import abc, utils from .asset import Asset from .automod import AutoModAction, AutoModRule, AutoModTriggerMetadata @@ -472,7 +473,7 @@ def _from_data(self, guild: GuildPayload) -> None: ) self.features: list[GuildFeature] = guild.get("features", []) self._splash: str | None = guild.get("splash") - self._system_channel_id: int | None = utils._get_as_snowflake(guild, "system_channel_id") + self._system_channel_id: int | None = get_as_snowflake(guild, "system_channel_id") self.description: str | None = guild.get("description") self.max_presences: int | None = guild.get("max_presences") self.max_members: int | None = guild.get("max_members") @@ -483,8 +484,8 @@ def _from_data(self, guild: GuildPayload) -> None: self._system_channel_flags: int = guild.get("system_channel_flags", 0) self.preferred_locale: str | None = guild.get("preferred_locale") self._discovery_splash: str | None = guild.get("discovery_splash") - self._rules_channel_id: int | None = utils._get_as_snowflake(guild, "rules_channel_id") - self._public_updates_channel_id: int | None = utils._get_as_snowflake(guild, "public_updates_channel_id") + self._rules_channel_id: int | None = get_as_snowflake(guild, "rules_channel_id") + self._public_updates_channel_id: int | None = get_as_snowflake(guild, "public_updates_channel_id") self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0)) self.approximate_presence_count = guild.get("approximate_presence_count") self.approximate_member_count = guild.get("approximate_member_count") @@ -510,8 +511,8 @@ def _from_data(self, guild: GuildPayload) -> None: self._sync(guild) self._large: bool | None = None if self._member_count is None else self._member_count >= 250 - self.owner_id: int | None = utils._get_as_snowflake(guild, "owner_id") - self.afk_channel: VoiceChannel | None = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore + self.owner_id: int | None = get_as_snowflake(guild, "owner_id") + self.afk_channel: VoiceChannel | None = self.get_channel(get_as_snowflake(guild, "afk_channel_id")) # type: ignore for obj in guild.get("voice_states", []): self._update_voice_state(obj, int(obj["channel_id"])) @@ -1019,7 +1020,7 @@ def get_member_named(self, name: str, /) -> Member | None: # do the actual lookup and return if found # if it isn't found then we'll do a full name lookup below. - result = utils.get(members, name=name[:-5], discriminator=potential_discriminator) + result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members) if result is not None: return result @@ -1713,24 +1714,24 @@ async def edit( fields["afk_timeout"] = afk_timeout if icon is not MISSING: - fields["icon"] = icon if icon is None else utils._bytes_to_base64_data(icon) + fields["icon"] = icon if icon is None else bytes_to_base64_data(icon) if banner is not MISSING: if banner is None: fields["banner"] = banner else: - fields["banner"] = utils._bytes_to_base64_data(banner) + fields["banner"] = bytes_to_base64_data(banner) if splash is not MISSING: if splash is None: fields["splash"] = splash else: - fields["splash"] = utils._bytes_to_base64_data(splash) + fields["splash"] = bytes_to_base64_data(splash) if discovery_splash is not MISSING: if discovery_splash is None: fields["discovery_splash"] = discovery_splash else: - fields["discovery_splash"] = utils._bytes_to_base64_data(discovery_splash) + fields["discovery_splash"] = bytes_to_base64_data(discovery_splash) if default_notifications is not MISSING: if not isinstance(default_notifications, NotificationLevel): @@ -2648,7 +2649,7 @@ async def create_custom_emoji( The created emoji. """ - img = utils._bytes_to_base64_data(image) + img = bytes_to_base64_data(image) role_ids = [role.id for role in roles] if roles else [] data = await self._state.http.create_custom_emoji(self.id, name, img, roles=role_ids, reason=reason) return self._state.store_emoji(self, data) @@ -2875,7 +2876,7 @@ async def create_role( if icon is None: fields["icon"] = None else: - fields["icon"] = utils._bytes_to_base64_data(icon) + fields["icon"] = bytes_to_base64_data(icon) fields["unicode_emoji"] = None if unicode_emoji is not MISSING: @@ -3681,7 +3682,7 @@ async def create_scheduled_event( payload["scheduled_end_time"] = end_time.isoformat() if image is not MISSING: - payload["image"] = utils._bytes_to_base64_data(image) + payload["image"] = bytes_to_base64_data(image) data = await self._state.http.create_scheduled_event(guild_id=self.id, reason=reason, **payload) event = ScheduledEvent(state=self._state, guild=self, creator=self.me, data=data) diff --git a/discord/http.py b/discord/http.py index 4600c045c6..dfddad0f2d 100644 --- a/discord/http.py +++ b/discord/http.py @@ -34,6 +34,7 @@ import aiohttp +from .utils.private import get_mime_type_for_image, to_json, from_json from . import __version__, utils from .errors import ( DiscordServerError, @@ -46,7 +47,8 @@ ) from .file import VoiceMessage from .gateway import DiscordClientWebSocketResponse -from .utils import MISSING, warn_deprecated +from .utils import MISSING +from .utils.private import warn_deprecated _log = logging.getLogger(__name__) @@ -96,7 +98,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str text = await response.text(encoding="utf-8") try: if response.headers["content-type"] == "application/json": - return utils._from_json(text) + return from_json(text) except KeyError: # Thanks Cloudflare pass @@ -259,7 +261,7 @@ async def request( # some checking if it's a JSON request if "json" in kwargs: headers["Content-Type"] = "application/json" - kwargs["data"] = utils._to_json(kwargs.pop("json")) + kwargs["data"] = to_json(kwargs.pop("json")) try: reason = kwargs.pop("reason") @@ -567,7 +569,7 @@ def send_multipart_helper( } ) payload["attachments"] = attachments - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) return self.request(route, form=form, files=files) def send_files( @@ -640,7 +642,7 @@ def edit_multipart_helper( payload["attachments"] = attachments else: payload["attachments"].extend(attachments) - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) return self.request(route, form=form, files=files) @@ -1240,7 +1242,7 @@ def start_forum_thread( ) payload["attachments"] = attachments - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) return self.request(route, form=form, reason=reason) return self.request(route, json=payload, reason=reason) @@ -1649,7 +1651,7 @@ def create_guild_sticker( initial_bytes = file.fp.read(16) try: - mime_type = utils._get_mime_type_for_image(initial_bytes) + mime_type = get_mime_type_for_image(initial_bytes) except InvalidArgument: if initial_bytes.startswith(b"{"): mime_type = "application/json" @@ -2587,7 +2589,7 @@ def _edit_webhook_helper( form: list[dict[str, Any]] = [ { "name": "payload_json", - "value": utils._to_json(payload), + "value": to_json(payload), } ] diff --git a/discord/integrations.py b/discord/integrations.py index 8bb1bd80e7..709ad1b586 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -31,7 +31,8 @@ from .enums import ExpireBehaviour, try_enum from .errors import InvalidArgument from .user import User -from .utils import MISSING, _get_as_snowflake, parse_time +from .utils import MISSING +from .utils.private import get_as_snowflake, parse_time from discord import utils __all__ = ( @@ -207,7 +208,7 @@ def _from_data(self, data: StreamIntegrationPayload) -> None: self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"]) self.expire_grace_period: int = data["expire_grace_period"] self.synced_at: datetime.datetime = parse_time(data["synced_at"]) - self._role_id: int | None = _get_as_snowflake(data, "role_id") + self._role_id: int | None = get_as_snowflake(data, "role_id") self.syncing: bool = data["syncing"] self.enable_emoticons: bool = data["enable_emoticons"] self.subscriber_count: int = data["subscriber_count"] diff --git a/discord/interactions.py b/discord/interactions.py index 48ed74e8fe..4a661b88a8 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -29,6 +29,7 @@ import datetime from typing import TYPE_CHECKING, Any, Coroutine, Union +from .utils.private import get_as_snowflake, deprecated, delay_task, cached_slot_property from . import utils from .channel import ChannelType, PartialMessageable, _threaded_channel_factory from .enums import ( @@ -200,8 +201,8 @@ def _from_data(self, data: InteractionPayload): self.data: InteractionData | None = data.get("data") self.token: str = data["token"] self.version: int = data["version"] - self.channel_id: int | None = utils._get_as_snowflake(data, "channel_id") - self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + self.channel_id: int | None = get_as_snowflake(data, "channel_id") + self.guild_id: int | None = get_as_snowflake(data, "guild_id") self.application_id: int = int(data["application_id"]) self.locale: str | None = data.get("locale") self.guild_locale: str | None = data.get("guild_locale") @@ -296,8 +297,8 @@ def is_component(self) -> bool: """Indicates whether the interaction is a message component.""" return self.type == InteractionType.component - @utils.cached_slot_property("_cs_channel") - @utils.deprecated("Interaction.channel", "2.7", stacklevel=4) + @cached_slot_property("_cs_channel") + @deprecated("Interaction.channel", "2.7", stacklevel=4) def cached_channel(self) -> InteractionChannel | None: """The cached channel from which the interaction was sent. DM channels are not resolved. These are :class:`PartialMessageable` instead. @@ -321,12 +322,12 @@ def permissions(self) -> Permissions: """ return Permissions(self._permissions) - @utils.cached_slot_property("_cs_app_permissions") + @cached_slot_property("_cs_app_permissions") def app_permissions(self) -> Permissions: """The resolved permissions of the application in the channel, including overwrites.""" return Permissions(self._app_permissions) - @utils.cached_slot_property("_cs_response") + @cached_slot_property("_cs_response") def response(self) -> InteractionResponse: """Returns an object responsible for handling responding to the interaction. @@ -335,7 +336,7 @@ def response(self) -> InteractionResponse: """ return InteractionResponse(self) - @utils.cached_slot_property("_cs_followup") + @cached_slot_property("_cs_followup") def followup(self) -> Webhook: """Returns the followup webhook for followup interactions.""" payload = { @@ -433,7 +434,7 @@ async def original_response(self) -> InteractionMessage: self._original_response = message return message - @utils.deprecated("Interaction.original_response", "2.2") + @deprecated("Interaction.original_response", "2.2") async def original_message(self): """An alias for :meth:`original_response`. @@ -560,7 +561,7 @@ async def edit_original_response( return message - @utils.deprecated("Interaction.edit_original_response", "2.2") + @deprecated("Interaction.edit_original_response", "2.2") async def edit_original_message(self, **kwargs): """An alias for :meth:`edit_original_response`. @@ -614,11 +615,11 @@ async def delete_original_response(self, *, delay: float | None = None) -> None: ) if delay is not None: - utils.delay_task(delay, func) + delay_task(delay, func) else: await func - @utils.deprecated("Interaction.delete_original_response", "2.2") + @deprecated("Interaction.delete_original_response", "2.2") async def delete_original_message(self, **kwargs): """An alias for :meth:`delete_original_response`. @@ -1243,7 +1244,7 @@ async def send_modal(self, modal: Modal) -> Interaction: self._parent._state.store_modal(modal, self._parent.user.id) return self._parent - @utils.deprecated("a button with type ButtonType.premium", "2.6") + @deprecated("a button with type ButtonType.premium", "2.6") async def premium_required(self) -> Interaction: """|coro| @@ -1532,8 +1533,8 @@ def __init__(self, *, data: InteractionMetadataPayload, state: ConnectionState): self.authorizing_integration_owners: AuthorizingIntegrationOwners = AuthorizingIntegrationOwners( data["authorizing_integration_owners"], state ) - self.original_response_message_id: int | None = utils._get_as_snowflake(data, "original_response_message_id") - self.interacted_message_id: int | None = utils._get_as_snowflake(data, "interacted_message_id") + self.original_response_message_id: int | None = get_as_snowflake(data, "original_response_message_id") + self.interacted_message_id: int | None = get_as_snowflake(data, "interacted_message_id") self.triggering_interaction_metadata: InteractionMetadata | None = None if tim := data.get("triggering_interaction_metadata"): self.triggering_interaction_metadata = InteractionMetadata(data=tim, state=state) @@ -1541,7 +1542,7 @@ def __init__(self, *, data: InteractionMetadataPayload, state: ConnectionState): def __repr__(self): return f"" - @utils.cached_slot_property("_cs_original_response_message") + @cached_slot_property("_cs_original_response_message") def original_response_message(self) -> Message | None: """Optional[:class:`Message`]: The original response message. Returns ``None`` if the message is not in cache, or if :attr:`original_response_message_id` is ``None``. @@ -1550,7 +1551,7 @@ def original_response_message(self) -> Message | None: return None return self._state._get_message(self.original_response_message_id) - @utils.cached_slot_property("_cs_interacted_message") + @cached_slot_property("_cs_interacted_message") def interacted_message(self) -> Message | None: """Optional[:class:`Message`]: The message that triggered the interaction. Returns ``None`` if the message is not in cache, or if :attr:`interacted_message_id` is ``None``. @@ -1596,7 +1597,7 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - @utils.cached_slot_property("_cs_user") + @cached_slot_property("_cs_user") def user(self) -> User | None: """Optional[:class:`User`]: The user that authorized the integration. Returns ``None`` if the user is not in cache, or if :attr:`user_id` is ``None``. @@ -1605,7 +1606,7 @@ def user(self) -> User | None: return None return self._state.get_user(self.user_id) - @utils.cached_slot_property("_cs_guild") + @cached_slot_property("_cs_guild") def guild(self) -> Guild | None: """Optional[:class:`Guild`]: The guild that authorized the integration. Returns ``None`` if the guild is not in cache, or if :attr:`guild_id` is ``0`` or ``None``. diff --git a/discord/invite.py b/discord/invite.py index b301a5c3e0..b7ef549cfa 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -32,7 +32,8 @@ from .enums import ChannelType, InviteTarget, VerificationLevel, try_enum from .mixins import Hashable from .object import Object -from .utils import _get_as_snowflake, parse_time, snowflake_time +from .utils import snowflake_time +from .utils.private import get_as_snowflake, parse_time __all__ = ( "PartialInviteChannel", @@ -413,7 +414,7 @@ def from_incomplete(cls: type[I], *, state: ConnectionState, data: InvitePayload @classmethod def from_gateway(cls: type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: - guild_id: int | None = _get_as_snowflake(data, "guild_id") + guild_id: int | None = get_as_snowflake(data, "guild_id") guild: Guild | Object | None = state._get_guild(guild_id) channel_id = int(data["channel_id"]) if guild is not None: diff --git a/discord/iterators.py b/discord/iterators.py index a5dc53286b..d04827c5ab 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -41,7 +41,8 @@ from .audit_logs import AuditLogEntry from .errors import NoMoreItems from .object import Object -from .utils import maybe_coroutine, snowflake_time, time_snowflake +from .utils import generate_snowflake, snowflake_time +from .utils.private import maybe_awaitable __all__ = ( "ReactionIterator", @@ -84,20 +85,6 @@ class _AsyncIterator(AsyncIterator[T]): async def next(self) -> T: raise NotImplementedError - def get(self, **attrs: Any) -> Awaitable[T | None]: - def predicate(elem: T): - for attr, val in attrs.items(): - nested = attr.split("__") - obj = elem - for attribute in nested: - obj = getattr(obj, attribute) - - if obj != val: - return False - return True - - return self.find(predicate) - async def find(self, predicate: _Func[T, bool]) -> T | None: while True: try: @@ -105,7 +92,7 @@ async def find(self, predicate: _Func[T, bool]) -> T | None: except NoMoreItems: return None - ret = await maybe_coroutine(predicate, elem) + ret = await maybe_awaitable(predicate, elem) if ret: return elem @@ -163,7 +150,7 @@ def __init__(self, iterator, func): async def next(self) -> T: # this raises NoMoreItems and will propagate appropriately item = await self.iterator.next() - return await maybe_coroutine(self.func, item) + return await maybe_awaitable(self.func, item) class _FilteredAsyncIterator(_AsyncIterator[T]): @@ -181,7 +168,7 @@ async def next(self) -> T: while True: # propagate NoMoreItems similar to _MappedAsyncIterator item = await getter() - ret = await maybe_coroutine(pred, item) + ret = await maybe_awaitable(pred, item) if ret: return item @@ -341,11 +328,11 @@ def __init__( oldest_first=None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) if isinstance(around, datetime.datetime): - around = Object(id=time_snowflake(around)) + around = Object(id=generate_snowflake(around)) self.reverse = after is not None if oldest_first is None else oldest_first self.messageable = messageable @@ -467,9 +454,9 @@ def __init__( action_type=None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.guild = guild self.loop = guild._state.loop @@ -578,9 +565,9 @@ class GuildIterator(_AsyncIterator["Guild"]): def __init__(self, bot, limit, before=None, after=None, with_counts=True): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.bot = bot self.limit = limit @@ -665,7 +652,7 @@ async def _retrieve_guilds_after_strategy(self, retrieve): class MemberIterator(_AsyncIterator["Member"]): def __init__(self, guild, limit=1000, after=None): if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.guild = guild self.limit = limit @@ -797,7 +784,7 @@ def __init__( self.before = None elif isinstance(before, datetime.datetime): if joined: - self.before = str(time_snowflake(before, high=False)) + self.before = str(generate_snowflake(before, high=False)) else: self.before = before.isoformat() else: @@ -873,9 +860,9 @@ def __init__( after: datetime.datetime | int | None = None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.event = event self.limit = limit @@ -967,9 +954,9 @@ def __init__( self.sku_ids = sku_ids if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.before = before self.after = after @@ -1081,9 +1068,9 @@ def __init__( user_id: int | None = None, ): if isinstance(before, datetime.datetime): - before = Object(id=time_snowflake(before, high=False)) + before = Object(id=generate_snowflake(before, high=False)) if isinstance(after, datetime.datetime): - after = Object(id=time_snowflake(after, high=True)) + after = Object(id=generate_snowflake(after, high=True)) self.state = state self.sku_id = sku_id diff --git a/discord/member.py b/discord/member.py index 0775f2670f..3d20c2a4fc 100644 --- a/discord/member.py +++ b/discord/member.py @@ -44,6 +44,7 @@ from .permissions import Permissions from .user import BaseUser, User, _UserTag from .utils import MISSING +from .utils.private import parse_time, SnowflakeList, copy_doc __all__ = ( "VoiceState", @@ -148,7 +149,7 @@ def _update( self.mute: bool = data.get("mute", False) self.deaf: bool = data.get("deaf", False) self.suppress: bool = data.get("suppress", False) - self.requested_to_speak_at: datetime.datetime | None = utils.parse_time(data.get("request_to_speak_timestamp")) + self.requested_to_speak_at: datetime.datetime | None = parse_time(data.get("request_to_speak_timestamp")) self.channel: VocalGuildChannel | None = channel def __repr__(self) -> str: @@ -200,7 +201,7 @@ def general(self, *args, **kwargs): return general func = generate_function(attr) - func = utils.copy_doc(value)(func) + func = copy_doc(value)(func) setattr(cls, attr, func) return cls @@ -309,16 +310,16 @@ def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: Connecti self._state: ConnectionState = state self._user: User = state.store_user(data["user"]) self.guild: Guild = guild - self.joined_at: datetime.datetime | None = utils.parse_time(data.get("joined_at")) - self.premium_since: datetime.datetime | None = utils.parse_time(data.get("premium_since")) - self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"])) + self.joined_at: datetime.datetime | None = parse_time(data.get("joined_at")) + self.premium_since: datetime.datetime | None = parse_time(data.get("premium_since")) + self._roles: SnowflakeList = SnowflakeList(map(int, data["roles"])) self._client_status: dict[str | None, str] = {None: "offline"} self.activities: tuple[ActivityTypes, ...] = () self.nick: str | None = data.get("nick", None) self.pending: bool = data.get("pending", False) self._avatar: str | None = data.get("avatar") self._banner: str | None = data.get("banner") - self.communication_disabled_until: datetime.datetime | None = utils.parse_time( + self.communication_disabled_until: datetime.datetime | None = parse_time( data.get("communication_disabled_until") ) self.flags: MemberFlags = MemberFlags._from_value(data.get("flags", 0)) @@ -359,9 +360,9 @@ def _from_message(cls: type[M], *, message: Message, data: MemberPayload) -> M: return cls(data=data, guild=message.guild, state=message._state) # type: ignore def _update_from_message(self, data: MemberPayload) -> None: - self.joined_at = utils.parse_time(data.get("joined_at")) - self.premium_since = utils.parse_time(data.get("premium_since")) - self._roles = utils.SnowflakeList(map(int, data["roles"])) + self.joined_at = parse_time(data.get("joined_at")) + self.premium_since = parse_time(data.get("premium_since")) + self._roles = SnowflakeList(map(int, data["roles"])) self.nick = data.get("nick", None) self.pending = data.get("pending", False) @@ -386,7 +387,7 @@ def _try_upgrade( def _copy(cls: type[M], member: M) -> M: self: M = cls.__new__(cls) # to bypass __init__ - self._roles = utils.SnowflakeList(member._roles, is_sorted=True) + self._roles = SnowflakeList(member._roles, is_sorted=True) self.joined_at = member.joined_at self.premium_since = member.premium_since self._client_status = member._client_status.copy() @@ -422,11 +423,11 @@ def _update(self, data: MemberPayload) -> None: except KeyError: pass - self.premium_since = utils.parse_time(data.get("premium_since")) - self._roles = utils.SnowflakeList(map(int, data["roles"])) + self.premium_since = parse_time(data.get("premium_since")) + self._roles = SnowflakeList(map(int, data["roles"])) self._avatar = data.get("avatar") self._banner = data.get("banner") - self.communication_disabled_until = utils.parse_time(data.get("communication_disabled_until")) + self.communication_disabled_until = parse_time(data.get("communication_disabled_until")) self.flags = MemberFlags._from_value(data.get("flags", 0)) def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> tuple[User, User] | None: @@ -1047,7 +1048,7 @@ async def add_roles(self, *roles: Snowflake, reason: str | None = None, atomic: """ if not atomic: - new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s) + new_roles = list({Object(id=r.id) for s in (self.roles[1:], roles) for r in s}) await self.edit(roles=new_roles, reason=reason) else: req = self._state.http.add_role diff --git a/discord/message.py b/discord/message.py index 30bcc9e4dc..cc7032c6ed 100644 --- a/discord/message.py +++ b/discord/message.py @@ -41,6 +41,7 @@ ) from urllib.parse import parse_qs, urlparse +from .utils.private import get_as_snowflake, parse_time, warn_deprecated, delay_task, cached_slot_property from . import utils from .channel import PartialMessageable from .components import _component_factory @@ -537,9 +538,9 @@ def __init__( def with_state(cls: type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR: self = cls.__new__(cls) self.type = try_enum(MessageReferenceType, data.get("type")) or MessageReferenceType.default - self.message_id = utils._get_as_snowflake(data, "message_id") - self.channel_id = utils._get_as_snowflake(data, "channel_id") - self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.message_id = get_as_snowflake(data, "message_id") + self.channel_id = get_as_snowflake(data, "channel_id") + self.guild_id = get_as_snowflake(data, "guild_id") self.fail_if_not_exists = data.get("fail_if_not_exists", True) self._state = state self.resolved = None @@ -630,7 +631,7 @@ class MessageCall: def __init__(self, state: ConnectionState, data: MessageCallPayload): self._state: ConnectionState = state self._participants: SnowflakeList = data.get("participants", []) - self._ended_timestamp: datetime.datetime | None = utils.parse_time(data["ended_timestamp"]) + self._ended_timestamp: datetime.datetime | None = parse_time(data["ended_timestamp"]) @property def participants(self) -> list[User | Object]: @@ -706,7 +707,7 @@ def __init__( self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) self.stickers: list[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])] self.components: list[Component] = [_component_factory(d) for d in data.get("components", [])] - self._edited_timestamp: datetime.datetime | None = utils.parse_time(data["edited_timestamp"]) + self._edited_timestamp: datetime.datetime | None = parse_time(data["edited_timestamp"]) @property def created_at(self) -> datetime.datetime: @@ -968,14 +969,14 @@ def __init__( self._state: ConnectionState = state self._raw_data: MessagePayload = data self.id: int = int(data["id"]) - self.webhook_id: int | None = utils._get_as_snowflake(data, "webhook_id") + self.webhook_id: int | None = get_as_snowflake(data, "webhook_id") self.reactions: list[Reaction] = [Reaction(message=self, data=d) for d in data.get("reactions", [])] self.attachments: list[Attachment] = [Attachment(data=a, state=self._state) for a in data["attachments"]] self.embeds: list[Embed] = [Embed.from_dict(a) for a in data["embeds"]] self.application: MessageApplicationPayload | None = data.get("application") self.activity: MessageActivityPayload | None = data.get("activity") self.channel: MessageableChannel = channel - self._edited_timestamp: datetime.datetime | None = utils.parse_time(data["edited_timestamp"]) + self._edited_timestamp: datetime.datetime | None = parse_time(data["edited_timestamp"]) self.type: MessageType = try_enum(MessageType, data["type"]) self.pinned: bool = data["pinned"] self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) @@ -990,7 +991,7 @@ def __init__( # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore except AttributeError: - self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id")) + self.guild = state._get_guild(get_as_snowflake(data, "guild_id")) try: ref = data["message_reference"] @@ -1149,7 +1150,7 @@ def _update(self, data): pass def _handle_edited_timestamp(self, value: str) -> None: - self._edited_timestamp = utils.parse_time(value) + self._edited_timestamp = parse_time(value) def _handle_pinned(self, value: bool) -> None: self.pinned = value @@ -1244,7 +1245,7 @@ def _rebind_cached_references(self, new_guild: Guild, new_channel: TextChannel | @property def interaction(self) -> MessageInteraction | None: - utils.warn_deprecated( + warn_deprecated( "interaction", "interaction_metadata", "2.6", @@ -1254,7 +1255,7 @@ def interaction(self) -> MessageInteraction | None: @interaction.setter def interaction(self, value: MessageInteraction | None) -> None: - utils.warn_deprecated( + warn_deprecated( "interaction", "interaction_metadata", "2.6", @@ -1262,7 +1263,7 @@ def interaction(self, value: MessageInteraction | None) -> None: ) self._interaction = value - @utils.cached_slot_property("_cs_raw_mentions") + @cached_slot_property("_cs_raw_mentions") def raw_mentions(self) -> list[int]: """A property that returns an array of user IDs matched with the syntax of ``<@user_id>`` in the message content. @@ -1272,28 +1273,28 @@ def raw_mentions(self) -> list[int]: """ return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)] - @utils.cached_slot_property("_cs_raw_channel_mentions") + @cached_slot_property("_cs_raw_channel_mentions") def raw_channel_mentions(self) -> list[int]: """A property that returns an array of channel IDs matched with the syntax of ``<#channel_id>`` in the message content. """ return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)] - @utils.cached_slot_property("_cs_raw_role_mentions") + @cached_slot_property("_cs_raw_role_mentions") def raw_role_mentions(self) -> list[int]: """A property that returns an array of role IDs matched with the syntax of ``<@&role_id>`` in the message content. """ return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)] - @utils.cached_slot_property("_cs_channel_mentions") + @cached_slot_property("_cs_channel_mentions") def channel_mentions(self) -> list[GuildChannel]: if self.guild is None: return [] it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) - return utils._unique(it) + return list(dict.fromkeys(it)) # using dict.fromkeys and not set to preserve order - @utils.cached_slot_property("_cs_clean_content") + @cached_slot_property("_cs_clean_content") def clean_content(self) -> str: """A property that returns the content in a "cleaned up" manner. This basically means that mentions are transformed @@ -1370,7 +1371,7 @@ def is_system(self) -> bool: MessageType.thread_starter_message, ) - @utils.cached_slot_property("_cs_system_content") + @cached_slot_property("_cs_system_content") def system_content(self) -> str: r"""A property that returns the content that is rendered regardless of the :attr:`Message.type`. @@ -1538,7 +1539,7 @@ async def delete(self, *, delay: float | None = None, reason: str | None = None) """ del_func = self._state.http.delete_message(self.channel.id, self.id, reason=reason) if delay is not None: - utils.delay_task(delay, del_func) + delay_task(delay, del_func) else: await del_func @@ -2184,7 +2185,7 @@ def created_at(self) -> datetime.datetime: def poll(self) -> Poll | None: return self._state._polls.get(self.id) - @utils.cached_slot_property("_cs_guild") + @cached_slot_property("_cs_guild") def guild(self) -> Guild | None: """The guild that the partial message belongs to, if applicable.""" return getattr(self.channel, "guild", None) diff --git a/discord/monetization.py b/discord/monetization.py index 8ca4f23c40..18b55894d0 100644 --- a/discord/monetization.py +++ b/discord/monetization.py @@ -31,7 +31,8 @@ from .flags import SKUFlags from .iterators import SubscriptionIterator from .mixins import Hashable -from .utils import MISSING, _get_as_snowflake, parse_time +from .utils import MISSING +from .utils.private import get_as_snowflake, parse_time if TYPE_CHECKING: from datetime import datetime @@ -226,12 +227,12 @@ def __init__(self, *, data: EntitlementPayload, state: ConnectionState) -> None: self.id: int = int(data["id"]) self.sku_id: int = int(data["sku_id"]) self.application_id: int = int(data["application_id"]) - self.user_id: int | MISSING = _get_as_snowflake(data, "user_id") or MISSING + self.user_id: int | MISSING = get_as_snowflake(data, "user_id") or MISSING self.type: EntitlementType = try_enum(EntitlementType, data["type"]) self.deleted: bool = data["deleted"] self.starts_at: datetime | MISSING = parse_time(data.get("starts_at")) or MISSING self.ends_at: datetime | MISSING | None = parse_time(ea) if (ea := data.get("ends_at")) is not None else MISSING - self.guild_id: int | MISSING = _get_as_snowflake(data, "guild_id") or MISSING + self.guild_id: int | MISSING = get_as_snowflake(data, "guild_id") or MISSING self.consumed: bool = data.get("consumed", False) def __repr__(self) -> str: diff --git a/discord/onboarding.py b/discord/onboarding.py index 868a5c8ae3..06d865128d 100644 --- a/discord/onboarding.py +++ b/discord/onboarding.py @@ -25,11 +25,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any +from functools import cached_property + +from discord import utils from .enums import OnboardingMode, PromptType, try_enum from .partial_emoji import PartialEmoji -from .utils import MISSING, cached_property, generate_snowflake, get -from discord import utils +from .utils import MISSING, generate_snowflake, find if TYPE_CHECKING: from .abc import Snowflake @@ -81,7 +83,7 @@ def __init__( id: int | None = None, ): # ID is required when making edits, but it can be any snowflake that isn't already used by another prompt during edits - self.id: int = int(id) if id else generate_snowflake() + self.id: int = int(id) if id else generate_snowflake(mode="realistic") self.title: str = title self.channels: list[Snowflake] = channels or [] self.roles: list[Snowflake] = roles or [] @@ -127,7 +129,7 @@ def _from_dict(cls, data: PromptOptionPayload, guild: Guild) -> PromptOption: # Emoji object is {'id': None, 'name': None, 'animated': False} ... emoji = PartialEmoji.from_dict(_emoji) if emoji.id: - emoji = get(guild.emojis, id=emoji.id) or emoji + emoji = find(lambda e: e.id == emoji.id, guild.emojis) or emoji else: emoji = None @@ -169,7 +171,7 @@ def __init__( id: int | None = None, # Currently optional as users can manually create these ): # ID is required when making edits, but it can be any snowflake that isn't already used by another prompt during edits - self.id: int = int(id) if id else generate_snowflake() + self.id: int = int(id) if id else generate_snowflake(mode="realistic") self.type: PromptType = type if isinstance(self.type, int): @@ -431,7 +433,7 @@ def get_prompt( The matching prompt, or None if it didn't exist. """ - return get(self.prompts, id=id) + return find(lambda p: p.id == id, self.prompts) async def delete_prompt( self, diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index 9bf304bc75..aacb0d8778 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Any, TypedDict, TypeVar from . import utils +from .utils.private import get_as_snowflake from .asset import Asset, AssetMixin from .errors import InvalidArgument @@ -108,7 +109,7 @@ def __init__(self, *, name: str | None, animated: bool = False, id: int | None = def from_dict(cls: type[PE], data: PartialEmojiPayload | dict[str, Any]) -> PE: return cls( animated=data.get("animated", False), - id=utils._get_as_snowflake(data, "id"), + id=get_as_snowflake(data, "id"), name=data.get("name") or "", ) diff --git a/discord/poll.py b/discord/poll.py index 5bddce29dd..7d2d426bcb 100644 --- a/discord/poll.py +++ b/discord/poll.py @@ -26,7 +26,9 @@ import datetime from typing import TYPE_CHECKING, Any +from functools import cached_property +from .utils.private import parse_time from . import utils from .enums import PollLayoutType, try_enum from .iterators import VoteIterator @@ -143,7 +145,7 @@ def count(self) -> int | None: return None if self._poll.results is None: return None # Unknown vote count. - _count = self._poll.results and utils.get(self._poll.results.answer_counts, id=self.id) + _count = self._poll.results and utils.find(lambda p: p.id == id, self._poll.results.answer_counts) if _count: return _count.count return 0 # If an answer isn't in answer_counts, it has 0 votes. @@ -341,10 +343,10 @@ def __init__( self._expiry = None self._message = None - @utils.cached_property + @cached_property def expiry(self) -> datetime.datetime | None: """An aware datetime object that specifies the date and time in UTC when the poll will end.""" - return utils.parse_time(self._expiry) + return parse_time(self._expiry) def to_dict(self) -> PollPayload: dict_ = { @@ -425,7 +427,7 @@ def get_answer(self, id) -> PollAnswer | None: Optional[:class:`.PollAnswer`] The returned answer or ``None`` if not found. """ - return utils.get(self.answers, id=id) + return utils.find(lambda a: a.id == id, self.answers) def add_answer( self, diff --git a/discord/role.py b/discord/role.py index b39d5d3ae6..c9af02d7e8 100644 --- a/discord/role.py +++ b/discord/role.py @@ -34,7 +34,8 @@ from .flags import RoleFlags from .mixins import Hashable from .permissions import Permissions -from .utils import MISSING, _bytes_to_base64_data, _get_as_snowflake, snowflake_time +from .utils import MISSING, snowflake_time +from .utils.private import get_as_snowflake, bytes_to_base64_data __all__ = ( "RoleTags", @@ -89,9 +90,9 @@ class RoleTags: ) def __init__(self, data: RoleTagPayload): - self.bot_id: int | None = _get_as_snowflake(data, "bot_id") - self.integration_id: int | None = _get_as_snowflake(data, "integration_id") - self.subscription_listing_id: int | None = _get_as_snowflake(data, "subscription_listing_id") + self.bot_id: int | None = get_as_snowflake(data, "bot_id") + self.integration_id: int | None = get_as_snowflake(data, "integration_id") + self.subscription_listing_id: int | None = get_as_snowflake(data, "subscription_listing_id") # NOTE: The API returns "null" for each of the following tags if they are True, and omits them if False. # However, "null" corresponds to None. # This is different from other fields where "null" means "not there". @@ -526,7 +527,7 @@ async def edit( if icon is None: payload["icon"] = None else: - payload["icon"] = _bytes_to_base64_data(icon) + payload["icon"] = bytes_to_base64_data(icon) payload["unicode_emoji"] = None if unicode_emoji is not MISSING: diff --git a/discord/scheduled_events.py b/discord/scheduled_events.py index b6b3847682..2ce40b2ab2 100644 --- a/discord/scheduled_events.py +++ b/discord/scheduled_events.py @@ -39,7 +39,7 @@ from .iterators import ScheduledEventSubscribersIterator from .mixins import Hashable from .object import Object -from .utils import warn_deprecated +from .utils.private import warn_deprecated, get_as_snowflake, bytes_to_base64_data __all__ = ( "ScheduledEvent", @@ -206,7 +206,7 @@ def __init__( self.end_time: datetime.datetime | None = end_time self.status: ScheduledEventStatus = try_enum(ScheduledEventStatus, data.get("status")) self.subscriber_count: int | None = data.get("user_count", None) - self.creator_id: int | None = utils._get_as_snowflake(data, "creator_id") + self.creator_id: int | None = get_as_snowflake(data, "creator_id") self.creator: Member | None = creator entity_metadata = data.get("entity_metadata") @@ -361,7 +361,7 @@ async def edit( if image is None: payload["image"] = None else: - payload["image"] = utils._bytes_to_base64_data(image) + payload["image"] = bytes_to_base64_data(image) if location is not MISSING: if not isinstance(location, (ScheduledEventLocation, utils.Undefined)): diff --git a/discord/stage_instance.py b/discord/stage_instance.py index d5703f69c2..8fdc508f41 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -30,7 +30,8 @@ from .enums import StagePrivacyLevel, try_enum from .errors import InvalidArgument from .mixins import Hashable -from .utils import MISSING, Undefined, cached_slot_property +from .utils import MISSING, Undefined +from .utils.private import cached_slot_property __all__ = ("StageInstance",) diff --git a/discord/state.py b/discord/state.py index 7b9841c447..7b59f885e1 100644 --- a/discord/state.py +++ b/discord/state.py @@ -44,6 +44,7 @@ ) from . import utils +from .utils.private import get_as_snowflake, parse_time, sane_wait_for from .activity import BaseActivity from .audit_logs import AuditLogEntry from .automod import AutoModRule @@ -180,7 +181,7 @@ def __init__( self.hooks: dict[str, Callable] = hooks self.shard_count: int | None = None self._ready_task: asyncio.Task | None = None - self.application_id: int | None = utils._get_as_snowflake(options, "application_id") + self.application_id: int | None = get_as_snowflake(options, "application_id") self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) if self.guild_ready_timeout < 0: @@ -636,7 +637,7 @@ def parse_ready(self, data) -> None: except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, "id") + self.application_id = get_as_snowflake(application, "id") # flags will always be present here self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore @@ -755,7 +756,7 @@ def parse_message_update(self, data) -> None: def parse_message_reaction_add(self, data) -> None: emoji = data["emoji"] - emoji_id = utils._get_as_snowflake(emoji, "id") + emoji_id = get_as_snowflake(emoji, "id") emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"]) raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") @@ -792,7 +793,7 @@ def parse_message_reaction_remove_all(self, data) -> None: def parse_message_reaction_remove(self, data) -> None: emoji = data["emoji"] - emoji_id = utils._get_as_snowflake(emoji, "id") + emoji_id = get_as_snowflake(emoji, "id") emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionActionEvent(data, emoji, "REACTION_REMOVE") @@ -822,7 +823,7 @@ def parse_message_reaction_remove(self, data) -> None: def parse_message_reaction_remove_emoji(self, data) -> None: emoji = data["emoji"] - emoji_id = utils._get_as_snowflake(emoji, "id") + emoji_id = get_as_snowflake(emoji, "id") emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionClearEmojiEvent(data, emoji) self.dispatch("raw_reaction_clear_emoji", raw) @@ -901,7 +902,7 @@ def parse_interaction_create(self, data) -> None: self.dispatch("interaction", interaction) def parse_presence_update(self, data) -> None: - guild_id = utils._get_as_snowflake(data, "guild_id") + guild_id = get_as_snowflake(data, "guild_id") # guild_id won't be None here guild = self._get_guild(guild_id) if guild is None: @@ -945,7 +946,7 @@ def parse_invite_delete(self, data) -> None: self.dispatch("invite_delete", invite) def parse_channel_delete(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + guild = self._get_guild(get_as_snowflake(data, "guild_id")) channel_id = int(data["id"]) if guild is not None: channel = guild.get_channel(channel_id) @@ -964,7 +965,7 @@ def parse_channel_update(self, data) -> None: self.dispatch("private_channel_update", old_channel, channel) return - guild_id = utils._get_as_snowflake(data, "guild_id") + guild_id = get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: channel = guild.get_channel(channel_id) @@ -992,7 +993,7 @@ def parse_channel_create(self, data) -> None: ) return - guild_id = utils._get_as_snowflake(data, "guild_id") + guild_id = get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here @@ -1023,7 +1024,7 @@ def parse_channel_pins_update(self, data) -> None: ) return - last_pin = utils.parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None + last_pin = parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None if guild is None: self.dispatch("private_channel_pins_update", channel, last_pin) @@ -1737,8 +1738,8 @@ def parse_stage_instance_delete(self, data) -> None: ) def parse_voice_state_update(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) - channel_id = utils._get_as_snowflake(data, "channel_id") + guild = self._get_guild(get_as_snowflake(data, "guild_id")) + channel_id = get_as_snowflake(data, "channel_id") flags = self.member_cache_flags # self.user is *always* cached when this is called self_id = self.user.id # type: ignore @@ -1837,7 +1838,7 @@ def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> User return self.get_user(user_id) def get_reaction_emoji(self, data) -> GuildEmoji | AppEmoji | PartialEmoji: - emoji_id = utils._get_as_snowflake(data, "id") + emoji_id = get_as_snowflake(data, "id") if not emoji_id: return data["name"] @@ -1935,7 +1936,7 @@ async def _delay_ready(self) -> None: ) if len(current_bucket) >= max_concurrency: try: - await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) + await sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) except asyncio.TimeoutError: fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" _log.warning(fmt, guild.shard_id, len(current_bucket)) @@ -1957,7 +1958,7 @@ async def _delay_ready(self) -> None: # 110 reqs/minute w/ 1 req/guild plus some buffer timeout = 61 * (len(children) / 110) try: - await utils.sane_wait_for(futures, timeout=timeout) + await sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: _log.warning( ("Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds"), @@ -2005,7 +2006,7 @@ def parse_ready(self, data) -> None: except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, "id") + self.application_id = get_as_snowflake(application, "id") self.application_flags = ApplicationFlags._from_value(application["flags"]) for guild_data in data["guilds"]: diff --git a/discord/sticker.py b/discord/sticker.py index 66d8d95ef9..6fb3e5ba38 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -32,7 +32,8 @@ from .enums import StickerFormatType, StickerType, try_enum from .errors import InvalidData from .mixins import Hashable -from .utils import MISSING, Undefined, cached_slot_property, find, get, snowflake_time +from .utils import MISSING, Undefined, find, snowflake_time +from .utils.private import cached_slot_property __all__ = ( "StickerPack", @@ -119,7 +120,7 @@ def _from_data(self, data: StickerPackPayload) -> None: self.name: str = data["name"] self.sku_id: int = int(data["sku_id"]) self.cover_sticker_id: int = int(data["cover_sticker_id"]) - self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore + self.cover_sticker: StandardSticker = find(lambda s: s.id == self.cover_sticker_id, self.stickers) # type: ignore self.description: str = data["description"] self._banner: int = int(data["banner_asset_id"]) diff --git a/discord/team.py b/discord/team.py index d04e47e3ff..a7d890e15d 100644 --- a/discord/team.py +++ b/discord/team.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING from . import utils +from .utils.private import get_as_snowflake from .asset import Asset from .enums import TeamMembershipState, try_enum from .user import BaseUser @@ -68,7 +69,7 @@ def __init__(self, state: ConnectionState, data: TeamPayload): self.id: int = int(data["id"]) self.name: str = data["name"] self._icon: str | None = data["icon"] - self.owner_id: int | None = utils._get_as_snowflake(data, "owner_user_id") + self.owner_id: int | None = get_as_snowflake(data, "owner_user_id") self.members: list[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]] def __repr__(self) -> str: @@ -84,7 +85,7 @@ def icon(self) -> Asset | None: @property def owner(self) -> TeamMember | None: """The team's owner.""" - return utils.get(self.members, id=self.owner_id) + return utils.find(lambda m: m.id == self.owner_id, self.members) class TeamMember(BaseUser): diff --git a/discord/template.py b/discord/template.py index 4eaa71e9b7..e5336fd5ed 100644 --- a/discord/template.py +++ b/discord/template.py @@ -28,7 +28,8 @@ from typing import TYPE_CHECKING, Any from .guild import Guild -from .utils import MISSING, Undefined, _bytes_to_base64_data, parse_time +from .utils import MISSING, Undefined +from .utils.private import bytes_to_base64_data, parse_time __all__ = ("Template",) @@ -195,7 +196,7 @@ async def create_guild(self, name: str, icon: Any = None) -> Guild: Invalid icon image format given. Must be PNG or JPG. """ if icon is not None: - icon = _bytes_to_base64_data(icon) + icon = bytes_to_base64_data(icon) data = await self._state.http.create_from_template(self.code, name, icon) return Guild(data=data, state=self._state) diff --git a/discord/threads.py b/discord/threads.py index a0270e54d9..7442b0ec01 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -32,7 +32,8 @@ from .errors import ClientException from .flags import ChannelFlags from .mixins import Hashable -from .utils import MISSING, _get_as_snowflake, parse_time +from .utils import MISSING +from .utils.private import get_as_snowflake, parse_time from discord import utils __all__ = ( @@ -189,7 +190,7 @@ def _from_data(self, data: ThreadPayload): # This data may be missing depending on how this object is being created self.owner_id = int(data.get("owner_id")) if data.get("owner_id", None) is not None else None - self.last_message_id = _get_as_snowflake(data, "last_message_id") + self.last_message_id = get_as_snowflake(data, "last_message_id") self.slowmode_delay = data.get("rate_limit_per_user", 0) self.message_count = data.get("message_count", None) self.member_count = data.get("member_count", None) @@ -283,7 +284,7 @@ def applied_tags(self) -> list[ForumTag]: This is only available for threads in forum or media channels. """ - from .channel import ForumChannel # to prevent circular import # noqa: PLC0415 + from .channel import ForumChannel # noqa: PLC0415 # to prevent circular import if isinstance(self.parent, ForumChannel): return [tag for tag_id in self._applied_tags if (tag := self.parent.get_tag(tag_id)) is not None] diff --git a/discord/ui/view.py b/discord/ui/view.py index 1b45077fb8..d532d2bb22 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -39,7 +39,7 @@ from ..components import Component from ..components import SelectMenu as SelectComponent from ..components import _component_factory -from ..utils import get +from .. import utils from .item import Item, ItemCallbackType __all__ = ("View",) @@ -313,7 +313,7 @@ def clear_items(self) -> None: self.__weights.clear() def get_item(self, custom_id: str) -> Item[V] | None: - """Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`. + """Get an item from the view with the given custom ID. Alias for `utils.find(lambda i: i.custom_id == custom_id, self.children)`. Parameters ---------- @@ -325,7 +325,7 @@ def get_item(self, custom_id: str) -> Item[V] | None: Optional[:class:`Item`] The item with the matching ``custom_id`` if it exists. """ - return get(self.children, custom_id=custom_id) + return utils.find(lambda i: i.custom_id == custom_id, self.children) async def interaction_check(self, interaction: Interaction) -> bool: """|coro| diff --git a/discord/user.py b/discord/user.py index 10d2889048..e9aa9a465e 100644 --- a/discord/user.py +++ b/discord/user.py @@ -34,7 +34,8 @@ from .flags import PublicUserFlags from .iterators import EntitlementIterator from .monetization import Entitlement -from .utils import MISSING, Undefined, _bytes_to_base64_data, snowflake_time +from .utils import MISSING, Undefined, snowflake_time +from .utils.private import bytes_to_base64_data if TYPE_CHECKING: from datetime import datetime @@ -470,12 +471,12 @@ async def edit( if avatar is None: payload["avatar"] = None elif avatar is not MISSING: - payload["avatar"] = _bytes_to_base64_data(avatar) + payload["avatar"] = bytes_to_base64_data(avatar) if banner is None: payload["banner"] = None elif banner is not MISSING: - payload["banner"] = _bytes_to_base64_data(banner) + payload["banner"] = bytes_to_base64_data(banner) data: UserPayload = await self._state.http.edit_profile(payload) return ClientUser(state=self._state, data=data) diff --git a/discord/utils.py b/discord/utils.py deleted file mode 100644 index 18a53a9ed0..0000000000 --- a/discord/utils.py +++ /dev/null @@ -1,1403 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import array -import asyncio -import collections.abc -import datetime -from enum import Enum, auto -import functools -import itertools -import json -import re -import sys -import types -import unicodedata -import warnings -from base64 import b64encode -from bisect import bisect_left -from inspect import isawaitable as _isawaitable -from inspect import signature as _signature -from operator import attrgetter -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Awaitable, - Callable, - Coroutine, - ForwardRef, - Generic, - Iterable, - Iterator, - Literal, - Mapping, - Protocol, - Sequence, - TypeVar, - Union, - overload, -) - -from .errors import HTTPException, InvalidArgument - -try: - import msgspec -except ModuleNotFoundError: - HAS_MSGSPEC = False -else: - HAS_MSGSPEC = True - - -__all__ = ( - "parse_time", - "warn_deprecated", - "deprecated", - "oauth_url", - "snowflake_time", - "time_snowflake", - "find", - "get", - "get_or_fetch", - "sleep_until", - "utcnow", - "resolve_invite", - "resolve_template", - "remove_markdown", - "escape_markdown", - "escape_mentions", - "raw_mentions", - "raw_channel_mentions", - "raw_role_mentions", - "as_chunks", - "format_dt", - "generate_snowflake", - "basic_autocomplete", - "filter_params", -) - -DISCORD_EPOCH = 1420070400000 - - -class Undefined(Enum): - MISSING = auto() - - def __bool__(self) -> Literal[False]: - return False - - -MISSING: Literal[Undefined.MISSING] = Undefined.MISSING - - -class _cached_property: - def __init__(self, function): - self.function = function - self.__doc__ = getattr(function, "__doc__") - - def __get__(self, instance, owner): - if instance is None: - return self - - value = self.function(instance) - setattr(instance, self.function.__name__, value) - - return value - - -if TYPE_CHECKING: - from typing_extensions import ParamSpec - - from .abc import Snowflake - from .commands.context import AutocompleteContext - from .commands.options import OptionChoice - from .invite import Invite - from .permissions import Permissions - from .template import Template - - class _RequestLike(Protocol): - headers: Mapping[str, Any] - - cached_property = property - - P = ParamSpec("P") - -else: - cached_property = _cached_property - AutocompleteContext = Any - OptionChoice = Any - - -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -_Iter = Union[Iterator[T], AsyncIterator[T]] - - -class CachedSlotProperty(Generic[T, T_co]): - def __init__(self, name: str, function: Callable[[T], T_co]) -> None: - self.name = name - self.function = function - self.__doc__ = getattr(function, "__doc__") - - @overload - def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... - - @overload - def __get__(self, instance: T, owner: type[T]) -> T_co: ... - - def __get__(self, instance: T | None, owner: type[T]) -> Any: - if instance is None: - return self - - try: - return getattr(instance, self.name) - except AttributeError: - value = self.function(instance) - setattr(instance, self.name, value) - return value - - -class classproperty(Generic[T_co]): - def __init__(self, fget: Callable[[Any], T_co]) -> None: - self.fget = fget - - def __get__(self, instance: Any | None, owner: type[Any]) -> T_co: - return self.fget(owner) - - def __set__(self, instance, value) -> None: - raise AttributeError("cannot set attribute") - - -def cached_slot_property( - name: str, -) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: - def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: - return CachedSlotProperty(name, func) - - return decorator - - -class SequenceProxy(Generic[T_co], collections.abc.Sequence): - """Read-only proxy of a Sequence.""" - - def __init__(self, proxied: Sequence[T_co]): - self.__proxied = proxied - - def __getitem__(self, idx: int) -> T_co: - return self.__proxied[idx] - - def __len__(self) -> int: - return len(self.__proxied) - - def __contains__(self, item: Any) -> bool: - return item in self.__proxied - - def __iter__(self) -> Iterator[T_co]: - return iter(self.__proxied) - - def __reversed__(self) -> Iterator[T_co]: - return reversed(self.__proxied) - - def index(self, value: Any, *args, **kwargs) -> int: - return self.__proxied.index(value, *args, **kwargs) - - def count(self, value: Any) -> int: - return self.__proxied.count(value) - - -def delay_task(delay: float, func: Coroutine): - async def inner_call(): - await asyncio.sleep(delay) - try: - await func - except HTTPException: - pass - - asyncio.create_task(inner_call()) - - -@overload -def parse_time(timestamp: None) -> None: ... - - -@overload -def parse_time(timestamp: str) -> datetime.datetime: ... - - -@overload -def parse_time(timestamp: str | None) -> datetime.datetime | None: ... - - -def parse_time(timestamp: str | None) -> datetime.datetime | None: - """A helper function to convert an ISO 8601 timestamp to a datetime object. - - Parameters - ---------- - timestamp: Optional[:class:`str`] - The timestamp to convert. - - Returns - ------- - Optional[:class:`datetime.datetime`] - The converted datetime object. - """ - if timestamp: - return datetime.datetime.fromisoformat(timestamp) - return None - - -def copy_doc(original: Callable) -> Callable[[T], T]: - def decorator(overridden: T) -> T: - overridden.__doc__ = original.__doc__ - overridden.__signature__ = _signature(original) # type: ignore - return overridden - - return decorator - - -def warn_deprecated( - name: str, - instead: str | None = None, - since: str | None = None, - removed: str | None = None, - reference: str | None = None, - stacklevel: int = 3, -) -> None: - """Warn about a deprecated function, with the ability to specify details about the deprecation. Emits a - DeprecationWarning. - - Parameters - ---------- - name: str - The name of the deprecated function. - instead: Optional[:class:`str`] - A recommended alternative to the function. - since: Optional[:class:`str`] - The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where - the patch version is optional. - removed: Optional[:class:`str`] - The version in which the function is planned to be removed. This should be in the format - ``major.minor(.patch)``, where the patch version is optional. - reference: Optional[:class:`str`] - A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub - issue/PR. - stacklevel: :class:`int` - The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. - """ - warnings.simplefilter("always", DeprecationWarning) # turn off filter - message = f"{name} is deprecated" - if since: - message += f" since version {since}" - if removed: - message += f" and will be removed in version {removed}" - if instead: - message += f", consider using {instead} instead" - message += "." - if reference: - message += f" See {reference} for more information." - - warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning) - warnings.simplefilter("default", DeprecationWarning) # reset filter - - -def deprecated( - instead: str | None = None, - since: str | None = None, - removed: str | None = None, - reference: str | None = None, - stacklevel: int = 3, - *, - use_qualname: bool = True, -) -> Callable[[Callable[[P], T]], Callable[[P], T]]: - """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when - the decorated function is called. - - Parameters - ---------- - instead: Optional[:class:`str`] - A recommended alternative to the function. - since: Optional[:class:`str`] - The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where - the patch version is optional. - removed: Optional[:class:`str`] - The version in which the function is planned to be removed. This should be in the format - ``major.minor(.patch)``, where the patch version is optional. - reference: Optional[:class:`str`] - A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub - issue/PR. - stacklevel: :class:`int` - The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. - use_qualname: :class:`bool` - Whether to use the qualified name of the function in the deprecation warning. If ``False``, the short name of - the function will be used instead. For example, __qualname__ will display as ``Client.login`` while __name__ - will display as ``login``. Defaults to ``True``. - """ - - def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: - @functools.wraps(func) - def decorated(*args: P.args, **kwargs: P.kwargs) -> T: - warn_deprecated( - name=func.__qualname__ if use_qualname else func.__name__, - instead=instead, - since=since, - removed=removed, - reference=reference, - stacklevel=stacklevel, - ) - return func(*args, **kwargs) - - return decorated - - return actual_decorator - - -def oauth_url( - client_id: int | str, - *, - permissions: Permissions | Undefined = MISSING, - guild: Snowflake | Undefined = MISSING, - redirect_uri: str | Undefined = MISSING, - scopes: Iterable[str] | Undefined = MISSING, - disable_guild_select: bool = False, -) -> str: - """A helper function that returns the OAuth2 URL for inviting the bot - into guilds. - - Parameters - ---------- - client_id: Union[:class:`int`, :class:`str`] - The client ID for your bot. - permissions: :class:`~discord.Permissions` - The permissions you're requesting. If not given then you won't be requesting any - permissions. - guild: :class:`~discord.abc.Snowflake` - The guild to pre-select in the authorization screen, if available. - redirect_uri: :class:`str` - An optional valid redirect URI. - scopes: Iterable[:class:`str`] - An optional valid list of scopes. Defaults to ``('bot',)``. - - .. versionadded:: 1.7 - disable_guild_select: :class:`bool` - Whether to disallow the user from changing the guild dropdown. - - .. versionadded:: 2.0 - - Returns - ------- - :class:`str` - The OAuth2 URL for inviting the bot into guilds. - """ - url = f"https://discord.com/oauth2/authorize?client_id={client_id}" - url += f"&scope={'+'.join(scopes or ('bot',))}" - if permissions is not MISSING: - url += f"&permissions={permissions.value}" - if guild is not MISSING: - url += f"&guild_id={guild.id}" - if redirect_uri is not MISSING: - from urllib.parse import urlencode # noqa: PLC0415 - - url += f"&response_type=code&{urlencode({'redirect_uri': redirect_uri})}" - if disable_guild_select: - url += "&disable_guild_select=true" - return url - - -def snowflake_time(id: int) -> datetime.datetime: - """Converts a Discord snowflake ID to a UTC-aware datetime object. - - Parameters - ---------- - id: :class:`int` - The snowflake ID. - - Returns - ------- - :class:`datetime.datetime` - An aware datetime in UTC representing the creation time of the snowflake. - """ - timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 - return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) - - -def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: - """Returns a numeric snowflake pretending to be created at the given date. - - When using as the lower end of a range, use ``time_snowflake(high=False) - 1`` - to be inclusive, ``high=True`` to be exclusive. - - When using as the higher end of a range, use ``time_snowflake(high=True) + 1`` - to be inclusive, ``high=False`` to be exclusive - - Parameters - ---------- - dt: :class:`datetime.datetime` - A datetime object to convert to a snowflake. - If naive, the timezone is assumed to be local time. - high: :class:`bool` - Whether to set the lower 22 bit to high or low. - - Returns - ------- - :class:`int` - The snowflake representing the time given. - """ - discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) - return (discord_millis << 22) + (2**22 - 1 if high else 0) - - -def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: - """A helper to return the first element found in the sequence - that meets the predicate. For example: :: - - member = discord.utils.find(lambda m: m.name == "Mighty", channel.guild.members) - - would find the first :class:`~discord.Member` whose name is 'Mighty' and return it. - If an entry is not found, then ``None`` is returned. - - This is different from :func:`py:filter` due to the fact it stops the moment it finds - a valid entry. - - Parameters - ---------- - predicate - A function that returns a boolean-like result. - seq: :class:`collections.abc.Iterable` - The iterable to search through. - """ - - for element in seq: - if predicate(element): - return element - return None - - -def get(iterable: Iterable[T], **attrs: Any) -> T | None: - r"""A helper that returns the first element in the iterable that meets - all the traits passed in ``attrs``. This is an alternative for - :func:`~discord.utils.find`. - - When multiple attributes are specified, they are checked using - logical AND, not logical OR. Meaning they have to meet every - attribute passed in and not one of them. - - To have a nested attribute search (i.e. search by ``x.y``) then - pass in ``x__y`` as the keyword argument. - - If nothing is found that matches the attributes passed, then - ``None`` is returned. - - Examples - --------- - - Basic usage: - - .. code-block:: python3 - - member = discord.utils.get(message.guild.members, name="Foo") - - Multiple attribute matching: - - .. code-block:: python3 - - channel = discord.utils.get(guild.voice_channels, name="Foo", bitrate=64000) - - Nested attribute matching: - - .. code-block:: python3 - - channel = discord.utils.get(client.get_all_channels(), guild__name="Cool", name="general") - - Parameters - ----------- - iterable - An iterable to search through. - \*\*attrs - Keyword arguments that denote attributes to search with. - """ - - # global -> local - _all = all - attrget = attrgetter - - # Special case the single element call - if len(attrs) == 1: - k, v = attrs.popitem() - pred = attrget(k.replace("__", ".")) - for elem in iterable: - if pred(elem) == v: - return elem - return None - - converted = [(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()] - - for elem in iterable: - if _all(pred(elem) == value for pred, value in converted): - return elem - return None - - -async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: - """|coro| - - Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. - If the fetch also fails, an error will be raised. - - Parameters - ---------- - obj: Any - The object to use the get or fetch methods in - attr: :class:`str` - The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. - id: :class:`int` - The ID of the object - default: Any - The default value to return if the object is not found, instead of raising an error. - - Returns - ------- - Any - The object found or the default value. - - Raises - ------ - :exc:`AttributeError` - The object is missing a ``get_`` or ``fetch_`` method - :exc:`NotFound` - Invalid ID for the object - :exc:`HTTPException` - An error occurred fetching the object - :exc:`Forbidden` - You do not have permission to fetch the object - - Examples - -------- - - Getting a guild from a guild ID: :: - - guild = await utils.get_or_fetch(client, "guild", guild_id) - - Getting a channel from the guild. If the channel is not found, return None: :: - - channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) - """ - getter = getattr(obj, f"get_{attr}")(id) - if getter is None: - try: - getter = await getattr(obj, f"fetch_{attr}")(id) - except AttributeError: - getter = await getattr(obj, f"_fetch_{attr}")(id) - if getter is None: - raise ValueError(f"Could not find {attr} with id {id} on {obj}") - except (HTTPException, ValueError): - if default is not MISSING: - return default - else: - raise - return getter - - -def _unique(iterable: Iterable[T]) -> list[T]: - return [x for x in dict.fromkeys(iterable)] - - -def _get_as_snowflake(data: Any, key: str) -> int | None: - try: - value = data[key] - except KeyError: - return None - else: - return value and int(value) - - -def _get_mime_type_for_image(data: bytes): - if data.startswith(b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"): - return "image/png" - elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): - return "image/jpeg" - elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")): - return "image/gif" - elif data.startswith(b"RIFF") and data[8:12] == b"WEBP": - return "image/webp" - else: - raise InvalidArgument("Unsupported image type given") - - -def _bytes_to_base64_data(data: bytes) -> str: - fmt = "data:{mime};base64,{data}" - mime = _get_mime_type_for_image(data) - b64 = b64encode(data).decode("ascii") - return fmt.format(mime=mime, data=b64) - - -if HAS_MSGSPEC: - - def _to_json(obj: Any) -> str: # type: ignore - return msgspec.json.encode(obj).decode("utf-8") - - _from_json = msgspec.json.decode # type: ignore - -else: - - def _to_json(obj: Any) -> str: - return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) - - _from_json = json.loads - - -def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: - reset_after: str | None = request.headers.get("X-Ratelimit-Reset-After") - if not use_clock and reset_after: - return float(reset_after) - utc = datetime.timezone.utc - now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) - return (reset - now).total_seconds() - - -async def maybe_coroutine(f, *args, **kwargs): - value = f(*args, **kwargs) - if _isawaitable(value): - return await value - else: - return value - - -async def async_all(gen, *, check=_isawaitable): - for elem in gen: - if check(elem): - elem = await elem - if not elem: - return False - return True - - -async def sane_wait_for(futures, *, timeout): - ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) - - if len(pending) != 0: - raise asyncio.TimeoutError() - - return done - - -def get_slots(cls: type[Any]) -> Iterator[str]: - for mro in reversed(cls.__mro__): - try: - yield from mro.__slots__ - except AttributeError: - continue - - -def compute_timedelta(dt: datetime.datetime): - if dt.tzinfo is None: - dt = dt.astimezone() - now = datetime.datetime.now(datetime.timezone.utc) - return max((dt - now).total_seconds(), 0) - - -async def sleep_until(when: datetime.datetime, result: T | None = None) -> T | None: - """|coro| - - Sleep until a specified time. - - If the time supplied is in the past this function will yield instantly. - - .. versionadded:: 1.3 - - Parameters - ---------- - when: :class:`datetime.datetime` - The timestamp in which to sleep until. If the datetime is naive then - it is assumed to be local time. - result: Any - If provided is returned to the caller when the coroutine completes. - """ - delta = compute_timedelta(when) - return await asyncio.sleep(delta, result) - - -def utcnow() -> datetime.datetime: - """A helper function to return an aware UTC datetime representing the current time. - - This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware - datetime, compared to the naive datetime in the standard library. - - .. versionadded:: 2.0 - - Returns - ------- - :class:`datetime.datetime` - The current aware datetime in UTC. - """ - return datetime.datetime.now(datetime.timezone.utc) - - -def valid_icon_size(size: int) -> bool: - """Icons must be power of 2 within [16, 4096].""" - return not size & (size - 1) and 4096 >= size >= 16 - - -class SnowflakeList(array.array): - """Internal data storage class to efficiently store a list of snowflakes. - - This should have the following characteristics: - - - Low memory usage - - O(n) iteration (obviously) - - O(n log n) initial creation if data is unsorted - - O(log n) search and indexing - - O(n) insertion - """ - - __slots__ = () - - if TYPE_CHECKING: - - def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... - - def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): - return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore - - def add(self, element: int) -> None: - i = bisect_left(self, element) - self.insert(i, element) - - def get(self, element: int) -> int | None: - i = bisect_left(self, element) - return self[i] if i != len(self) and self[i] == element else None - - def has(self, element: int) -> bool: - i = bisect_left(self, element) - return i != len(self) and self[i] == element - - -_IS_ASCII = re.compile(r"^[\x00-\x7f]+$") - - -def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: - """Returns string's width.""" - match = _IS_ASCII.match(string) - if match: - return match.endpos - - UNICODE_WIDE_CHAR_TYPE = "WFA" - func = unicodedata.east_asian_width - return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) - - -def resolve_invite(invite: Invite | str) -> str: - """ - Resolves an invite from a :class:`~discord.Invite`, URL or code. - - Parameters - ---------- - invite: Union[:class:`~discord.Invite`, :class:`str`] - The invite. - - Returns - ------- - :class:`str` - The invite code. - """ - from .invite import Invite # circular import # noqa: PLC0415 - - if isinstance(invite, Invite): - return invite.code - rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)" - m = re.match(rx, invite) - if m: - return m.group(1) - return invite - - -def resolve_template(code: Template | str) -> str: - """ - Resolves a template code from a :class:`~discord.Template`, URL or code. - - .. versionadded:: 1.4 - - Parameters - ---------- - code: Union[:class:`~discord.Template`, :class:`str`] - The code. - - Returns - ------- - :class:`str` - The template code. - """ - from .template import Template # circular import # noqa: PLC0415 - - if isinstance(code, Template): - return code.code - rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)" - m = re.match(rx, code) - if m: - return m.group(1) - return code - - -_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?(?:>>)?\s|{_MARKDOWN_ESCAPE_LINKS}" - -_MARKDOWN_ESCAPE_REGEX = re.compile( - rf"(?P{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", - re.MULTILINE | re.X, -) - -_URL_REGEX = r"(?P<[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])" - -_MARKDOWN_STOCK_REGEX = rf"(?P[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})" - - -def remove_markdown(text: str, *, ignore_links: bool = True) -> str: - """A helper function that removes markdown characters. - - .. versionadded:: 1.7 - - .. note:: - This function is not markdown aware and may remove meaning from the original text. For example, - if the input contains ``10 * 5`` then it will be converted into ``10 5``. - - Parameters - ---------- - text: :class:`str` - The text to remove markdown from. - ignore_links: :class:`bool` - Whether to leave links alone when removing markdown. For example, - if a URL in the text contains characters such as ``_`` then it will - be left alone. Defaults to ``True``. - - Returns - ------- - :class:`str` - The text with the markdown special characters removed. - """ - - def replacement(match): - groupdict = match.groupdict() - return groupdict.get("url", "") - - regex = _MARKDOWN_STOCK_REGEX - if ignore_links: - regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE) - - -def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: - r"""A helper function that escapes Discord's markdown. - - Parameters - ----------- - text: :class:`str` - The text to escape markdown from. - as_needed: :class:`bool` - Whether to escape the markdown characters as needed. This - means that it does not escape extraneous characters if it's - not necessary, e.g. ``**hello**`` is escaped into ``\*\*hello**`` - instead of ``\*\*hello\*\*``. Note however that this can open - you up to some clever syntax abuse. Defaults to ``False``. - ignore_links: :class:`bool` - Whether to leave links alone when escaping markdown. For example, - if a URL in the text contains characters such as ``_`` then it will - be left alone. This option is not supported with ``as_needed``. - Defaults to ``True``. - - Returns - -------- - :class:`str` - The text with the markdown special characters escaped with a slash. - """ - - if not as_needed: - - def replacement(match): - groupdict = match.groupdict() - is_url = groupdict.get("url") - if is_url: - return is_url - return f"\\{groupdict['markdown']}" - - regex = _MARKDOWN_STOCK_REGEX - if ignore_links: - regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE | re.X) - else: - text = re.sub(r"\\", r"\\\\", text) - return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text) - - -def escape_mentions(text: str) -> str: - """A helper function that escapes everyone, here, role, and user mentions. - - .. note:: - - This does not include channel mentions. - - .. note:: - - For more granular control over what mentions should be escaped - within messages, refer to the :class:`~discord.AllowedMentions` - class. - - Parameters - ---------- - text: :class:`str` - The text to escape mentions from. - - Returns - ------- - :class:`str` - The text with the mentions removed. - """ - return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text) - - -def raw_mentions(text: str) -> list[int]: - """Returns a list of user IDs matching ``<@user_id>`` in the string. - - .. versionadded:: 2.2 - - Parameters - ---------- - text: :class:`str` - The string to get user mentions from. - - Returns - ------- - List[:class:`int`] - A list of user IDs found in the string. - """ - return [int(x) for x in re.findall(r"<@!?([0-9]+)>", text)] - - -def raw_channel_mentions(text: str) -> list[int]: - """Returns a list of channel IDs matching ``<@#channel_id>`` in the string. - - .. versionadded:: 2.2 - - Parameters - ---------- - text: :class:`str` - The string to get channel mentions from. - - Returns - ------- - List[:class:`int`] - A list of channel IDs found in the string. - """ - return [int(x) for x in re.findall(r"<#([0-9]+)>", text)] - - -def raw_role_mentions(text: str) -> list[int]: - """Returns a list of role IDs matching ``<@&role_id>`` in the string. - - .. versionadded:: 2.2 - - Parameters - ---------- - text: :class:`str` - The string to get role mentions from. - - Returns - ------- - List[:class:`int`] - A list of role IDs found in the string. - """ - return [int(x) for x in re.findall(r"<@&([0-9]+)>", text)] - - -def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[list[T]]: - ret = [] - n = 0 - for item in iterator: - ret.append(item) - n += 1 - if n == max_size: - yield ret - ret = [] - n = 0 - if ret: - yield ret - - -async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[list[T]]: - ret = [] - n = 0 - async for item in iterator: - ret.append(item) - n += 1 - if n == max_size: - yield ret - ret = [] - n = 0 - if ret: - yield ret - - -@overload -def as_chunks(iterator: Iterator[T], max_size: int) -> Iterator[list[T]]: ... - - -@overload -def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[list[T]]: ... - - -def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[list[T]]: - """A helper function that collects an iterator into chunks of a given size. - - .. versionadded:: 2.0 - - .. warning:: - - The last chunk collected may not be as large as ``max_size``. - - Parameters - ---------- - iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] - The iterator to chunk, can be sync or async. - max_size: :class:`int` - The maximum chunk size. - - Returns - ------- - Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`] - A new iterator which yields chunks of a given size. - """ - if max_size <= 0: - raise ValueError("Chunk sizes must be greater than 0.") - - if isinstance(iterator, AsyncIterator): - return _achunk(iterator, max_size) - return _chunk(iterator, max_size) - - -PY_310 = sys.version_info >= (3, 10) - - -def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]: - params = [] - literal_cls = type(Literal[0]) - for p in parameters: - if isinstance(p, literal_cls): - params.extend(p.__args__) - else: - params.append(p) - return tuple(params) - - -def normalise_optional_params(parameters: Iterable[Any]) -> tuple[Any, ...]: - none_cls = type(None) - return tuple(p for p in parameters if p is not none_cls) + (none_cls,) - - -def evaluate_annotation( - tp: Any, - globals: dict[str, Any], - locals: dict[str, Any], - cache: dict[str, Any], - *, - implicit_str: bool = True, -): - if isinstance(tp, ForwardRef): - tp = tp.__forward_arg__ - # ForwardRefs always evaluate their internals - implicit_str = True - - if implicit_str and isinstance(tp, str): - if tp in cache: - return cache[tp] - evaluated = eval(tp, globals, locals) - cache[tp] = evaluated - return evaluate_annotation(evaluated, globals, locals, cache) - - if hasattr(tp, "__args__"): - implicit_str = True - is_literal = False - args = tp.__args__ - if not hasattr(tp, "__origin__"): - if PY_310 and tp.__class__ is types.UnionType: # type: ignore - converted = Union[args] # type: ignore - return evaluate_annotation(converted, globals, locals, cache) - - return tp - if tp.__origin__ is Union: - try: - if args.index(type(None)) != len(args) - 1: - args = normalise_optional_params(tp.__args__) - except ValueError: - pass - if tp.__origin__ is Literal: - if not PY_310: - args = flatten_literal_params(tp.__args__) - implicit_str = False - is_literal = True - - evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args - ) - - if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): - raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") - - if evaluated_args == args: - return tp - - try: - return tp.copy_with(evaluated_args) - except AttributeError: - return tp.__origin__[evaluated_args] - - return tp - - -def resolve_annotation( - annotation: Any, - globalns: dict[str, Any], - localns: dict[str, Any] | None, - cache: dict[str, Any] | None, -) -> Any: - if annotation is None: - return type(None) - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - - locals = globalns if localns is None else localns - if cache is None: - cache = {} - return evaluate_annotation(annotation, globalns, locals, cache) - - -TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] - - -def format_dt(dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None) -> str: - """A helper function to format a :class:`datetime.datetime` for presentation within Discord. - - This allows for a locale-independent way of presenting data using Discord specific Markdown. - - +-------------+----------------------------+-----------------+ - | Style | Example Output | Description | - +=============+============================+=================+ - | t | 22:57 | Short Time | - +-------------+----------------------------+-----------------+ - | T | 22:57:58 | Long Time | - +-------------+----------------------------+-----------------+ - | d | 17/05/2016 | Short Date | - +-------------+----------------------------+-----------------+ - | D | 17 May 2016 | Long Date | - +-------------+----------------------------+-----------------+ - | f (default) | 17 May 2016 22:57 | Short Date Time | - +-------------+----------------------------+-----------------+ - | F | Tuesday, 17 May 2016 22:57 | Long Date Time | - +-------------+----------------------------+-----------------+ - | R | 5 years ago | Relative Time | - +-------------+----------------------------+-----------------+ - - Note that the exact output depends on the user's locale setting in the client. The example output - presented is using the ``en-GB`` locale. - - .. versionadded:: 2.0 - - Parameters - ---------- - dt: Union[:class:`datetime.datetime`, :class:`datetime.time`] - The datetime to format. - style: :class:`str` - The style to format the datetime with. - - Returns - ------- - :class:`str` - The formatted string. - """ - if isinstance(dt, datetime.time): - dt = datetime.datetime.combine(datetime.datetime.now(), dt) - if style is None: - return f"" - return f"" - - -def generate_snowflake(dt: datetime.datetime | None = None) -> int: - """Returns a numeric snowflake pretending to be created at the given date but more accurate and random - than :func:`time_snowflake`. If dt is not passed, it makes one from the current time using utcnow. - - Parameters - ---------- - dt: :class:`datetime.datetime` - A datetime object to convert to a snowflake. - If naive, the timezone is assumed to be local time. - - Returns - ------- - :class:`int` - The snowflake representing the time given. - """ - - dt = dt or utcnow() - return int(dt.timestamp() * 1000 - DISCORD_EPOCH) << 22 | 0x3FFFFF - - -V = Union[Iterable[OptionChoice], Iterable[str], Iterable[int], Iterable[float]] -AV = Awaitable[V] -Values = Union[V, Callable[[AutocompleteContext], Union[V, AV]], AV] -AutocompleteFunc = Callable[[AutocompleteContext], AV] -FilterFunc = Callable[[AutocompleteContext, Any], Union[bool, Awaitable[bool]]] - - -def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: - """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and - will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is - callable, it will be called with the AutocompleteContext. - - This is meant to be passed into the :attr:`discord.Option.autocomplete` attribute. - - Parameters - ---------- - values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] - Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a - single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. - filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] - An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: - the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. - - .. versionadded:: 2.7 - - Returns - ------- - Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] - A wrapped callback for the autocomplete. - - Examples - -------- - - Basic usage: - - .. code-block:: python3 - - Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"))) - - # or - - - async def autocomplete(ctx): - return "foo", "bar", "baz", ctx.interaction.user.name - - - Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) - - With filter parameter: - - .. code-block:: python3 - - Option( - str, - "color", - autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i), - ) - - .. versionadded:: 2.0 - - Note - ---- - Autocomplete cannot be used for options that have specified choices. - """ - - async def autocomplete_callback(ctx: AutocompleteContext) -> V: - _values = values # since we reassign later, python considers it local if we don't do this - - if callable(_values): - _values = _values(ctx) - if asyncio.iscoroutine(_values): - _values = await _values - - if filter is None: - - def _filter(ctx: AutocompleteContext, item: Any) -> bool: - item = getattr(item, "name", item) - return str(item).lower().startswith(str(ctx.value or "").lower()) - - gen = (val for val in _values if _filter(ctx, val)) - - elif asyncio.iscoroutinefunction(filter): - gen = (val for val in _values if await filter(ctx, val)) - - elif callable(filter): - gen = (val for val in _values if filter(ctx, val)) - - else: - raise TypeError("``filter`` must be callable.") - - return iter(itertools.islice(gen, 25)) - - return autocomplete_callback - - -def filter_params(params, **kwargs): - """A helper function to filter out and replace certain keyword parameters - - Parameters - ---------- - params: Dict[str, Any] - The initial parameters to filter. - **kwargs: Dict[str, Optional[str]] - Key to value pairs where the key's contents would be moved to the - value, or if the value is None, remove key's contents (see code example). - - Example - ------- - .. code-block:: python3 - - >>> params = {"param1": 12, "param2": 13} - >>> filter_params(params, param1="param3", param2=None) - {'param3': 12} - # values of 'param1' is moved to 'param3' - # and values of 'param2' are completely removed. - """ - for old_param, new_param in kwargs.items(): - if old_param in params: - if new_param is None: - params.pop(old_param) - else: - params[new_param] = params.pop(old_param) - - return params diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py new file mode 100644 index 0000000000..41e39a9a6c --- /dev/null +++ b/discord/utils/__init__.py @@ -0,0 +1,137 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import ( + Any, + AsyncIterator, + Iterator, + Mapping, + Protocol, + TypeVar, + Union, +) + +from ..errors import HTTPException +from .public import ( + basic_autocomplete, + generate_snowflake, + utcnow, + find, + snowflake_time, + oauth_url, + Undefined, + MISSING, + format_dt, + escape_mentions, + raw_mentions, + raw_channel_mentions, + raw_role_mentions, + remove_markdown, + escape_markdown, +) + +DISCORD_EPOCH = 1420070400000 + + +__all__ = ( + "oauth_url", + "snowflake_time", + "find", + "get_or_fetch", + "utcnow", + "remove_markdown", + "escape_markdown", + "escape_mentions", + "raw_mentions", + "raw_channel_mentions", + "raw_role_mentions", + "format_dt", + "generate_snowflake", + "basic_autocomplete", + "Undefined", + "MISSING", +) + + +async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: + """|coro| + + Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. + If the fetch also fails, an error will be raised. + + Parameters + ---------- + obj: Any + The object to use the get or fetch methods in + attr: :class:`str` + The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. + id: :class:`int` + The ID of the object + default: Any + The default value to return if the object is not found, instead of raising an error. + + Returns + ------- + Any + The object found or the default value. + + Raises + ------ + :exc:`AttributeError` + The object is missing a ``get_`` or ``fetch_`` method + :exc:`NotFound` + Invalid ID for the object + :exc:`HTTPException` + An error occurred fetching the object + :exc:`Forbidden` + You do not have permission to fetch the object + + Examples + -------- + + Getting a guild from a guild ID: :: + + guild = await utils.get_or_fetch(client, "guild", guild_id) + + Getting a channel from the guild. If the channel is not found, return None: :: + + channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) + """ + getter = getattr(obj, f"get_{attr}")(id) + if getter is None: + try: + getter = await getattr(obj, f"fetch_{attr}")(id) + except AttributeError: + getter = await getattr(obj, f"_fetch_{attr}")(id) + if getter is None: + raise ValueError(f"Could not find {attr} with id {id} on {obj}") + except (HTTPException, ValueError): + if default is not MISSING: + return default + else: + raise + return getter diff --git a/discord/utils/private.py b/discord/utils/private.py new file mode 100644 index 0000000000..491bb19e1e --- /dev/null +++ b/discord/utils/private.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import array +import asyncio +import collections.abc +import datetime +import functools +import re +import sys +import types +import unicodedata +import warnings +from _bisect import bisect_left +from base64 import b64encode +from inspect import isawaitable, signature +from typing import ( + TYPE_CHECKING, + Any, + overload, + Callable, + TypeVar, + ParamSpec, + Iterable, + Literal, + ForwardRef, + Union, + Coroutine, + Awaitable, + reveal_type, + Generic, + Sequence, + Iterator, +) + +from ..errors import InvalidArgument, HTTPException + +if TYPE_CHECKING: + from ..invite import Invite + from ..template import Template + +_IS_ASCII = re.compile(r"^[\x00-\x7f]+$") + +P = ParamSpec("P") +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) + + +def resolve_invite(invite: Invite | str) -> str: + """ + Resolves an invite from a :class:`~discord.Invite`, URL or code. + + Parameters + ---------- + invite: Union[:class:`~discord.Invite`, :class:`str`] + The invite. + + Returns + ------- + :class:`str` + The invite code. + """ + from ..invite import Invite # noqa: PLC0415 # circular import + + if isinstance(invite, Invite): + return invite.code + rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)" + m = re.match(rx, invite) + if m: + return m.group(1) + return invite + + +__all__ = ("resolve_invite",) + + +def get_as_snowflake(data: Any, key: str) -> int | None: + try: + value = data[key] + except KeyError: + return None + else: + return value and int(value) + + +def get_mime_type_for_image(data: bytes): + if data.startswith(b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"): + return "image/png" + elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): + return "image/jpeg" + elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")): + return "image/gif" + elif data.startswith(b"RIFF") and data[8:12] == b"WEBP": + return "image/webp" + else: + raise InvalidArgument("Unsupported image type given") + + +def bytes_to_base64_data(data: bytes) -> str: + fmt = "data:{mime};base64,{data}" + mime = get_mime_type_for_image(data) + b64 = b64encode(data).decode("ascii") + return fmt.format(mime=mime, data=b64) + + +def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: + reset_after: str | None = request.headers.get("X-Ratelimit-Reset-After") + if not use_clock and reset_after: + return float(reset_after) + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) + return (reset - now).total_seconds() + + +def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: + """Returns string's width.""" + match = _IS_ASCII.match(string) + if match: + return match.endpos + + UNICODE_WIDE_CHAR_TYPE = "WFA" + func = unicodedata.east_asian_width + return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) + + +def resolve_template(code: Template | str) -> str: + """ + Resolves a template code from a :class:`~discord.Template`, URL or code. + + .. versionadded:: 1.4 + + Parameters + ---------- + code: Union[:class:`~discord.Template`, :class:`str`] + The code. + + Returns + ------- + :class:`str` + The template code. + """ + from ..template import Template # noqa: PLC0415 # circular import + + if isinstance(code, Template): + return code.code + rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)" + m = re.match(rx, code) + if m: + return m.group(1) + return code + + +__all__ = ( + "resolve_invite", + "get_as_snowflake", + "get_mime_type_for_image", + "bytes_to_base64_data", + "parse_ratelimit_header", + "string_width", +) + + +@overload +def parse_time(timestamp: None) -> None: ... + + +@overload +def parse_time(timestamp: str) -> datetime.datetime: ... + + +@overload +def parse_time(timestamp: str | None) -> datetime.datetime | None: ... + + +def parse_time(timestamp: str | None) -> datetime.datetime | None: + """A helper function to convert an ISO 8601 timestamp to a datetime object. + + Parameters + ---------- + timestamp: Optional[:class:`str`] + The timestamp to convert. + + Returns + ------- + Optional[:class:`datetime.datetime`] + The converted datetime object. + """ + if timestamp: + return datetime.datetime.fromisoformat(timestamp) + return None + + +def warn_deprecated( + name: str, + instead: str | None = None, + since: str | None = None, + removed: str | None = None, + reference: str | None = None, + stacklevel: int = 3, +) -> None: + """Warn about a deprecated function, with the ability to specify details about the deprecation. Emits a + DeprecationWarning. + + Parameters + ---------- + name: str + The name of the deprecated function. + instead: Optional[:class:`str`] + A recommended alternative to the function. + since: Optional[:class:`str`] + The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where + the patch version is optional. + removed: Optional[:class:`str`] + The version in which the function is planned to be removed. This should be in the format + ``major.minor(.patch)``, where the patch version is optional. + reference: Optional[:class:`str`] + A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub + issue/PR. + stacklevel: :class:`int` + The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. + """ + warnings.simplefilter("always", DeprecationWarning) # turn off filter + message = f"{name} is deprecated" + if since: + message += f" since version {since}" + if removed: + message += f" and will be removed in version {removed}" + if instead: + message += f", consider using {instead} instead" + message += "." + if reference: + message += f" See {reference} for more information." + + warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning) + warnings.simplefilter("default", DeprecationWarning) # reset filter + + +def deprecated( + instead: str | None = None, + since: str | None = None, + removed: str | None = None, + reference: str | None = None, + stacklevel: int = 3, + *, + use_qualname: bool = True, +) -> Callable[[Callable[[P], T]], Callable[[P], T]]: + """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when + the decorated function is called. + + Parameters + ---------- + instead: Optional[:class:`str`] + A recommended alternative to the function. + since: Optional[:class:`str`] + The version in which the function was deprecated. This should be in the format ``major.minor(.patch)``, where + the patch version is optional. + removed: Optional[:class:`str`] + The version in which the function is planned to be removed. This should be in the format + ``major.minor(.patch)``, where the patch version is optional. + reference: Optional[:class:`str`] + A reference that explains the deprecation, typically a URL to a page such as a changelog entry or a GitHub + issue/PR. + stacklevel: :class:`int` + The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. + use_qualname: :class:`bool` + Whether to use the qualified name of the function in the deprecation warning. If ``False``, the short name of + the function will be used instead. For example, __qualname__ will display as ``Client.login`` while __name__ + will display as ``login``. Defaults to ``True``. + """ + + def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: + @functools.wraps(func) + def decorated(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecated( + name=func.__qualname__ if use_qualname else func.__name__, + instead=instead, + since=since, + removed=removed, + reference=reference, + stacklevel=stacklevel, + ) + return func(*args, **kwargs) + + return decorated + + return actual_decorator + + +PY_310 = sys.version_info >= (3, 10) + + +def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]: + params = [] + literal_cls = type(Literal[0]) + for p in parameters: + if isinstance(p, literal_cls): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + + +def normalise_optional_params(parameters: Iterable[Any]) -> tuple[Any, ...]: + none_cls = type(None) + return tuple(p for p in parameters if p is not none_cls) + (none_cls,) + + +def evaluate_annotation( + tp: Any, + globals: dict[str, Any], + locals: dict[str, Any], + cache: dict[str, Any], + *, + implicit_str: bool = True, +): + if isinstance(tp, ForwardRef): + tp = tp.__forward_arg__ + # ForwardRefs always evaluate their internals + implicit_str = True + + if implicit_str and isinstance(tp, str): + if tp in cache: + return cache[tp] + evaluated = eval(tp, globals, locals) + cache[tp] = evaluated + return evaluate_annotation(evaluated, globals, locals, cache) + + if hasattr(tp, "__args__"): + implicit_str = True + is_literal = False + args = tp.__args__ + if not hasattr(tp, "__origin__"): + if PY_310 and tp.__class__ is types.UnionType: # type: ignore + converted = Union[args] # type: ignore + return evaluate_annotation(converted, globals, locals, cache) + + return tp + if tp.__origin__ is Union: + try: + if args.index(type(None)) != len(args) - 1: + args = normalise_optional_params(tp.__args__) + except ValueError: + pass + if tp.__origin__ is Literal: + if not PY_310: + args = flatten_literal_params(tp.__args__) + implicit_str = False + is_literal = True + + evaluated_args = tuple( + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args + ) + + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + + if evaluated_args == args: + return tp + + try: + return tp.copy_with(evaluated_args) + except AttributeError: + return tp.__origin__[evaluated_args] + + return tp + + +def resolve_annotation( + annotation: Any, + globalns: dict[str, Any], + localns: dict[str, Any] | None, + cache: dict[str, Any] | None, +) -> Any: + if annotation is None: + return type(None) + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + + locals = globalns if localns is None else localns + if cache is None: + cache = {} + return evaluate_annotation(annotation, globalns, locals, cache) + + +def delay_task(delay: float, func: Coroutine): + async def inner_call(): + await asyncio.sleep(delay) + try: + await func + except HTTPException: + pass + + asyncio.create_task(inner_call()) + + +async def async_all(gen: Iterable[Any]) -> bool: + for elem in gen: + if isawaitable(elem): + elem = await elem + if not elem: + return False + return True + + +async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: + value = f(*args, **kwargs) + if isawaitable(value): + reveal_type(f) + return await value + return value + + +async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]: + ensured = [asyncio.ensure_future(fut) for fut in futures] + done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) + + if len(pending) != 0: + raise asyncio.TimeoutError() + + return done + + +class SnowflakeList(array.array): + """Internal data storage class to efficiently store a list of snowflakes. + + This should have the following characteristics: + + - Low memory usage + - O(n) iteration (obviously) + - O(n log n) initial creation if data is unsorted + - O(log n) search and indexing + - O(n) insertion + """ + + __slots__ = () + + if TYPE_CHECKING: + + def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... + + def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): + return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore + + def add(self, element: int) -> None: + i = bisect_left(self, element) + self.insert(i, element) + + def get(self, element: int) -> int | None: + i = bisect_left(self, element) + return self[i] if i != len(self) and self[i] == element else None + + def has(self, element: int) -> bool: + i = bisect_left(self, element) + return i != len(self) and self[i] == element + + +def copy_doc(original: Callable) -> Callable[[T], T]: + def decorator(overridden: T) -> T: + overridden.__doc__ = original.__doc__ + overridden.__signature__ = signature(original) # type: ignore + return overridden + + return decorator + + +class SequenceProxy(collections.abc.Sequence, Generic[T_co]): + """Read-only proxy of a Sequence.""" + + def __init__(self, proxied: Sequence[T_co]): + self.__proxied = proxied + + def __getitem__(self, idx: int) -> T_co: + return self.__proxied[idx] + + def __len__(self) -> int: + return len(self.__proxied) + + def __contains__(self, item: Any) -> bool: + return item in self.__proxied + + def __iter__(self) -> Iterator[T_co]: + return iter(self.__proxied) + + def __reversed__(self) -> Iterator[T_co]: + return reversed(self.__proxied) + + def index(self, value: Any, *args, **kwargs) -> int: + return self.__proxied.index(value, *args, **kwargs) + + def count(self, value: Any) -> int: + return self.__proxied.count(value) + + +class CachedSlotProperty(Generic[T, T_co]): + def __init__(self, name: str, function: Callable[[T], T_co]) -> None: + self.name = name + self.function = function + self.__doc__ = getattr(function, "__doc__") + + @overload + def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... + + @overload + def __get__(self, instance: T, owner: type[T]) -> T_co: ... + + def __get__(self, instance: T | None, owner: type[T]) -> Any: + if instance is None: + return self + + try: + return getattr(instance, self.name) + except AttributeError: + value = self.function(instance) + setattr(instance, self.name, value) + return value + + +def get_slots(cls: type[Any]) -> Iterator[str]: + for mro in reversed(cls.__mro__): + try: + yield from mro.__slots__ + except AttributeError: + continue + + +def cached_slot_property( + name: str, +) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: + def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: + return CachedSlotProperty(name, func) + + return decorator + + +try: + import msgspec + + def to_json(obj: Any) -> str: # type: ignore + return msgspec.json.encode(obj).decode("utf-8") + + from_json = msgspec.json.decode # type: ignore + +except ModuleNotFoundError: + import json + + def to_json(obj: Any) -> str: + return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) + + from_json = json.loads diff --git a/discord/utils/public.py b/discord/utils/public.py new file mode 100644 index 0000000000..dd476d50d5 --- /dev/null +++ b/discord/utils/public.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +import asyncio +import re +import datetime +from enum import Enum, auto +import itertools +from collections.abc import Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Any, Literal, TypeVar + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..commands.context import AutocompleteContext + from ..commands.options import OptionChoice + from ..permissions import Permissions + + +T = TypeVar("T") + + +class Undefined(Enum): + MISSING = auto() + + def __bool__(self) -> Literal[False]: + return False + + +MISSING: Literal[Undefined.MISSING] = Undefined.MISSING + +DISCORD_EPOCH = 1420070400000 + + +def utcnow() -> datetime.datetime: + """A helper function to return an aware UTC datetime representing the current time. + + This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware + datetime, compared to the naive datetime in the standard library. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`datetime.datetime` + The current aware datetime in UTC. + """ + return datetime.datetime.now(datetime.timezone.utc) + + +V = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] +AV = Awaitable[V] +Values = V | Callable[["AutocompleteContext"], V | AV] | AV +AutocompleteFunc = Callable[["AutocompleteContext"], AV] +FilterFunc = Callable[["AutocompleteContext", Any], bool | Awaitable[bool]] + + +def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: + """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and + will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is + callable, it will be called with the AutocompleteContext. + + This is meant to be passed into the :attr:`discord.Option.autocomplete` attribute. + + Parameters + ---------- + values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] + Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a + single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. + filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] + An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: + the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. + + .. versionadded:: 2.7 + + Returns + ------- + Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] + A wrapped callback for the autocomplete. + + Examples + -------- + + Basic usage: + + .. code-block:: python3 + + Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"))) + + # or + + + async def autocomplete(ctx): + return "foo", "bar", "baz", ctx.interaction.user.name + + + Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) + + With filter parameter: + + .. code-block:: python3 + + Option( + str, + "color", + autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i), + ) + + .. versionadded:: 2.0 + + Note + ---- + Autocomplete cannot be used for options that have specified choices. + """ + + async def autocomplete_callback(ctx: AutocompleteContext) -> V: + _values = values # since we reassign later, python considers it local if we don't do this + + if callable(_values): + _values = _values(ctx) + if asyncio.iscoroutine(_values): + _values = await _values + + if filter is None: + + def _filter(ctx: AutocompleteContext, item: Any) -> bool: + item = getattr(item, "name", item) + return str(item).lower().startswith(str(ctx.value or "").lower()) + + gen = (val for val in _values if _filter(ctx, val)) + + elif asyncio.iscoroutinefunction(filter): + gen = (val for val in _values if await filter(ctx, val)) + + elif callable(filter): + gen = (val for val in _values if filter(ctx, val)) + + else: + raise TypeError("``filter`` must be callable.") + + return iter(itertools.islice(gen, 25)) + + return autocomplete_callback + + +def generate_snowflake( + dt: datetime.datetime | None = None, + *, + mode: Literal["boundary", "realistic"] = "boundary", + high: bool = False, +) -> int: + """Returns a numeric snowflake pretending to be created at the given date. + + This function can generate both realistic snowflakes (for general use) and + boundary snowflakes (for range queries). + + Parameters + ---------- + dt: :class:`datetime.datetime` + A datetime object to convert to a snowflake. + If naive, the timezone is assumed to be local time. + If None, uses current UTC time. + mode: :class:`str` + The type of snowflake to generate: + - "realistic": Creates a snowflake with random-like lower bits + - "boundary": Creates a snowflake for range queries (default) + high: :class:`bool` + Only used when mode="boundary". Whether to set the lower 22 bits + to high (True) or low (False). Default is False. + + Returns + ------- + :class:`int` + The snowflake representing the time given. + + Examples + -------- + # Generate realistic snowflake + snowflake = generate_snowflake(dt) + + # Generate boundary snowflakes + lower_bound = generate_snowflake(dt, mode="boundary", high=False) + upper_bound = generate_snowflake(dt, mode="boundary", high=True) + + # For inclusive ranges: + # Lower: generate_snowflake(dt, mode="boundary", high=False) - 1 + # Upper: generate_snowflake(dt, mode="boundary", high=True) + 1 + """ + dt = dt or utcnow() + discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) + + if mode == "realistic": + return (discord_millis << 22) | 0x3FFFFF + elif mode == "boundary": + return (discord_millis << 22) + (2**22 - 1 if high else 0) + else: + raise ValueError(f"Invalid mode '{mode}'. Must be 'realistic' or 'boundary'") + + +def snowflake_time(id: int) -> datetime.datetime: + """Converts a Discord snowflake ID to a UTC-aware datetime object. + + Parameters + ---------- + id: :class:`int` + The snowflake ID. + + Returns + ------- + :class:`datetime.datetime` + An aware datetime in UTC representing the creation time of the snowflake. + """ + timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 + return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + + +def oauth_url( + client_id: int | str, + *, + permissions: Permissions | Undefined = MISSING, + guild: Snowflake | Undefined = MISSING, + redirect_uri: str | Undefined = MISSING, + scopes: Iterable[str] | Undefined = MISSING, + disable_guild_select: bool = False, +) -> str: + """A helper function that returns the OAuth2 URL for inviting the bot + into guilds. + + Parameters + ---------- + client_id: Union[:class:`int`, :class:`str`] + The client ID for your bot. + permissions: :class:`~discord.Permissions` + The permissions you're requesting. If not given then you won't be requesting any + permissions. + guild: :class:`~discord.abc.Snowflake` + The guild to pre-select in the authorization screen, if available. + redirect_uri: :class:`str` + An optional valid redirect URI. + scopes: Iterable[:class:`str`] + An optional valid list of scopes. Defaults to ``('bot',)``. + + .. versionadded:: 1.7 + disable_guild_select: :class:`bool` + Whether to disallow the user from changing the guild dropdown. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`str` + The OAuth2 URL for inviting the bot into guilds. + """ + url = f"https://discord.com/oauth2/authorize?client_id={client_id}" + url += f"&scope={'+'.join(scopes or ('bot',))}" + if permissions is not MISSING: + url += f"&permissions={permissions.value}" + if guild is not MISSING: + url += f"&guild_id={guild.id}" + if redirect_uri is not MISSING: + from urllib.parse import urlencode # noqa: PLC0415 + + url += f"&response_type=code&{urlencode({'redirect_uri': redirect_uri})}" + if disable_guild_select: + url += "&disable_guild_select=true" + return url + + +TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] + + +def format_dt(dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None) -> str: + """A helper function to format a :class:`datetime.datetime` for presentation within Discord. + + This allows for a locale-independent way of presenting data using Discord specific Markdown. + + +-------------+----------------------------+-----------------+ + | Style | Example Output | Description | + +=============+============================+=================+ + | t | 22:57 | Short Time | + +-------------+----------------------------+-----------------+ + | T | 22:57:58 | Long Time | + +-------------+----------------------------+-----------------+ + | d | 17/05/2016 | Short Date | + +-------------+----------------------------+-----------------+ + | D | 17 May 2016 | Long Date | + +-------------+----------------------------+-----------------+ + | f (default) | 17 May 2016 22:57 | Short Date Time | + +-------------+----------------------------+-----------------+ + | F | Tuesday, 17 May 2016 22:57 | Long Date Time | + +-------------+----------------------------+-----------------+ + | R | 5 years ago | Relative Time | + +-------------+----------------------------+-----------------+ + + Note that the exact output depends on the user's locale setting in the client. The example output + presented is using the ``en-GB`` locale. + + .. versionadded:: 2.0 + + Parameters + ---------- + dt: Union[:class:`datetime.datetime`, :class:`datetime.time`] + The datetime to format. + style: :class:`str`R + The style to format the datetime with. + + Returns + ------- + :class:`str` + The formatted string. + """ + if isinstance(dt, datetime.time): + dt = datetime.datetime.combine(datetime.datetime.now(), dt) + if style is None: + return f"" + return f"" + + +MENTION_PATTERN = re.compile(r"@(everyone|here|[!&]?[0-9]{17,20})") + + +def escape_mentions(text: str) -> str: + """A helper function that escapes everyone, here, role, and user mentions. + + .. note:: + + This does not include channel mentions. + + .. note:: + + For more granular control over what mentions should be escaped + within messages, refer to the :class:`~discord.AllowedMentions` + class. + + Parameters + ---------- + text: :class:`str` + The text to escape mentions from. + + Returns + ------- + :class:`str` + The text with the mentions removed. + """ + return MENTION_PATTERN.sub("@\u200b\\1", text) + + +RAW_MENTION_PATTERN = re.compile(r"<@!?([0-9]+)>") + + +def raw_mentions(text: str) -> list[int]: + """Returns a list of user IDs matching ``<@user_id>`` in the string. + + .. versionadded:: 2.2 + + Parameters + ---------- + text: :class:`str` + The string to get user mentions from. + + Returns + ------- + List[:class:`int`] + A list of user IDs found in the string. + """ + return [int(x) for x in RAW_MENTION_PATTERN.findall(text)] + + +RAW_CHANNEL_PATTERN = re.compile(r"<#([0-9]+)>") + + +def raw_channel_mentions(text: str) -> list[int]: + """Returns a list of channel IDs matching ``<@#channel_id>`` in the string. + + .. versionadded:: 2.2 + + Parameters + ---------- + text: :class:`str` + The string to get channel mentions from. + + Returns + ------- + List[:class:`int`] + A list of channel IDs found in the string. + """ + return [int(x) for x in RAW_CHANNEL_PATTERN.findall(text)] + + +RAW_ROLE_PATTERN = re.compile(r"<@&([0-9]+)>") + + +def raw_role_mentions(text: str) -> list[int]: + """Returns a list of role IDs matching ``<@&role_id>`` in the string. + + .. versionadded:: 2.2 + + Parameters + ---------- + text: :class:`str` + The string to get role mentions from. + + Returns + ------- + List[:class:`int`] + A list of role IDs found in the string. + """ + return [int(x) for x in RAW_ROLE_PATTERN.findall(text)] + + +_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?(?:>>)?\s|{_MARKDOWN_ESCAPE_LINKS}" + +_MARKDOWN_ESCAPE_REGEX = re.compile( + rf"(?P{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", + re.MULTILINE | re.X, +) + +_URL_REGEX = r"(?P<[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])" + +_MARKDOWN_STOCK_REGEX = rf"(?P[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})" + + +def remove_markdown(text: str, *, ignore_links: bool = True) -> str: + """A helper function that removes markdown characters. + + .. versionadded:: 1.7 + + .. note:: + This function is not markdown aware and may remove meaning from the original text. For example, + if the input contains ``10 * 5`` then it will be converted into ``10 5``. + + Parameters + ---------- + text: :class:`str` + The text to remove markdown from. + ignore_links: :class:`bool` + Whether to leave links alone when removing markdown. For example, + if a URL in the text contains characters such as ``_`` then it will + be left alone. Defaults to ``True``. + + Returns + ------- + :class:`str` + The text with the markdown special characters removed. + """ + + def replacement(match): + groupdict = match.groupdict() + return groupdict.get("url", "") + + regex = _MARKDOWN_STOCK_REGEX + if ignore_links: + regex = f"(?:{_URL_REGEX}|{regex})" + return re.sub(regex, replacement, text, 0, re.MULTILINE) + + +def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: + r"""A helper function that escapes Discord's markdown. + + Parameters + ----------- + text: :class:`str` + The text to escape markdown from. + as_needed: :class:`bool` + Whether to escape the markdown characters as needed. This + means that it does not escape extraneous characters if it's + not necessary, e.g. ``**hello**`` is escaped into ``\*\*hello**`` + instead of ``\*\*hello\*\*``. Note however that this can open + you up to some clever syntax abuse. Defaults to ``False``. + ignore_links: :class:`bool` + Whether to leave links alone when escaping markdown. For example, + if a URL in the text contains characters such as ``_`` then it will + be left alone. This option is not supported with ``as_needed``. + Defaults to ``True``. + + Returns + -------- + :class:`str` + The text with the markdown special characters escaped with a slash. + """ + + if not as_needed: + + def replacement(match): + groupdict = match.groupdict() + is_url = groupdict.get("url") + if is_url: + return is_url + return f"\\{groupdict['markdown']}" + + regex = _MARKDOWN_STOCK_REGEX + if ignore_links: + regex = f"(?:{_URL_REGEX}|{regex})" + return re.sub(regex, replacement, text, 0, re.MULTILINE | re.X) + else: + text = re.sub(r"\\", r"\\\\", text) + return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text) + + +def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: + """A helper to return the first element found in the sequence + that meets the predicate. For example: :: + + member = discord.utils.find(lambda m: m.name == "Mighty", channel.guild.members) + + would find the first :class:`~discord.Member` whose name is 'Mighty' and return it. + If an entry is not found, then ``None`` is returned. + + This is different from :func:`py:filter` due to the fact it stops the moment it finds + a valid entry. + + Parameters + ---------- + predicate + A function that returns a boolean-like result. + seq: :class:`collections.abc.Iterable` + The iterable to search through. + """ + + for element in seq: + if predicate(element): + return element + return None diff --git a/discord/voice_client.py b/discord/voice_client.py index 09f166658b..a8c554e512 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -56,6 +56,7 @@ from .player import AudioPlayer, AudioSource from .sinks import RawData, RecordingException, Sink from .utils import MISSING +from .utils.private import sane_wait_for if TYPE_CHECKING: from . import abc @@ -392,7 +393,7 @@ async def connect(self, *, reconnect: bool, timeout: float) -> None: await self.voice_connect() try: - await utils.sane_wait_for(futures, timeout=timeout) + await sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: await self.disconnect(force=True) raise diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 8256592d77..0be4163568 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -36,6 +36,7 @@ import aiohttp +from ..utils.private import bytes_to_base64_data, get_as_snowflake, parse_ratelimit_header, to_json from .. import utils from ..asset import Asset from ..channel import ForumChannel, PartialMessageable @@ -134,7 +135,7 @@ async def request( if payload is not None: headers["Content-Type"] = "application/json" - to_send = utils._to_json(payload) + to_send = to_json(payload) if auth_token is not None: headers["Authorization"] = f"Bot {auth_token}" @@ -181,7 +182,7 @@ async def request( remaining = response.headers.get("X-Ratelimit-Remaining") if remaining == "0" and response.status != 429: - delta = utils._parse_ratelimit_header(response) + delta = parse_ratelimit_header(response) _log.debug( ("Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds"), webhook_id, @@ -508,7 +509,7 @@ def create_interaction_response( ) if attachments: payload["data"]["attachments"] = attachments - form[0]["value"] = utils._to_json(payload) + form[0]["value"] = to_json(payload) route = Route( "POST", @@ -700,7 +701,7 @@ def handle_message_parameters( payload["thread_name"] = thread_name if multipart_files: - multipart.append({"name": "payload_json", "value": utils._to_json(payload)}) + multipart.append({"name": "payload_json", "value": to_json(payload)}) payload = None multipart += multipart_files @@ -1006,8 +1007,8 @@ def __init__( def _update(self, data: WebhookPayload | FollowerWebhookPayload): self.id = int(data["id"]) self.type = try_enum(WebhookType, int(data["type"])) - self.channel_id = utils._get_as_snowflake(data, "channel_id") - self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.channel_id = get_as_snowflake(data, "channel_id") + self.guild_id = get_as_snowflake(data, "guild_id") self.name = data.get("name") self._avatar = data.get("avatar") self.token = data.get("token") @@ -1493,7 +1494,7 @@ async def edit( payload["name"] = str(name) if name is not None else None if avatar is not MISSING: - payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None + payload["avatar"] = bytes_to_base64_data(avatar) if avatar is not None else None adapter = async_context.get() diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index 6808f745ac..57ee45657c 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -41,6 +41,7 @@ from urllib.parse import quote as urlquote from .. import utils +from ..utils.private import parse_ratelimit_header, bytes_to_base64_data, to_json from ..channel import PartialMessageable from ..errors import ( DiscordServerError, @@ -124,7 +125,7 @@ def request( if payload is not None: headers["Content-Type"] = "application/json" - to_send = utils._to_json(payload) + to_send = to_json(payload) if auth_token is not None: headers["Authorization"] = f"Bot {auth_token}" @@ -183,7 +184,7 @@ def request( remaining = response.headers.get("X-Ratelimit-Remaining") if remaining == "0" and response.status_code != 429: - delta = utils._parse_ratelimit_header(response) + delta = parse_ratelimit_header(response) _log.debug( ("Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds"), webhook_id, @@ -846,7 +847,7 @@ def edit( payload["name"] = str(name) if name is not None else None if avatar is not MISSING: - payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None + payload["avatar"] = bytes_to_base64_data(avatar) if avatar is not None else None adapter: WebhookAdapter = _get_webhook_adapter() diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py index f7d630d4aa..eda968becc 100644 --- a/discord/welcome_screen.py +++ b/discord/welcome_screen.py @@ -28,7 +28,8 @@ from typing import TYPE_CHECKING, overload from .partial_emoji import _EmojiTag -from .utils import _get_as_snowflake, get +from . import utils +from .utils.private import get_as_snowflake if TYPE_CHECKING: from .abc import Snowflake @@ -96,13 +97,13 @@ def to_dict(self) -> WelcomeScreenChannelPayload: @classmethod def _from_dict(cls, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeScreenChannel: - channel_id = _get_as_snowflake(data, "channel_id") + channel_id = get_as_snowflake(data, "channel_id") channel = guild.get_channel(channel_id) description = data.get("description") - _emoji_id = _get_as_snowflake(data, "emoji_id") + _emoji_id = get_as_snowflake(data, "emoji_id") _emoji_name = data.get("emoji_name") - emoji = get(guild.emojis, id=_emoji_id) if _emoji_id else _emoji_name + emoji = utils.find(lambda e: e.id == _emoji_id, guild.emojis) if _emoji_id else _emoji_name return cls(channel=channel, description=description, emoji=emoji) # type: ignore @@ -192,7 +193,7 @@ async def edit(self, **options): rules_channel = guild.get_channel(12345678) announcements_channel = guild.get_channel(87654321) - custom_emoji = utils.get(guild.emojis, name="loudspeaker") + custom_emoji = utils.find(lambda e: e.name == "loudspeaker", guild.emojis) await welcome_screen.edit( description="This is a very cool community server!", welcome_channels=[ diff --git a/discord/widget.py b/discord/widget.py index 004b87d66e..eb54e606b3 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -31,7 +31,8 @@ from .enums import Status, try_enum from .invite import Invite from .user import BaseUser -from .utils import _get_as_snowflake, resolve_invite, snowflake_time +from .utils import snowflake_time +from .utils.private import resolve_invite, get_as_snowflake if TYPE_CHECKING: import datetime @@ -265,7 +266,7 @@ def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: self.members: list[WidgetMember] = [] channels = {channel.id: channel for channel in self.channels} for member in data.get("members", []): - connected_channel = _get_as_snowflake(member, "channel_id") + connected_channel = get_as_snowflake(member, "channel_id") if connected_channel in channels: connected_channel = channels[connected_channel] # type: ignore elif connected_channel: diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css index 57ef457b9e..cd2b41d9a2 100644 --- a/docs/_static/css/custom.css +++ b/docs/_static/css/custom.css @@ -8,9 +8,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/saira/v8/memWYa2wxmKQyPMrZX79wwYZQMhsyuShhKMjjbU9uXuA773FksAxljYm.woff2) format("woff2"); - unicode-range: - U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, - U+1EA0-1EF9, U+20AB; + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, + U+01AF-01B0, U+1EA0-1EF9, U+20AB; } /* latin-ext */ @font-face { @@ -21,9 +20,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/saira/v8/memWYa2wxmKQyPMrZX79wwYZQMhsyuShhKMjjbU9uXuA773FksExljYm.woff2) format("woff2"); - unicode-range: - U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, - U+2C60-2C7F, U+A720-A7FF; + unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, + U+2113, U+2C60-2C7F, U+A720-A7FF; } /* latin */ @font-face { @@ -34,9 +32,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/saira/v8/memWYa2wxmKQyPMrZX79wwYZQMhsyuShhKMjjbU9uXuA773Fks8xlg.woff2) format("woff2"); - unicode-range: - U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, - U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, + U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* latin */ @font-face { @@ -46,9 +43,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/outfit/v4/QGYyz_MVcBeNP4NjuGObqx1XmO1I4W61O4a0Ew.woff2) format("woff2"); - unicode-range: - U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, - U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, + U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* latin */ @font-face { @@ -58,9 +54,8 @@ font-display: swap; src: url(https://fonts.gstatic.com/s/outfit/v4/QGYyz_MVcBeNP4NjuGObqx1XmO1I4deyO4a0Ew.woff2) format("woff2"); - unicode-range: - U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, - U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, + U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } /* attribute tables */ diff --git a/docs/api/async_iter.rst b/docs/api/async_iter.rst index 22967bab28..a958b94cb3 100644 --- a/docs/api/async_iter.rst +++ b/docs/api/async_iter.rst @@ -33,17 +33,6 @@ Certain utilities make working with async iterators easier, detailed below. Advances the iterator by one, if possible. If no more items are found then this raises :exc:`NoMoreItems`. - .. method:: get(**attrs) - :async: - - |coro| - - Similar to :func:`utils.get` except run over the async iterator. - - Getting the last message by a user named 'Dave' or ``None``: :: - - msg = await channel.history().get(author__name='Dave') - .. method:: find(predicate) :async: diff --git a/docs/api/utils.rst b/docs/api/utils.rst index e426c2f2d0..db6930cf0d 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -7,7 +7,6 @@ Utility Functions .. autofunction:: discord.utils.find -.. autofunction:: discord.utils.get .. autofunction:: discord.utils.get_or_fetch @@ -25,30 +24,12 @@ Utility Functions .. autofunction:: discord.utils.raw_role_mentions -.. autofunction:: discord.utils.resolve_invite - -.. autofunction:: discord.utils.resolve_template - -.. autofunction:: discord.utils.sleep_until - .. autofunction:: discord.utils.utcnow .. autofunction:: discord.utils.snowflake_time -.. autofunction:: discord.utils.parse_time - .. autofunction:: discord.utils.format_dt -.. autofunction:: discord.utils.time_snowflake - .. autofunction:: discord.utils.generate_snowflake .. autofunction:: discord.utils.basic_autocomplete - -.. autofunction:: discord.utils.as_chunks - -.. autofunction:: discord.utils.filter_params - -.. autofunction:: discord.utils.warn_deprecated - -.. autofunction:: discord.utils.deprecated diff --git a/docs/faq.rst b/docs/faq.rst index bf8bfe82d0..c1ba99f9ce 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -214,7 +214,7 @@ Quick example: :: await message.add_reaction(emoji) # no ID, do a lookup - emoji = discord.utils.get(guild.emojis, name='LUL') + emoji = discord.utils.find(lambda e: e.name == "LUL", guild.emojis) if emoji: await message.add_reaction(emoji) @@ -279,13 +279,13 @@ The following use an HTTP request: - :meth:`Guild.fetch_member` -If the functions above do not help you, then use of :func:`utils.find` or :func:`utils.get` would serve some use in finding +If the functions above do not help you, then use of :func:`utils.find` would serve some use in finding specific models. Quick example: :: # find a guild by name - guild = discord.utils.get(client.guilds, name='My Server') + guild = discord.utils.find(lambda g: g.name == "My Server", client.guilds) # make sure to check if it's found if guild is not None: diff --git a/tests/test_utils.py b/tests/test_utils.py index d0f94acb93..61931451a0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,16 +30,9 @@ from discord.utils import ( MISSING, - _cached_property, - _parse_ratelimit_header, - _unique, - async_all, - copy_doc, find, - get, - maybe_coroutine, + generate_snowflake, snowflake_time, - time_snowflake, utcnow, ) @@ -79,23 +72,6 @@ def test_temporary(): # assert not MISSING # assert repr(MISSING) == '...' # -# -# def test_cached_property() -> None: -# class Test: -# def __init__(self, x: int): -# self.x = x -# -# @_cached_property -# def foo(self) -> int: -# self.x += 1 -# return self.x -# -# t = Test(0) -# assert isinstance(_cached_property.__get__(_cached_property(None), None, None), _cached_property) -# assert t.foo == 1 -# assert t.foo == 1 -# -# # def test_find_get() -> None: # class Obj: # def __init__(self, value: int):