diff --git a/discord/client.py b/discord/client.py index 93036aacde..71aed9351a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -63,7 +63,12 @@ from .ui.view import View from .user import ClientUser, User from .utils import MISSING -from .utils.private import SequenceProxy, bytes_to_base64_data, resolve_invite, resolve_template +from .utils.private import ( + SequenceProxy, + bytes_to_base64_data, + resolve_invite, + resolve_template, +) from .voice_client import VoiceClient from .webhook import Webhook from .widget import Widget @@ -629,7 +634,10 @@ async def login(self, token: str) -> None: data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) - print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or "discord") + print_banner( + bot_name=self._connection.user.display_name, + module=self._banner_module or "discord", + ) start_logging(self._flavor, debug=self._debug) async def connect(self, *, reconnect: bool = True) -> None: @@ -1130,24 +1138,6 @@ def get_all_members(self) -> Generator[Member]: for guild in self.guilds: yield from guild.members - async def get_or_fetch_user(self, id: int, /) -> User | None: - """|coro| - - Looks up a user in the user cache or fetches if not found. - - Parameters - ---------- - id: :class:`int` - The ID to search for. - - Returns - ------- - Optional[:class:`~discord.User`] - The user or ``None`` if not found. - """ - - return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None) - # listeners/waiters async def wait_until_ready(self) -> None: diff --git a/discord/guild.py b/discord/guild.py index 3b43d61af9..56250ecbc1 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -1020,7 +1020,10 @@ def get_member_named(self, name: str, /) -> Member | None: # do the actual lookup and return if found # if it isn't found then we'll do a full name lookup below. - result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members) + result = utils.find( + lambda m: m.name == name[:-5] and discriminator == potential_discriminator, + members, + ) if result is not None: return result @@ -3312,7 +3315,10 @@ async def widget(self) -> Widget: return Widget(state=self._state, data=data) async def edit_widget( - self, *, enabled: bool | utils.Undefined = MISSING, channel: Snowflake | None | utils.Undefined = MISSING + self, + *, + enabled: bool | utils.Undefined = MISSING, + channel: Snowflake | None | utils.Undefined = MISSING, ) -> None: """|coro| diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index 888646f2ef..c58d06ba7e 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -25,11 +25,6 @@ from __future__ import annotations -from typing import ( - Any, -) - -from ..errors import HTTPException from .public import ( MISSING, UNICODE_EMOJIS, @@ -56,7 +51,6 @@ "oauth_url", "snowflake_time", "find", - "get_or_fetch", "utcnow", "remove_markdown", "escape_markdown", @@ -71,63 +65,3 @@ "MISSING", "UNICODE_EMOJIS", ) - - -async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: - """|coro| - - Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. - If the fetch also fails, an error will be raised. - - Parameters - ---------- - obj: Any - The object to use the get or fetch methods in - attr: :class:`str` - The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. - id: :class:`int` - The ID of the object - default: Any - The default value to return if the object is not found, instead of raising an error. - - Returns - ------- - Any - The object found or the default value. - - Raises - ------ - :exc:`AttributeError` - The object is missing a ``get_`` or ``fetch_`` method - :exc:`NotFound` - Invalid ID for the object - :exc:`HTTPException` - An error occurred fetching the object - :exc:`Forbidden` - You do not have permission to fetch the object - - Examples - -------- - - Getting a guild from a guild ID: :: - - guild = await utils.get_or_fetch(client, "guild", guild_id) - - Getting a channel from the guild. If the channel is not found, return None: :: - - channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) - """ - getter = getattr(obj, f"get_{attr}")(id) - if getter is None: - try: - getter = await getattr(obj, f"fetch_{attr}")(id) - except AttributeError as e: - getter = await getattr(obj, f"_fetch_{attr}")(id) - if getter is None: - raise ValueError(f"Could not find {attr} with id {id} on {obj}") from e - except (HTTPException, ValueError): - if default is not MISSING: - return default - else: - raise - return getter diff --git a/discord/utils/private.py b/discord/utils/private.py index 5eb248b153..7ed3612f9e 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -28,6 +28,7 @@ Sequence, TypeVar, Union, + get_args, overload, ) @@ -69,9 +70,6 @@ def resolve_invite(invite: Invite | str) -> str: return invite -__all__ = ("resolve_invite",) - - def get_as_snowflake(data: Any, key: str) -> int | None: try: value = data[key] @@ -111,7 +109,7 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return (reset - now).total_seconds() -def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: +def string_width(string: str, *, _IS_ASCII: re.Pattern[str] = _IS_ASCII) -> int: """Returns string's width.""" match = _IS_ASCII.match(string) if match: @@ -138,7 +136,7 @@ def resolve_template(code: Template | str) -> str: :class:`str` The template code. """ - from ..template import Template # noqa: PLC0415 # circular import + from ..template import Template # noqa: PLC0415 if isinstance(code, Template): return code.code @@ -167,10 +165,6 @@ def parse_time(timestamp: None) -> None: ... def parse_time(timestamp: str) -> datetime.datetime: ... -@overload -def parse_time(timestamp: str | None) -> datetime.datetime | None: ... - - def parse_time(timestamp: str | None) -> datetime.datetime | None: """A helper function to convert an ISO 8601 timestamp to a datetime object. @@ -218,7 +212,6 @@ def warn_deprecated( stacklevel: :class:`int` The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. """ - warnings.simplefilter("always", DeprecationWarning) # turn off filter message = f"{name} is deprecated" if since: message += f" since version {since}" @@ -231,7 +224,6 @@ def warn_deprecated( message += f" See {reference} for more information." warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning) - warnings.simplefilter("default", DeprecationWarning) # reset filter def deprecated( @@ -242,7 +234,7 @@ def deprecated( stacklevel: int = 3, *, use_qualname: bool = True, -) -> Callable[[Callable[[P], T]], Callable[[P], T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when the decorated function is called. @@ -267,7 +259,7 @@ def deprecated( will display as ``login``. Defaults to ``True``. """ - def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: + def actual_decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) def decorated(*args: P.args, **kwargs: P.kwargs) -> T: warn_deprecated( @@ -289,11 +281,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T: def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]: - params = [] + params: list[Any] = [] literal_cls = type(Literal[0]) for p in parameters: if isinstance(p, literal_cls): - params.extend(p.__args__) + params.extend(get_args(p)) else: params.append(p) return tuple(params) @@ -311,7 +303,7 @@ def evaluate_annotation( cache: dict[str, Any], *, implicit_str: bool = True, -): +) -> Any: if isinstance(tp, ForwardRef): tp = tp.__forward_arg__ # ForwardRefs always evaluate their internals @@ -329,8 +321,8 @@ def evaluate_annotation( is_literal = False args = tp.__args__ if not hasattr(tp, "__origin__"): - if PY_310 and tp.__class__ is types.UnionType: # type: ignore - converted = Union[args] # type: ignore # noqa: UP007 + if PY_310 and tp.__class__ is types.UnionType: + converted = Union[args] # noqa: UP007 return evaluate_annotation(converted, globals, locals, cache) return tp @@ -381,7 +373,7 @@ def resolve_annotation( return evaluate_annotation(annotation, globalns, locals, cache) -def delay_task(delay: float, func: Coroutine): +def delay_task(delay: float, func: Coroutine[Any, Any, Any]): async def inner_call(): await asyncio.sleep(delay) try: @@ -408,7 +400,7 @@ async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwa return value -async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]: +async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]: ensured = [asyncio.ensure_future(fut) for fut in futures] done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) @@ -418,7 +410,7 @@ async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> s return done -class SnowflakeList(array.array): +class SnowflakeList(array.array[int]): """Internal data storage class to efficiently store a list of snowflakes. This should have the following characteristics: @@ -437,7 +429,7 @@ class SnowflakeList(array.array): def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): - return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore + return super().__new__(cls, "Q", data if is_sorted else sorted(data)) def add(self, element: int) -> None: i = bisect_left(self, element) @@ -452,22 +444,27 @@ def has(self, element: int) -> bool: return i != len(self) and self[i] == element -def copy_doc(original: Callable) -> Callable[[T], T]: +def copy_doc(original: Callable[..., object]) -> Callable[[T], T]: def decorator(overridden: T) -> T: overridden.__doc__ = original.__doc__ - overridden.__signature__ = signature(original) # type: ignore + overridden.__signature__ = signature(original) # type: ignore[reportAttributeAccessIssue] return overridden return decorator -class SequenceProxy(collections.abc.Sequence, Generic[T_co]): +class SequenceProxy(collections.abc.Sequence[T_co], Generic[T_co]): """Read-only proxy of a Sequence.""" def __init__(self, proxied: Sequence[T_co]): self.__proxied = proxied - def __getitem__(self, idx: int) -> T_co: + @overload + def __getitem__(self, idx: int) -> T_co: ... + @overload + def __getitem__(self, idx: slice) -> Sequence[T_co]: ... + + def __getitem__(self, idx: int | slice) -> T_co | Sequence[T_co]: return self.__proxied[idx] def __len__(self) -> int: @@ -482,7 +479,7 @@ def __iter__(self) -> Iterator[T_co]: def __reversed__(self) -> Iterator[T_co]: return reversed(self.__proxied) - def index(self, value: Any, *args, **kwargs) -> int: + def index(self, value: Any, *args: Any, **kwargs: Any) -> int: return self.__proxied.index(value, *args, **kwargs) def count(self, value: Any) -> int: @@ -515,10 +512,10 @@ def __get__(self, instance: T | None, owner: type[T]) -> Any: def get_slots(cls: type[Any]) -> Iterator[str]: for mro in reversed(cls.__mro__): - try: - yield from mro.__slots__ - except AttributeError: + slots = getattr(mro, "__slots__", None) + if slots is None: continue + yield from slots def cached_slot_property( @@ -533,15 +530,15 @@ def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: try: import msgspec - def to_json(obj: Any) -> str: # type: ignore + def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return msgspec.json.encode(obj).decode("utf-8") - from_json = msgspec.json.decode # type: ignore + from_json = msgspec.json.decode except ModuleNotFoundError: import json - def to_json(obj: Any) -> str: + def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) from_json = json.loads diff --git a/discord/utils/public.py b/discord/utils/public.py index 1b600f265e..470ba37929 100644 --- a/discord/utils/public.py +++ b/discord/utils/public.py @@ -8,7 +8,7 @@ import re from collections.abc import Awaitable, Callable, Iterable from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast if TYPE_CHECKING: from ..abc import Snowflake @@ -16,7 +16,6 @@ from ..commands.options import OptionChoice from ..permissions import Permissions - T = TypeVar("T") @@ -121,16 +120,21 @@ async def autocomplete_callback(ctx: AutocompleteContext) -> V: if asyncio.iscoroutine(_values): _values = await _values + _values = cast(V, _values) if filter is None: - def _filter(ctx: AutocompleteContext, item: Any) -> bool: + def _filter(ctx: AutocompleteContext, item: OptionChoice | str | int | float) -> bool: item = getattr(item, "name", item) return str(item).lower().startswith(str(ctx.value or "").lower()) gen = (val for val in _values if _filter(ctx, val)) elif asyncio.iscoroutinefunction(filter): - gen = (val for val in _values if await filter(ctx, val)) + filtered_values: list[OptionChoice | str | int | float] = [] + for val in _values: + if await filter(ctx, val): + filtered_values.append(val) + gen = (val for val in _values) elif callable(filter): gen = (val for val in _values if filter(ctx, val)) @@ -138,7 +142,7 @@ def _filter(ctx: AutocompleteContext, item: Any) -> bool: else: raise TypeError("``filter`` must be callable.") - return iter(itertools.islice(gen, 25)) + return cast(V, iter(itertools.islice(gen, 25))) return autocomplete_callback @@ -459,7 +463,7 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str: The text with the markdown special characters removed. """ - def replacement(match): + def replacement(match: re.Match[str]): groupdict = match.groupdict() return groupdict.get("url", "") @@ -496,7 +500,7 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = if not as_needed: - def replacement(match): + def replacement(match: re.Match[str]): groupdict = match.groupdict() is_url = groupdict.get("url") if is_url: diff --git a/pyproject.toml b/pyproject.toml index c08d77caa3..fe5c52be3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -302,4 +302,4 @@ show_error_codes = true ignore_errors = true [tool.pytest.ini_options] -asyncio_mode = "auto" +asyncio_mode = "auto" \ No newline at end of file