diff --git a/CHANGELOG.md b/CHANGELOG.md index 718642108d..253c13e547 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2775](https://github.com/Pycord-Development/pycord/pull/2775)) - Added `discord.Interaction.created_at`. ([#2801](https://github.com/Pycord-Development/pycord/pull/2801)) +- Added `Guild.get_or_fetch()` and `Client.get_or_fetch()` shortcut methods. + ([#2776](https://github.com/Pycord-Development/pycord/pull/2776)) - Added `User.nameplate` property. ([#2817](https://github.com/Pycord-Development/pycord/pull/2817)) - Added role gradients support with `Role.colours` and the `RoleColours` class. @@ -168,6 +170,9 @@ These changes are available on the `master` branch, but have not yet been releas ([#2501](https://github.com/Pycord-Development/pycord/pull/2501)) - Deprecated `Interaction.cached_channel` in favor of `Interaction.channel`. ([#2658](https://github.com/Pycord-Development/pycord/pull/2658)) +- Deprecated `utils.get_or_fetch(attr, id)` and `Client.get_or_fetch_user(id)` in favor + of `utils.get_or_fetch(object_type, object_id)`. + ([#2776](https://github.com/Pycord-Development/pycord/pull/2776)) ### Removed diff --git a/discord/client.py b/discord/client.py index 2d0f1b8770..bd83326dbd 100644 --- a/discord/client.py +++ b/discord/client.py @@ -31,7 +31,15 @@ import sys import traceback from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generator, Sequence, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Generator, + Sequence, + TypeVar, +) import aiohttp @@ -60,18 +68,26 @@ from .threads import Thread from .ui.view import View from .user import ClientUser, User -from .utils import MISSING +from .utils import _D, _FETCHABLE, MISSING from .voice_client import VoiceClient from .webhook import Webhook from .widget import Widget if TYPE_CHECKING: from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime - from .channel import DMChannel + from .channel import ( + CategoryChannel, + DMChannel, + ForumChannel, + StageChannel, + TextChannel, + VoiceChannel, + ) from .interaction import Interaction from .member import Member from .message import Message from .poll import Poll + from .threads import Thread, ThreadMember from .ui.item import Item from .voice_client import VoiceProtocol @@ -1147,7 +1163,12 @@ def get_all_members(self) -> Generator[Member]: for guild in self.guilds: yield from guild.members - async def get_or_fetch_user(self, id: int, /) -> User | None: + @utils.deprecated( + instead="Client.get_or_fetch(User, id)", + since="2.7", + removed="3.0", + ) + async def get_or_fetch_user(self, id: int, /) -> User | None: # TODO: Remove in 3.0 """|coro| Looks up a user in the user cache or fetches if not found. @@ -1163,7 +1184,47 @@ async def get_or_fetch_user(self, id: int, /) -> User | None: The user or ``None`` if not found. """ - return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None) + return await self.get_or_fetch(object_type=User, object_id=id, default=None) + + async def get_or_fetch( + self: Client, + object_type: type[_FETCHABLE], + object_id: int | None, + default: _D = None, + ) -> _FETCHABLE | _D | None: + """ + Shortcut method to get data from an object either by returning the cached version, or if it does not exist, attempting to fetch it from the API. + + Parameters + ---------- + object_type: VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | User | Guild | GuildEmoji | AppEmoji + Type of object to fetch or get. + object_id: int | None + ID of object to get. If None, returns default if provided, else None. + default: Any | None + A default to return instead of raising if fetch fails. + + Returns + ------- + VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | User | Guild | GuildEmoji | AppEmoji | None + The object if found, or `default` if provided when not found. + + Raises + ------ + :exc:`TypeError` + Raised when required parameters are missing or invalid types are provided. + :exc:`InvalidArgument` + Raised when an unsupported or incompatible object type is used. + """ + try: + return await utils.get_or_fetch( + obj=self, + object_type=object_type, + object_id=object_id, + default=default, + ) + except (HTTPException, ValueError): + return default # listeners/waiters diff --git a/discord/guild.py b/discord/guild.py index 488f4aa074..e13f8db815 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -63,7 +63,7 @@ VoiceRegion, try_enum, ) -from .errors import ClientException, InvalidArgument, InvalidData +from .errors import ClientException, HTTPException, InvalidArgument, InvalidData from .file import File from .flags import SystemChannelFlags from .integrations import Integration, _integration_factory @@ -85,6 +85,7 @@ from .sticker import GuildSticker from .threads import Thread, ThreadMember from .user import User +from .utils import _D, _FETCHABLE from .welcome_screen import WelcomeScreen, WelcomeScreenChannel from .widget import Widget @@ -863,6 +864,49 @@ def get_member(self, user_id: int, /) -> Member | None: """ return self._members.get(user_id) + async def get_or_fetch( + self: Guild, + object_type: type[_FETCHABLE], + object_id: int | None, + default: _D = None, + ) -> _FETCHABLE | _D | None: + """ + Shortcut method to get data from this guild either by returning the cached version, + or if it does not exist, attempting to fetch it from the API. + + Parameters + ---------- + object_type: VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Role | Member | GuildEmoji + Type of object to fetch or get. + + object_id: :class:`int` | None + ID of the object to get. If ``None``, returns ``default`` if provided, otherwise ``None``. + + default : Any | None + The value to return instead of raising if fetching fails or if ``object_id`` is ``None``. + + Returns + ------- + VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Role | Member | GuildEmoji | None + The object if found, or `default` if provided when not found. + + Raises + ------ + :exc:`TypeError` + Raised when required parameters are missing or invalid types are provided. + :exc:`InvalidArgument` + Raised when an unsupported or incompatible object type is used. + """ + try: + return await utils.get_or_fetch( + obj=self, + object_type=object_type, + object_id=object_id, + default=default, + ) + except (HTTPException, ValueError): + return default + @property def premium_subscribers(self) -> list[Member]: """A list of members who have "boosted" this guild.""" @@ -2699,6 +2743,26 @@ async def delete_sticker( """ await self._state.http.delete_guild_sticker(self.id, sticker.id, reason) + def get_emoji(self, emoji_id: int, /) -> GuildEmoji | None: + """Returns an emoji with the given ID. + + .. versionadded:: 2.7 + + Parameters + ---------- + emoji_id: int + The ID to search for. + + Returns + ------- + Optional[:class:`Emoji`] + The returned Emoji or ``None`` if not found. + """ + emoji = self._state.get_emoji(emoji_id) + if emoji and emoji.guild == self: + return emoji + return None + async def fetch_emojis(self) -> list[GuildEmoji]: r"""|coro| diff --git a/discord/utils.py b/discord/utils.py index 17cee5fc10..b646a93f14 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -63,6 +63,23 @@ overload, ) +if TYPE_CHECKING: + from discord import ( + Client, + VoiceChannel, + TextChannel, + ForumChannel, + StageChannel, + CategoryChannel, + Thread, + Member, + User, + Guild, + Role, + GuildEmoji, + AppEmoji, + ) + from .errors import HTTPException, InvalidArgument try: @@ -98,6 +115,7 @@ "generate_snowflake", "basic_autocomplete", "filter_params", + "MISSING", ) DISCORD_EPOCH = 1420070400000 @@ -577,64 +595,225 @@ def get(iterable: Iterable[T], **attrs: Any) -> T | None: return None -async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: - """|coro| +_FETCHABLE = TypeVar( + "_FETCHABLE", + bound="VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | Member | User | Guild | Role | GuildEmoji | AppEmoji", +) +_D = TypeVar("_D") +_Getter = Callable[[Any, int], Any] +_Fetcher = Callable[[Any, int], Awaitable[Any]] - Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. - If the fetch also fails, an error will be raised. + +# TODO: In version 3.0, remove the 'attr' and 'id' arguments. +# Also, eliminate the default 'MISSING' value for both 'object_type' and 'object_id'. +@overload +async def get_or_fetch( + obj: Guild | Client, + object_type: type[_FETCHABLE], + object_id: Literal[None], + default: _D = ..., + attr: str = ..., + id: int = ..., +) -> None | _D: ... + + +@overload +async def get_or_fetch( + obj: Guild | Client, + object_type: type[_FETCHABLE], + object_id: int, + default: _D, + attr: str = ..., + id: int = ..., +) -> _FETCHABLE | _D: ... + + +@overload +async def get_or_fetch( + obj: Guild | Client, + object_type: type[_FETCHABLE], + object_id: int, + *, + attr: str = ..., + id: int = ..., +) -> _FETCHABLE: ... + + +async def get_or_fetch( + obj: Guild | Client, + object_type: type[_FETCHABLE] = MISSING, + object_id: int | None = MISSING, + default: _D = MISSING, + attr: str = MISSING, + id: int = MISSING, +) -> _FETCHABLE | _D | None: + """ + Shortcut method to get data from an object either by returning the cached version, or if it does not exist, attempting to fetch it from the API. Parameters ---------- - obj: Any - The object to use the get or fetch methods in - attr: :class:`str` - The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. - id: :class:`int` - The ID of the object - default: Any - The default value to return if the object is not found, instead of raising an error. + obj : Guild | Client + The object to operate on. + object_type: VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | User | Guild | Role | Member | GuildEmoji | AppEmoji + Type of object to fetch or get. + + object_id: int | None + ID of object to get. + + default : Any | None + The value to return instead of raising if fetching fails. Returns ------- - Any - The object found or the default value. + + VoiceChannel | TextChannel | ForumChannel | StageChannel | CategoryChannel | Thread | User | Guild | Role | Member | GuildEmoji | AppEmoji | None + The object if found, or `default` if provided when not found. + Returns `None` only if `object_id` is None and no `default` is given. Raises ------ - :exc:`AttributeError` - The object is missing a ``get_`` or ``fetch_`` method + :exc:`TypeError` + Raised when required parameters are missing or invalid types are provided. + :exc:`InvalidArgument` + Raised when an unsupported or incompatible object type is used. :exc:`NotFound` - Invalid ID for the object + Invalid ID for the object. :exc:`HTTPException` - An error occurred fetching the object + An error occurred fetching the object. :exc:`Forbidden` - You do not have permission to fetch the object + You do not have permission to fetch the object. + """ + from discord import Client, Guild, Member, Role, User + + if object_id is None: + return default if default is not MISSING else None + + # Temporary backward compatibility for 'attr' and 'id'. + # This entire if block should be removed in version 3.0. + if attr is not MISSING or id is not MISSING or isinstance(object_type, str): + warn_deprecated( + name="get_or_fetch(obj, attr='type', id=...)", + instead="get_or_fetch(obj, object_type=Type, object_id=...)", + since="2.7", + removed="3.0", + ) - Examples - -------- + deprecated_attr = attr if attr is not MISSING else object_type + deprecated_id = id if id is not MISSING else object_id + + if isinstance(deprecated_attr, str): + mapped_type = _get_string_to_type_map().get(deprecated_attr.lower()) + if mapped_type is None: + raise InvalidArgument( + f"Unknown type string '{deprecated_attr}' used. Please use a valid class like `discord.Member` instead." + ) + object_type = mapped_type + elif isinstance(deprecated_attr, type): + object_type = deprecated_attr + else: + raise TypeError( + f"Invalid `attr` or `object_type`: expected a string or class, got {type(deprecated_attr).__name__}." + ) - Getting a guild from a guild ID: :: + object_id = deprecated_id - guild = await utils.get_or_fetch(client, 'guild', guild_id) + if object_type is MISSING or object_id is MISSING: + raise TypeError("required parameters: `object_type` and `object_id`.") - Getting a channel from the guild. If the channel is not found, return None: :: + if isinstance(obj, Guild) and object_type is User: + raise InvalidArgument( + "Guild cannot get_or_fetch discord.User. Use Client instead." + ) + elif isinstance(obj, Client) and object_type is Member: + raise InvalidArgument("Client cannot get_or_fetch Member. Use Guild instead.") + elif isinstance(obj, Client) and object_type is Role: + raise InvalidArgument("Client cannot get_or_fetch Role. Use Guild instead.") + elif isinstance(obj, Guild) and object_type is Guild: + raise InvalidArgument("Guild cannot get_or_fetch Guild. Use Client instead.") - channel = await utils.get_or_fetch(guild, 'channel', channel_id, default=None) - """ - getter = getattr(obj, f"get_{attr}")(id) - if getter is None: - try: - getter = await getattr(obj, f"fetch_{attr}")(id) - except AttributeError: - getter = await getattr(obj, f"_fetch_{attr}")(id) - if getter is None: - raise ValueError(f"Could not find {attr} with id {id} on {obj}") - except (HTTPException, ValueError): - if default is not MISSING: - return default - else: - raise - return getter + try: + getter, fetcher = _get_getter_fetcher_map()[object_type] + except KeyError: + raise InvalidArgument( + f"Class {object_type.__name__} cannot be used with discord.{type(obj).__name__}.get_or_fetch()" + ) + + result = getter(obj, object_id) + if result is not None: + return result + + try: + return await fetcher(obj, object_id) + except (HTTPException, ValueError): + if default is not MISSING: + return default + raise + + +@functools.lru_cache(maxsize=1) +def _get_string_to_type_map() -> dict[str, type]: + """Return a cached map of lowercase strings -> discord types.""" + from discord import Guild, Member, Role, User, abc, emoji + + return { + "channel": abc.GuildChannel, + "member": Member, + "user": User, + "guild": Guild, + "emoji": emoji._EmojiTag, + "appemoji": AppEmoji, + "role": Role, + } + + +@functools.lru_cache(maxsize=1) +def _get_getter_fetcher_map() -> dict[type, tuple[_Getter, _Fetcher]]: + """Return a cached map of type names -> (getter, fetcher) functions.""" + from discord import Guild, Member, Role, User, abc, emoji + + base_map: dict[type, tuple[_Getter, _Fetcher]] = { + Member: ( + lambda obj, oid: obj.get_member(oid), + lambda obj, oid: obj.fetch_member(oid), + ), + Role: ( + lambda obj, oid: obj.get_role(oid), + lambda obj, oid: obj._fetch_role(oid), + ), + User: ( + lambda obj, oid: obj.get_user(oid), + lambda obj, oid: obj.fetch_user(oid), + ), + Guild: ( + lambda obj, oid: obj.get_guild(oid), + lambda obj, oid: obj.fetch_guild(oid), + ), + emoji._EmojiTag: ( + lambda obj, oid: obj.get_emoji(oid), + lambda obj, oid: obj.fetch_emoji(oid), + ), + abc.GuildChannel: ( + lambda obj, oid: obj.get_channel(oid), + lambda obj, oid: obj.fetch_channel(oid), + ), + } + + expanded: dict[type, tuple[_Getter, _Fetcher]] = {} + for base, funcs in base_map.items(): + expanded[base] = funcs + for subclass in _all_subclasses(base): + if subclass not in expanded: + expanded[subclass] = funcs + + return expanded + + +def _all_subclasses(cls: type) -> set[type]: + """Recursively collect all subclasses of a class.""" + subs = set(cls.__subclasses__()) + for sub in cls.__subclasses__(): + subs |= _all_subclasses(sub) + return subs def _unique(iterable: Iterable[T]) -> list[T]: