From 085065f41fdb69ea83bfaba1208a7975e0372509 Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Wed, 10 Sep 2025 23:49:01 +0200 Subject: [PATCH 1/8] strict type --- discord/utils/__init__.py | 65 ---------------------- discord/utils/private.py | 111 +++++++++++++++++++++----------------- discord/utils/public.py | 56 ++++++++++++------- pyproject.toml | 4 ++ 4 files changed, 102 insertions(+), 134 deletions(-) diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index 888646f2ef..43da707397 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -25,11 +25,7 @@ from __future__ import annotations -from typing import ( - Any, -) -from ..errors import HTTPException from .public import ( MISSING, UNICODE_EMOJIS, @@ -56,7 +52,6 @@ "oauth_url", "snowflake_time", "find", - "get_or_fetch", "utcnow", "remove_markdown", "escape_markdown", @@ -71,63 +66,3 @@ "MISSING", "UNICODE_EMOJIS", ) - - -async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any: - """|coro| - - Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it. - If the fetch also fails, an error will be raised. - - Parameters - ---------- - obj: Any - The object to use the get or fetch methods in - attr: :class:`str` - The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute. - id: :class:`int` - The ID of the object - default: Any - The default value to return if the object is not found, instead of raising an error. - - Returns - ------- - Any - The object found or the default value. - - Raises - ------ - :exc:`AttributeError` - The object is missing a ``get_`` or ``fetch_`` method - :exc:`NotFound` - Invalid ID for the object - :exc:`HTTPException` - An error occurred fetching the object - :exc:`Forbidden` - You do not have permission to fetch the object - - Examples - -------- - - Getting a guild from a guild ID: :: - - guild = await utils.get_or_fetch(client, "guild", guild_id) - - Getting a channel from the guild. If the channel is not found, return None: :: - - channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None) - """ - getter = getattr(obj, f"get_{attr}")(id) - if getter is None: - try: - getter = await getattr(obj, f"fetch_{attr}")(id) - except AttributeError as e: - getter = await getattr(obj, f"_fetch_{attr}")(id) - if getter is None: - raise ValueError(f"Could not find {attr} with id {id} on {obj}") from e - except (HTTPException, ValueError): - if default is not MISSING: - return default - else: - raise - return getter diff --git a/discord/utils/private.py b/discord/utils/private.py index 5eb248b153..f9479eb1c5 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -16,22 +16,24 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, + overload, Callable, - Coroutine, - ForwardRef, - Generic, + TypeVar, + ParamSpec, Iterable, - Iterator, Literal, - ParamSpec, - Sequence, - TypeVar, + ForwardRef, Union, - overload, + Coroutine, + Awaitable, + reveal_type, + Generic, + Sequence, + Iterator, + get_args, ) -from ..errors import HTTPException, InvalidArgument +from ..errors import InvalidArgument, HTTPException if TYPE_CHECKING: from ..invite import Invite @@ -69,9 +71,6 @@ def resolve_invite(invite: Invite | str) -> str: return invite -__all__ = ("resolve_invite",) - - def get_as_snowflake(data: Any, key: str) -> int | None: try: value = data[key] @@ -107,11 +106,13 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) utc = datetime.timezone.utc now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) + reset = datetime.datetime.fromtimestamp( + float(request.headers["X-Ratelimit-Reset"]), utc + ) return (reset - now).total_seconds() -def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: +def string_width(string: str, *, _IS_ASCII: re.Pattern[str] = _IS_ASCII) -> int: """Returns string's width.""" match = _IS_ASCII.match(string) if match: @@ -138,7 +139,7 @@ def resolve_template(code: Template | str) -> str: :class:`str` The template code. """ - from ..template import Template # noqa: PLC0415 # circular import + from ..template import Template # noqa: PLC0415 if isinstance(code, Template): return code.code @@ -167,10 +168,6 @@ def parse_time(timestamp: None) -> None: ... def parse_time(timestamp: str) -> datetime.datetime: ... -@overload -def parse_time(timestamp: str | None) -> datetime.datetime | None: ... - - def parse_time(timestamp: str | None) -> datetime.datetime | None: """A helper function to convert an ISO 8601 timestamp to a datetime object. @@ -242,7 +239,7 @@ def deprecated( stacklevel: int = 3, *, use_qualname: bool = True, -) -> Callable[[Callable[[P], T]], Callable[[P], T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when the decorated function is called. @@ -267,7 +264,7 @@ def deprecated( will display as ``login``. Defaults to ``True``. """ - def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: + def actual_decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) def decorated(*args: P.args, **kwargs: P.kwargs) -> T: warn_deprecated( @@ -289,11 +286,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T: def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]: - params = [] + params: list[Any] = [] literal_cls = type(Literal[0]) for p in parameters: if isinstance(p, literal_cls): - params.extend(p.__args__) + params.extend(get_args(p)) else: params.append(p) return tuple(params) @@ -311,7 +308,7 @@ def evaluate_annotation( cache: dict[str, Any], *, implicit_str: bool = True, -): +) -> Any: if isinstance(tp, ForwardRef): tp = tp.__forward_arg__ # ForwardRefs always evaluate their internals @@ -329,8 +326,8 @@ def evaluate_annotation( is_literal = False args = tp.__args__ if not hasattr(tp, "__origin__"): - if PY_310 and tp.__class__ is types.UnionType: # type: ignore - converted = Union[args] # type: ignore # noqa: UP007 + if PY_310 and tp.__class__ is types.UnionType: + converted = Union[args] return evaluate_annotation(converted, globals, locals, cache) return tp @@ -347,11 +344,16 @@ def evaluate_annotation( is_literal = True evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) + for arg in args ) - if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): - raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + if is_literal and not all( + isinstance(x, (str, int, bool, type(None))) for x in evaluated_args + ): + raise TypeError( + "Literal arguments must be of type str, int, bool, or NoneType." + ) if evaluated_args == args: return tp @@ -381,7 +383,7 @@ def resolve_annotation( return evaluate_annotation(annotation, globalns, locals, cache) -def delay_task(delay: float, func: Coroutine): +def delay_task(delay: float, func: Coroutine[Any, Any, Any]): async def inner_call(): await asyncio.sleep(delay) try: @@ -401,16 +403,23 @@ async def async_all(gen: Iterable[Any]) -> bool: return True -async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: +async def maybe_awaitable( + f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs +) -> T: value = f(*args, **kwargs) if isawaitable(value): + reveal_type(f) return await value return value -async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]: +async def sane_wait_for( + futures: Iterable[Awaitable[T]], *, timeout: float +) -> set[asyncio.Task[T]]: ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) + done, pending = await asyncio.wait( + ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED + ) if len(pending) != 0: raise asyncio.TimeoutError() @@ -418,7 +427,7 @@ async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> s return done -class SnowflakeList(array.array): +class SnowflakeList(array.array[int]): """Internal data storage class to efficiently store a list of snowflakes. This should have the following characteristics: @@ -437,7 +446,7 @@ class SnowflakeList(array.array): def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ... def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): - return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore + return super().__new__(cls, "Q", data if is_sorted else sorted(data)) def add(self, element: int) -> None: i = bisect_left(self, element) @@ -452,22 +461,22 @@ def has(self, element: int) -> bool: return i != len(self) and self[i] == element -def copy_doc(original: Callable) -> Callable[[T], T]: +def copy_doc(original: Callable[..., object]) -> Callable[[T], T]: def decorator(overridden: T) -> T: overridden.__doc__ = original.__doc__ - overridden.__signature__ = signature(original) # type: ignore + setattr(overridden, "__signature__", signature(original)) return overridden return decorator -class SequenceProxy(collections.abc.Sequence, Generic[T_co]): +class SequenceProxy(collections.abc.Sequence[T_co], Generic[T_co]): """Read-only proxy of a Sequence.""" def __init__(self, proxied: Sequence[T_co]): self.__proxied = proxied - def __getitem__(self, idx: int) -> T_co: + def __getitem__(self, idx: int) -> T_co: # type: ignore[override] return self.__proxied[idx] def __len__(self) -> int: @@ -482,7 +491,7 @@ def __iter__(self) -> Iterator[T_co]: def __reversed__(self) -> Iterator[T_co]: return reversed(self.__proxied) - def index(self, value: Any, *args, **kwargs) -> int: + def index(self, value: Any, *args: Any, **kwargs: Any) -> int: return self.__proxied.index(value, *args, **kwargs) def count(self, value: Any) -> int: @@ -493,10 +502,12 @@ class CachedSlotProperty(Generic[T, T_co]): def __init__(self, name: str, function: Callable[[T], T_co]) -> None: self.name = name self.function = function - self.__doc__ = function.__doc__ + self.__doc__ = getattr(function, "__doc__") @overload - def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... + def __get__( + self, instance: None, owner: type[T] + ) -> CachedSlotProperty[T, T_co]: ... @overload def __get__(self, instance: T, owner: type[T]) -> T_co: ... @@ -515,10 +526,10 @@ def __get__(self, instance: T | None, owner: type[T]) -> Any: def get_slots(cls: type[Any]) -> Iterator[str]: for mro in reversed(cls.__mro__): - try: - yield from mro.__slots__ - except AttributeError: + slots = getattr(mro, "__slots__", None) + if slots is None: continue + yield from slots def cached_slot_property( @@ -533,15 +544,15 @@ def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: try: import msgspec - def to_json(obj: Any) -> str: # type: ignore + def _to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return msgspec.json.encode(obj).decode("utf-8") - from_json = msgspec.json.decode # type: ignore + _from_json = msgspec.json.decode except ModuleNotFoundError: import json - def to_json(obj: Any) -> str: + def _to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) - from_json = json.loads + _from_json = json.loads diff --git a/discord/utils/public.py b/discord/utils/public.py index 1b600f265e..d715d9064a 100644 --- a/discord/utils/public.py +++ b/discord/utils/public.py @@ -1,14 +1,14 @@ from __future__ import annotations import asyncio +import re import datetime import importlib.resources import itertools import json -import re from collections.abc import Awaitable, Callable, Iterable from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast if TYPE_CHECKING: from ..abc import Snowflake @@ -16,7 +16,6 @@ from ..commands.options import OptionChoice from ..permissions import Permissions - T = TypeVar("T") @@ -48,14 +47,16 @@ def utcnow() -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) -V = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] +V = Iterable[OptionChoice] | Iterable[str] | Iterable[int] | Iterable[float] AV = Awaitable[V] -Values = V | Callable[["AutocompleteContext"], V | AV] | AV -AutocompleteFunc = Callable[["AutocompleteContext"], AV] -FilterFunc = Callable[["AutocompleteContext", Any], bool | Awaitable[bool]] +Values = V | Callable[[AutocompleteContext], V | AV] | AV +AutocompleteFunc = Callable[[AutocompleteContext], AV] +FilterFunc = Callable[[AutocompleteContext, Any], bool | Awaitable[bool]] -def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: +def basic_autocomplete( + values: Values, *, filter: FilterFunc | None = None +) -> AutocompleteFunc: """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is callable, it will be called with the AutocompleteContext. @@ -121,16 +122,23 @@ async def autocomplete_callback(ctx: AutocompleteContext) -> V: if asyncio.iscoroutine(_values): _values = await _values + _values = cast(V, _values) if filter is None: - def _filter(ctx: AutocompleteContext, item: Any) -> bool: + def _filter( + ctx: AutocompleteContext, item: OptionChoice | str | int | float + ) -> bool: item = getattr(item, "name", item) return str(item).lower().startswith(str(ctx.value or "").lower()) gen = (val for val in _values if _filter(ctx, val)) elif asyncio.iscoroutinefunction(filter): - gen = (val for val in _values if await filter(ctx, val)) + filtered_values: list[OptionChoice | str | int | float] = [] + for val in _values: + if await filter(ctx, val): + filtered_values.append(val) + gen = (val for val in _values) elif callable(filter): gen = (val for val in _values if filter(ctx, val)) @@ -138,7 +146,7 @@ def _filter(ctx: AutocompleteContext, item: Any) -> bool: else: raise TypeError("``filter`` must be callable.") - return iter(itertools.islice(gen, 25)) + return cast(V, iter(itertools.islice(gen, 25))) return autocomplete_callback @@ -269,7 +277,9 @@ def oauth_url( TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] -def format_dt(dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None) -> str: +def format_dt( + dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None +) -> str: """A helper function to format a :class:`datetime.datetime` for presentation within Discord. This allows for a locale-independent way of presenting data using Discord specific Markdown. @@ -408,7 +418,9 @@ def raw_role_mentions(text: str) -> list[int]: return [int(x) for x in RAW_ROLE_PATTERN.findall(text)] -_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((? str: The text with the markdown special characters removed. """ - def replacement(match): + def replacement(match: re.Match[str]): groupdict = match.groupdict() return groupdict.get("url", "") regex = _MARKDOWN_STOCK_REGEX if ignore_links: regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE) + return re.sub(regex, replacement, text, 0, re.MULTILINE) -def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: +def escape_markdown( + text: str, *, as_needed: bool = False, ignore_links: bool = True +) -> str: r"""A helper function that escapes Discord's markdown. Parameters @@ -496,7 +510,7 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = if not as_needed: - def replacement(match): + def replacement(match: re.Match[str]): groupdict = match.groupdict() is_url = groupdict.get("url") if is_url: @@ -506,7 +520,7 @@ def replacement(match): regex = _MARKDOWN_STOCK_REGEX if ignore_links: regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE | re.X) + return re.sub(regex, replacement, text, 0, re.MULTILINE | re.X) else: text = re.sub(r"\\", r"\\\\", text) return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text) @@ -538,7 +552,11 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: return None -with importlib.resources.files(__package__).joinpath("../emojis.json").open(encoding="utf-8") as f: +with ( + importlib.resources.files(__package__) + .joinpath("../emojis.json") + .open(encoding="utf-8") as f +): EMOJIS_MAP = json.load(f) UNICODE_EMOJIS = set(EMOJIS_MAP.values()) diff --git a/pyproject.toml b/pyproject.toml index 286c1e60a7..76e972c78e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -303,3 +303,7 @@ ignore_errors = true [tool.pytest.ini_options] asyncio_mode = "auto" + +[tool.basedpyright] +typeCheckingMode = "strict" + From 9911df6fddd6e1c269c6bb2e4033be9e85b8d320 Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Thu, 11 Sep 2025 00:07:49 +0200 Subject: [PATCH 2/8] fix precommit --- discord/guild.py | 14 +++++--- discord/state.py | 2 +- discord/utils/__init__.py | 1 - discord/utils/private.py | 61 +++++++++++++--------------------- discord/utils/public.py | 32 +++++------------- pyproject.toml | 3 -- tests/test_typing_annotated.py | 2 +- 7 files changed, 45 insertions(+), 70 deletions(-) diff --git a/discord/guild.py b/discord/guild.py index 6d9876c762..692401f584 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -534,7 +534,7 @@ def _sync(self, data: GuildPayload) -> None: if "channels" in data: channels = data["channels"] for c in channels: - factory, ch_type = _guild_channel_factory(c["type"]) + factory, _ = _guild_channel_factory(c["type"]) if factory: self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore @@ -1020,7 +1020,10 @@ def get_member_named(self, name: str, /) -> Member | None: # do the actual lookup and return if found # if it isn't found then we'll do a full name lookup below. - result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members) + result = utils.find( + lambda m: m.name == name[:-5] and discriminator == potential_discriminator, + members, + ) if result is not None: return result @@ -1885,7 +1888,7 @@ async def fetch_channels(self) -> Sequence[GuildChannel]: data = await self._state.http.get_all_guild_channels(self.id) def convert(d): - factory, ch_type = _guild_channel_factory(d["type"]) + factory, _ = _guild_channel_factory(d["type"]) if factory is None: raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d)) @@ -3312,7 +3315,10 @@ async def widget(self) -> Widget: return Widget(state=self._state, data=data) async def edit_widget( - self, *, enabled: bool | utils.Undefined = MISSING, channel: Snowflake | None | utils.Undefined = MISSING + self, + *, + enabled: bool | utils.Undefined = MISSING, + channel: Snowflake | None | utils.Undefined = MISSING, ) -> None: """|coro| diff --git a/discord/state.py b/discord/state.py index 4cb14830fd..43753bfb65 100644 --- a/discord/state.py +++ b/discord/state.py @@ -985,7 +985,7 @@ def parse_channel_update(self, data) -> None: ) def parse_channel_create(self, data) -> None: - factory, ch_type = _channel_factory(data["type"]) + factory, _ = _channel_factory(data["type"]) if factory is None: _log.debug( "CHANNEL_CREATE referencing an unknown channel type %s. Discarding.", diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index 43da707397..c58d06ba7e 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -25,7 +25,6 @@ from __future__ import annotations - from .public import ( MISSING, UNICODE_EMOJIS, diff --git a/discord/utils/private.py b/discord/utils/private.py index f9479eb1c5..693bceca84 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -16,24 +16,24 @@ from typing import ( TYPE_CHECKING, Any, - overload, + Awaitable, Callable, - TypeVar, - ParamSpec, - Iterable, - Literal, - ForwardRef, - Union, Coroutine, - Awaitable, - reveal_type, + ForwardRef, Generic, - Sequence, + Iterable, Iterator, + Literal, + ParamSpec, + Sequence, + TypeVar, + Union, get_args, + overload, + reveal_type, ) -from ..errors import InvalidArgument, HTTPException +from ..errors import HTTPException, InvalidArgument if TYPE_CHECKING: from ..invite import Invite @@ -106,9 +106,7 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) utc = datetime.timezone.utc now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp( - float(request.headers["X-Ratelimit-Reset"]), utc - ) + reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) return (reset - now).total_seconds() @@ -327,7 +325,9 @@ def evaluate_annotation( args = tp.__args__ if not hasattr(tp, "__origin__"): if PY_310 and tp.__class__ is types.UnionType: - converted = Union[args] + converted = args[0] + for arg in args[1:]: + converted = converted | arg return evaluate_annotation(converted, globals, locals, cache) return tp @@ -344,16 +344,11 @@ def evaluate_annotation( is_literal = True evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) - for arg in args + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args ) - if is_literal and not all( - isinstance(x, (str, int, bool, type(None))) for x in evaluated_args - ): - raise TypeError( - "Literal arguments must be of type str, int, bool, or NoneType." - ) + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") if evaluated_args == args: return tp @@ -403,9 +398,7 @@ async def async_all(gen: Iterable[Any]) -> bool: return True -async def maybe_awaitable( - f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs -) -> T: +async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: value = f(*args, **kwargs) if isawaitable(value): reveal_type(f) @@ -413,13 +406,9 @@ async def maybe_awaitable( return value -async def sane_wait_for( - futures: Iterable[Awaitable[T]], *, timeout: float -) -> set[asyncio.Task[T]]: +async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]: ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait( - ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED - ) + done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) if len(pending) != 0: raise asyncio.TimeoutError() @@ -464,7 +453,7 @@ def has(self, element: int) -> bool: def copy_doc(original: Callable[..., object]) -> Callable[[T], T]: def decorator(overridden: T) -> T: overridden.__doc__ = original.__doc__ - setattr(overridden, "__signature__", signature(original)) + overridden.__signature__ = signature(original) return overridden return decorator @@ -502,12 +491,10 @@ class CachedSlotProperty(Generic[T, T_co]): def __init__(self, name: str, function: Callable[[T], T_co]) -> None: self.name = name self.function = function - self.__doc__ = getattr(function, "__doc__") + self.__doc__ = function.__doc__ @overload - def __get__( - self, instance: None, owner: type[T] - ) -> CachedSlotProperty[T, T_co]: ... + def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... @overload def __get__(self, instance: T, owner: type[T]) -> T_co: ... diff --git a/discord/utils/public.py b/discord/utils/public.py index d715d9064a..875edb05ed 100644 --- a/discord/utils/public.py +++ b/discord/utils/public.py @@ -1,11 +1,11 @@ from __future__ import annotations import asyncio -import re import datetime import importlib.resources import itertools import json +import re from collections.abc import Awaitable, Callable, Iterable from enum import Enum, auto from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast @@ -54,9 +54,7 @@ def utcnow() -> datetime.datetime: FilterFunc = Callable[[AutocompleteContext, Any], bool | Awaitable[bool]] -def basic_autocomplete( - values: Values, *, filter: FilterFunc | None = None -) -> AutocompleteFunc: +def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is callable, it will be called with the AutocompleteContext. @@ -125,9 +123,7 @@ async def autocomplete_callback(ctx: AutocompleteContext) -> V: _values = cast(V, _values) if filter is None: - def _filter( - ctx: AutocompleteContext, item: OptionChoice | str | int | float - ) -> bool: + def _filter(ctx: AutocompleteContext, item: OptionChoice | str | int | float) -> bool: item = getattr(item, "name", item) return str(item).lower().startswith(str(ctx.value or "").lower()) @@ -277,9 +273,7 @@ def oauth_url( TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] -def format_dt( - dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None -) -> str: +def format_dt(dt: datetime.datetime | datetime.time, /, style: TimestampStyle | None = None) -> str: """A helper function to format a :class:`datetime.datetime` for presentation within Discord. This allows for a locale-independent way of presenting data using Discord specific Markdown. @@ -418,9 +412,7 @@ def raw_role_mentions(text: str) -> list[int]: return [int(x) for x in RAW_ROLE_PATTERN.findall(text)] -_MARKDOWN_ESCAPE_SUBREGEX = "|".join( - r"\{0}(?=([\s\S]*((? str: +def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: r"""A helper function that escapes Discord's markdown. Parameters @@ -520,7 +510,7 @@ def replacement(match: re.Match[str]): regex = _MARKDOWN_STOCK_REGEX if ignore_links: regex = f"(?:{_URL_REGEX}|{regex})" - return re.sub(regex, replacement, text, 0, re.MULTILINE | re.X) + return re.sub(regex, replacement, text, count=0, flags=re.MULTILINE | re.X) else: text = re.sub(r"\\", r"\\\\", text) return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text) @@ -552,11 +542,7 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> T | None: return None -with ( - importlib.resources.files(__package__) - .joinpath("../emojis.json") - .open(encoding="utf-8") as f -): +with importlib.resources.files(__package__).joinpath("../emojis.json").open(encoding="utf-8") as f: EMOJIS_MAP = json.load(f) UNICODE_EMOJIS = set(EMOJIS_MAP.values()) diff --git a/pyproject.toml b/pyproject.toml index 76e972c78e..29d25a0d87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -304,6 +304,3 @@ ignore_errors = true [tool.pytest.ini_options] asyncio_mode = "auto" -[tool.basedpyright] -typeCheckingMode = "strict" - diff --git a/tests/test_typing_annotated.py b/tests/test_typing_annotated.py index ede5530777..f867469072 100644 --- a/tests/test_typing_annotated.py +++ b/tests/test_typing_annotated.py @@ -77,7 +77,7 @@ async def echo(self, ctx, txt: Annotated[str, discord.Option(description="Some t def test_typing_annotated_optional(): - async def echo(ctx, txt: Annotated[Optional[str], discord.Option()]): + async def echo(ctx, txt: Annotated[str | None, discord.Option()]): await ctx.respond(txt) cmd = SlashCommand(echo) From 0f3a3751a7a4dbafbf0c1919b242bb71fb45eae3 Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Thu, 11 Sep 2025 00:11:53 +0200 Subject: [PATCH 3/8] usage of noqa --- discord/utils/private.py | 4 +--- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/discord/utils/private.py b/discord/utils/private.py index 693bceca84..8519d9fec0 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -325,9 +325,7 @@ def evaluate_annotation( args = tp.__args__ if not hasattr(tp, "__origin__"): if PY_310 and tp.__class__ is types.UnionType: - converted = args[0] - for arg in args[1:]: - converted = converted | arg + converted = Union[args] # noqa: UP007 return evaluate_annotation(converted, globals, locals, cache) return tp diff --git a/pyproject.toml b/pyproject.toml index 29d25a0d87..286c1e60a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -303,4 +303,3 @@ ignore_errors = true [tool.pytest.ini_options] asyncio_mode = "auto" - From 7ccae38425262b6ae11a602f42c26892bbc6532e Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Sat, 13 Sep 2025 16:51:11 +0200 Subject: [PATCH 4/8] paillat comment --- discord/utils/private.py | 52 ++++++++++++++++++++++++++-------------- pyproject.toml | 3 +++ 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/discord/utils/private.py b/discord/utils/private.py index 8519d9fec0..b95c21eb6b 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -30,7 +30,6 @@ Union, get_args, overload, - reveal_type, ) from ..errors import HTTPException, InvalidArgument @@ -106,7 +105,9 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) utc = datetime.timezone.utc now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) + reset = datetime.datetime.fromtimestamp( + float(request.headers["X-Ratelimit-Reset"]), utc + ) return (reset - now).total_seconds() @@ -213,7 +214,6 @@ def warn_deprecated( stacklevel: :class:`int` The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3. """ - warnings.simplefilter("always", DeprecationWarning) # turn off filter message = f"{name} is deprecated" if since: message += f" since version {since}" @@ -226,7 +226,6 @@ def warn_deprecated( message += f" See {reference} for more information." warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning) - warnings.simplefilter("default", DeprecationWarning) # reset filter def deprecated( @@ -342,11 +341,16 @@ def evaluate_annotation( is_literal = True evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) + for arg in args ) - if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): - raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + if is_literal and not all( + isinstance(x, (str, int, bool, type(None))) for x in evaluated_args + ): + raise TypeError( + "Literal arguments must be of type str, int, bool, or NoneType." + ) if evaluated_args == args: return tp @@ -396,17 +400,22 @@ async def async_all(gen: Iterable[Any]) -> bool: return True -async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: +async def maybe_awaitable( + f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs +) -> T: value = f(*args, **kwargs) if isawaitable(value): - reveal_type(f) return await value return value -async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]: +async def sane_wait_for( + futures: Iterable[Awaitable[T]], *, timeout: float +) -> set[asyncio.Task[T]]: ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) + done, pending = await asyncio.wait( + ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED + ) if len(pending) != 0: raise asyncio.TimeoutError() @@ -451,7 +460,7 @@ def has(self, element: int) -> bool: def copy_doc(original: Callable[..., object]) -> Callable[[T], T]: def decorator(overridden: T) -> T: overridden.__doc__ = original.__doc__ - overridden.__signature__ = signature(original) + overridden.__signature__ = signature(original) # pyright: ignore[reportAttributeAccessIssue] return overridden return decorator @@ -463,7 +472,12 @@ class SequenceProxy(collections.abc.Sequence[T_co], Generic[T_co]): def __init__(self, proxied: Sequence[T_co]): self.__proxied = proxied - def __getitem__(self, idx: int) -> T_co: # type: ignore[override] + @overload + def __getitem__(self, idx: int) -> T_co: ... + @overload + def __getitem__(self, idx: slice) -> Sequence[T_co]: ... + + def __getitem__(self, idx: int | slice) -> T_co | Sequence[T_co]: return self.__proxied[idx] def __len__(self) -> int: @@ -492,7 +506,9 @@ def __init__(self, name: str, function: Callable[[T], T_co]) -> None: self.__doc__ = function.__doc__ @overload - def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... + def __get__( + self, instance: None, owner: type[T] + ) -> CachedSlotProperty[T, T_co]: ... @overload def __get__(self, instance: T, owner: type[T]) -> T_co: ... @@ -529,15 +545,15 @@ def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]: try: import msgspec - def _to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] + def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return msgspec.json.encode(obj).decode("utf-8") - _from_json = msgspec.json.decode + from_json = msgspec.json.decode except ModuleNotFoundError: import json - def _to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] + def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) - _from_json = json.loads + from_json = json.loads diff --git a/pyproject.toml b/pyproject.toml index c08d77caa3..70b3caabd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -303,3 +303,6 @@ ignore_errors = true [tool.pytest.ini_options] asyncio_mode = "auto" + +[tool.basedpyright] +typeCheckingMode = "strict" From 26284d16afd280e4e57e252a84480fb26cbcaf6a Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Wed, 17 Sep 2025 14:46:41 +0200 Subject: [PATCH 5/8] paillat comment --- discord/client.py | 136 ++++++++++++++++++++++++++++++---------------- 1 file changed, 90 insertions(+), 46 deletions(-) diff --git a/discord/client.py b/discord/client.py index 93036aacde..61816403bc 100644 --- a/discord/client.py +++ b/discord/client.py @@ -63,7 +63,12 @@ from .ui.view import View from .user import ClientUser, User from .utils import MISSING -from .utils.private import SequenceProxy, bytes_to_base64_data, resolve_invite, resolve_template +from .utils.private import ( + SequenceProxy, + bytes_to_base64_data, + resolve_invite, + resolve_template, +) from .voice_client import VoiceClient from .webhook import Webhook from .widget import Widget @@ -234,8 +239,12 @@ def __init__( self._banner_module = options.get("banner_module") # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop - self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = {} + self.loop: asyncio.AbstractEventLoop = ( + asyncio.get_event_loop() if loop is None else loop + ) + self._listeners: dict[ + str, list[tuple[asyncio.Future, Callable[..., bool]]] + ] = {} self.shard_id: int | None = options.get("shard_id") self.shard_count: int | None = options.get("shard_count") @@ -253,7 +262,9 @@ def __init__( self._handlers: dict[str, Callable] = {"ready": self._handle_ready} - self._hooks: dict[str, Callable] = {"before_identify": self._call_before_identify_hook} + self._hooks: dict[str, Callable] = { + "before_identify": self._call_before_identify_hook + } self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) @@ -292,7 +303,9 @@ async def __aexit__( # internals - def _get_websocket(self, guild_id: int | None = None, *, shard_id: int | None = None) -> DiscordWebSocket: + def _get_websocket( + self, guild_id: int | None = None, *, shard_id: int | None = None + ) -> DiscordWebSocket: return self.ws def _get_state(self, **options: Any) -> ConnectionState: @@ -541,7 +554,9 @@ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exc() - async def on_view_error(self, error: Exception, item: Item, interaction: Interaction) -> None: + async def on_view_error( + self, error: Exception, item: Item, interaction: Interaction + ) -> None: """|coro| The default view error handler provided by the client. @@ -553,7 +568,9 @@ async def on_view_error(self, error: Exception, item: Item, interaction: Interac f"Ignoring exception in view {interaction.view} for item {item}:", file=sys.stderr, ) - traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) + traceback.print_exception( + error.__class__, error, error.__traceback__, file=sys.stderr + ) async def on_modal_error(self, error: Exception, interaction: Interaction) -> None: """|coro| @@ -565,17 +582,23 @@ async def on_modal_error(self, error: Exception, interaction: Interaction) -> No """ print(f"Ignoring exception in modal {interaction.modal}:", file=sys.stderr) - traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) + traceback.print_exception( + error.__class__, error, error.__traceback__, file=sys.stderr + ) # hooks - async def _call_before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: + async def _call_before_identify_hook( + self, shard_id: int | None, *, initial: bool = False + ) -> None: # This hook is an internal hook that actually calls the public one. # It allows the library to have its own hook without stepping on the # toes of those who need to override their own hook. await self.before_identify_hook(shard_id, initial=initial) - async def before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: + async def before_identify_hook( + self, shard_id: int | None, *, initial: bool = False + ) -> None: """|coro| A hook that is called before IDENTIFYing a session. This is useful @@ -622,14 +645,19 @@ async def login(self, token: str) -> None: passing status code. """ if not isinstance(token, str): - raise TypeError(f"token must be of type str, not {token.__class__.__name__}") + raise TypeError( + f"token must be of type str, not {token.__class__.__name__}" + ) _log.info("logging in using static token") data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) - print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or "discord") + print_banner( + bot_name=self._connection.user.display_name, + module=self._banner_module or "discord", + ) start_logging(self._flavor, debug=self._debug) async def connect(self, *, reconnect: bool = True) -> None: @@ -726,7 +754,9 @@ async def connect(self, *, reconnect: bool = True) -> None: # This is apparently what the official Discord client does. if self.ws is None: continue - ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) + ws_params.update( + sequence=self.ws.sequence, resume=True, session=self.ws.session_id + ) async def close(self) -> None: """|coro| @@ -894,7 +924,9 @@ def allowed_mentions(self, value: AllowedMentions | None) -> None: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: - raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}") + raise TypeError( + f"allowed_mentions must be AllowedMentions not {value.__class__!r}" + ) @property def intents(self) -> Intents: @@ -968,7 +1000,9 @@ def get_message(self, id: int, /) -> Message | None: """ return self._connection._get_message(id) - def get_partial_messageable(self, id: int, *, type: ChannelType | None = None) -> PartialMessageable: + def get_partial_messageable( + self, id: int, *, type: ChannelType | None = None + ) -> PartialMessageable: """Returns a partial messageable with the given channel ID. This is useful if you have a channel_id but don't want to do an API call @@ -1130,24 +1164,6 @@ def get_all_members(self) -> Generator[Member]: for guild in self.guilds: yield from guild.members - async def get_or_fetch_user(self, id: int, /) -> User | None: - """|coro| - - Looks up a user in the user cache or fetches if not found. - - Parameters - ---------- - id: :class:`int` - The ID to search for. - - Returns - ------- - Optional[:class:`~discord.User`] - The user or ``None`` if not found. - """ - - return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None) - # listeners/waiters async def wait_until_ready(self) -> None: @@ -1315,7 +1331,9 @@ async def my_message(message): name, ) - def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None: + def remove_listener( + self, func: Coro, name: str | utils.Undefined = MISSING + ) -> None: """Removes a listener from the pool of listeners. Parameters @@ -1335,7 +1353,9 @@ def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> except ValueError: pass - def listen(self, name: str | utils.Undefined = MISSING, once: bool = False) -> Callable[[Coro], Coro]: + def listen( + self, name: str | utils.Undefined = MISSING, once: bool = False + ) -> Callable[[Coro], Coro]: """A decorator that registers another function as an external event listener. Basically this allows you to listen to multiple events from different places e.g. such as :func:`.on_ready` @@ -1547,7 +1567,9 @@ def fetch_guilds( All parameters are optional. """ - return GuildIterator(self, limit=limit, before=before, after=after, with_counts=with_counts) + return GuildIterator( + self, limit=limit, before=before, after=after, with_counts=with_counts + ) async def fetch_template(self, code: Template | str) -> Template: """|coro| @@ -1866,7 +1888,9 @@ async def fetch_user(self, user_id: int, /) -> User: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel(self, channel_id: int, /) -> GuildChannel | PrivateChannel | Thread: + async def fetch_channel( + self, channel_id: int, / + ) -> GuildChannel | PrivateChannel | Thread: """|coro| Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. @@ -1897,7 +1921,9 @@ async def fetch_channel(self, channel_id: int, /) -> GuildChannel | PrivateChann factory, ch_type = _threaded_channel_factory(data["type"]) if factory is None: - raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) + raise InvalidData( + "Unknown channel type {type} for channel ID {id}.".format_map(data) + ) if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here @@ -1971,7 +1997,10 @@ async def fetch_premium_sticker_packs(self) -> list[StickerPack]: Retrieving the sticker packs failed. """ data = await self.http.list_premium_sticker_packs() - return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]] + return [ + StickerPack(state=self._connection, data=pack) + for pack in data["sticker_packs"] + ] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -2031,7 +2060,9 @@ def add_view(self, view: View, *, message_id: int | None = None) -> None: raise TypeError(f"expected an instance of View not {view.__class__!r}") if not view.is_persistent(): - raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout") + raise ValueError( + "View is not persistent. Items need to have a custom_id set and View must have no timeout" + ) self._connection.store_view(view, message_id) @@ -2057,7 +2088,9 @@ async def fetch_role_connection_metadata_records( List[:class:`.ApplicationRoleConnectionMetadata`] The bot's role connection metadata records. """ - data = await self._connection.http.get_application_role_connection_metadata_records(self.application_id) + data = await self._connection.http.get_application_role_connection_metadata_records( + self.application_id + ) return [ApplicationRoleConnectionMetadata.from_dict(r) for r in data] async def update_role_connection_metadata_records( @@ -2196,8 +2229,13 @@ async def fetch_emojis(self) -> list[AppEmoji]: List[:class:`AppEmoji`] The retrieved emojis. """ - data = await self._connection.http.get_all_application_emojis(self.application_id) - return [self._connection.maybe_store_app_emoji(self.application_id, d) for d in data["items"]] + data = await self._connection.http.get_all_application_emojis( + self.application_id + ) + return [ + self._connection.maybe_store_app_emoji(self.application_id, d) + for d in data["items"] + ] async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: """|coro| @@ -2221,7 +2259,9 @@ async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: HTTPException An error occurred fetching the emoji. """ - data = await self._connection.http.get_application_emoji(self.application_id, emoji_id) + data = await self._connection.http.get_application_emoji( + self.application_id, emoji_id + ) return self._connection.maybe_store_app_emoji(self.application_id, data) async def create_emoji( @@ -2256,7 +2296,9 @@ async def create_emoji( """ img = bytes_to_base64_data(image) - data = await self._connection.http.create_application_emoji(self.application_id, name, img) + data = await self._connection.http.create_application_emoji( + self.application_id, name, img + ) return self._connection.maybe_store_app_emoji(self.application_id, data) async def delete_emoji(self, emoji: Snowflake) -> None: @@ -2275,6 +2317,8 @@ async def delete_emoji(self, emoji: Snowflake) -> None: An error occurred deleting the emoji. """ - await self._connection.http.delete_application_emoji(self.application_id, emoji.id) + await self._connection.http.delete_application_emoji( + self.application_id, emoji.id + ) if self._connection.cache_app_emojis and self._connection.get_emoji(emoji.id): self._connection.remove_emoji(emoji) From e2d3def8301c34c598bdb532f2c8ec2599a531f9 Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:32:03 +0200 Subject: [PATCH 6/8] run precommit --- discord/client.py | 106 ++++++++++----------------------------- discord/utils/private.py | 31 +++--------- 2 files changed, 34 insertions(+), 103 deletions(-) diff --git a/discord/client.py b/discord/client.py index 61816403bc..71aed9351a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -239,12 +239,8 @@ def __init__( self._banner_module = options.get("banner_module") # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) - self._listeners: dict[ - str, list[tuple[asyncio.Future, Callable[..., bool]]] - ] = {} + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: int | None = options.get("shard_id") self.shard_count: int | None = options.get("shard_count") @@ -262,9 +258,7 @@ def __init__( self._handlers: dict[str, Callable] = {"ready": self._handle_ready} - self._hooks: dict[str, Callable] = { - "before_identify": self._call_before_identify_hook - } + self._hooks: dict[str, Callable] = {"before_identify": self._call_before_identify_hook} self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) @@ -303,9 +297,7 @@ async def __aexit__( # internals - def _get_websocket( - self, guild_id: int | None = None, *, shard_id: int | None = None - ) -> DiscordWebSocket: + def _get_websocket(self, guild_id: int | None = None, *, shard_id: int | None = None) -> DiscordWebSocket: return self.ws def _get_state(self, **options: Any) -> ConnectionState: @@ -554,9 +546,7 @@ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exc() - async def on_view_error( - self, error: Exception, item: Item, interaction: Interaction - ) -> None: + async def on_view_error(self, error: Exception, item: Item, interaction: Interaction) -> None: """|coro| The default view error handler provided by the client. @@ -568,9 +558,7 @@ async def on_view_error( f"Ignoring exception in view {interaction.view} for item {item}:", file=sys.stderr, ) - traceback.print_exception( - error.__class__, error, error.__traceback__, file=sys.stderr - ) + traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) async def on_modal_error(self, error: Exception, interaction: Interaction) -> None: """|coro| @@ -582,23 +570,17 @@ async def on_modal_error(self, error: Exception, interaction: Interaction) -> No """ print(f"Ignoring exception in modal {interaction.modal}:", file=sys.stderr) - traceback.print_exception( - error.__class__, error, error.__traceback__, file=sys.stderr - ) + traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) # hooks - async def _call_before_identify_hook( - self, shard_id: int | None, *, initial: bool = False - ) -> None: + async def _call_before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: # This hook is an internal hook that actually calls the public one. # It allows the library to have its own hook without stepping on the # toes of those who need to override their own hook. await self.before_identify_hook(shard_id, initial=initial) - async def before_identify_hook( - self, shard_id: int | None, *, initial: bool = False - ) -> None: + async def before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None: """|coro| A hook that is called before IDENTIFYing a session. This is useful @@ -645,9 +627,7 @@ async def login(self, token: str) -> None: passing status code. """ if not isinstance(token, str): - raise TypeError( - f"token must be of type str, not {token.__class__.__name__}" - ) + raise TypeError(f"token must be of type str, not {token.__class__.__name__}") _log.info("logging in using static token") @@ -754,9 +734,7 @@ async def connect(self, *, reconnect: bool = True) -> None: # This is apparently what the official Discord client does. if self.ws is None: continue - ws_params.update( - sequence=self.ws.sequence, resume=True, session=self.ws.session_id - ) + ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) async def close(self) -> None: """|coro| @@ -924,9 +902,7 @@ def allowed_mentions(self, value: AllowedMentions | None) -> None: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: - raise TypeError( - f"allowed_mentions must be AllowedMentions not {value.__class__!r}" - ) + raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}") @property def intents(self) -> Intents: @@ -1000,9 +976,7 @@ def get_message(self, id: int, /) -> Message | None: """ return self._connection._get_message(id) - def get_partial_messageable( - self, id: int, *, type: ChannelType | None = None - ) -> PartialMessageable: + def get_partial_messageable(self, id: int, *, type: ChannelType | None = None) -> PartialMessageable: """Returns a partial messageable with the given channel ID. This is useful if you have a channel_id but don't want to do an API call @@ -1331,9 +1305,7 @@ async def my_message(message): name, ) - def remove_listener( - self, func: Coro, name: str | utils.Undefined = MISSING - ) -> None: + def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None: """Removes a listener from the pool of listeners. Parameters @@ -1353,9 +1325,7 @@ def remove_listener( except ValueError: pass - def listen( - self, name: str | utils.Undefined = MISSING, once: bool = False - ) -> Callable[[Coro], Coro]: + def listen(self, name: str | utils.Undefined = MISSING, once: bool = False) -> Callable[[Coro], Coro]: """A decorator that registers another function as an external event listener. Basically this allows you to listen to multiple events from different places e.g. such as :func:`.on_ready` @@ -1567,9 +1537,7 @@ def fetch_guilds( All parameters are optional. """ - return GuildIterator( - self, limit=limit, before=before, after=after, with_counts=with_counts - ) + return GuildIterator(self, limit=limit, before=before, after=after, with_counts=with_counts) async def fetch_template(self, code: Template | str) -> Template: """|coro| @@ -1888,9 +1856,7 @@ async def fetch_user(self, user_id: int, /) -> User: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel( - self, channel_id: int, / - ) -> GuildChannel | PrivateChannel | Thread: + async def fetch_channel(self, channel_id: int, /) -> GuildChannel | PrivateChannel | Thread: """|coro| Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. @@ -1921,9 +1887,7 @@ async def fetch_channel( factory, ch_type = _threaded_channel_factory(data["type"]) if factory is None: - raise InvalidData( - "Unknown channel type {type} for channel ID {id}.".format_map(data) - ) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here @@ -1997,10 +1961,7 @@ async def fetch_premium_sticker_packs(self) -> list[StickerPack]: Retrieving the sticker packs failed. """ data = await self.http.list_premium_sticker_packs() - return [ - StickerPack(state=self._connection, data=pack) - for pack in data["sticker_packs"] - ] + return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -2060,9 +2021,7 @@ def add_view(self, view: View, *, message_id: int | None = None) -> None: raise TypeError(f"expected an instance of View not {view.__class__!r}") if not view.is_persistent(): - raise ValueError( - "View is not persistent. Items need to have a custom_id set and View must have no timeout" - ) + raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout") self._connection.store_view(view, message_id) @@ -2088,9 +2047,7 @@ async def fetch_role_connection_metadata_records( List[:class:`.ApplicationRoleConnectionMetadata`] The bot's role connection metadata records. """ - data = await self._connection.http.get_application_role_connection_metadata_records( - self.application_id - ) + data = await self._connection.http.get_application_role_connection_metadata_records(self.application_id) return [ApplicationRoleConnectionMetadata.from_dict(r) for r in data] async def update_role_connection_metadata_records( @@ -2229,13 +2186,8 @@ async def fetch_emojis(self) -> list[AppEmoji]: List[:class:`AppEmoji`] The retrieved emojis. """ - data = await self._connection.http.get_all_application_emojis( - self.application_id - ) - return [ - self._connection.maybe_store_app_emoji(self.application_id, d) - for d in data["items"] - ] + data = await self._connection.http.get_all_application_emojis(self.application_id) + return [self._connection.maybe_store_app_emoji(self.application_id, d) for d in data["items"]] async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: """|coro| @@ -2259,9 +2211,7 @@ async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: HTTPException An error occurred fetching the emoji. """ - data = await self._connection.http.get_application_emoji( - self.application_id, emoji_id - ) + data = await self._connection.http.get_application_emoji(self.application_id, emoji_id) return self._connection.maybe_store_app_emoji(self.application_id, data) async def create_emoji( @@ -2296,9 +2246,7 @@ async def create_emoji( """ img = bytes_to_base64_data(image) - data = await self._connection.http.create_application_emoji( - self.application_id, name, img - ) + data = await self._connection.http.create_application_emoji(self.application_id, name, img) return self._connection.maybe_store_app_emoji(self.application_id, data) async def delete_emoji(self, emoji: Snowflake) -> None: @@ -2317,8 +2265,6 @@ async def delete_emoji(self, emoji: Snowflake) -> None: An error occurred deleting the emoji. """ - await self._connection.http.delete_application_emoji( - self.application_id, emoji.id - ) + await self._connection.http.delete_application_emoji(self.application_id, emoji.id) if self._connection.cache_app_emojis and self._connection.get_emoji(emoji.id): self._connection.remove_emoji(emoji) diff --git a/discord/utils/private.py b/discord/utils/private.py index b95c21eb6b..bff476fda4 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -105,9 +105,7 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) utc = datetime.timezone.utc now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp( - float(request.headers["X-Ratelimit-Reset"]), utc - ) + reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) return (reset - now).total_seconds() @@ -341,16 +339,11 @@ def evaluate_annotation( is_literal = True evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) - for arg in args + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args ) - if is_literal and not all( - isinstance(x, (str, int, bool, type(None))) for x in evaluated_args - ): - raise TypeError( - "Literal arguments must be of type str, int, bool, or NoneType." - ) + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") if evaluated_args == args: return tp @@ -400,22 +393,16 @@ async def async_all(gen: Iterable[Any]) -> bool: return True -async def maybe_awaitable( - f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs -) -> T: +async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: value = f(*args, **kwargs) if isawaitable(value): return await value return value -async def sane_wait_for( - futures: Iterable[Awaitable[T]], *, timeout: float -) -> set[asyncio.Task[T]]: +async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]: ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait( - ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED - ) + done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) if len(pending) != 0: raise asyncio.TimeoutError() @@ -506,9 +493,7 @@ def __init__(self, name: str, function: Callable[[T], T_co]) -> None: self.__doc__ = function.__doc__ @overload - def __get__( - self, instance: None, owner: type[T] - ) -> CachedSlotProperty[T, T_co]: ... + def __get__(self, instance: None, owner: type[T]) -> CachedSlotProperty[T, T_co]: ... @overload def __get__(self, instance: T, owner: type[T]) -> T_co: ... From 527cce1530b039a2bd98eed1cf724c1561f53b5f Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:38:28 +0200 Subject: [PATCH 7/8] test check --- discord/utils/private.py | 2 +- discord/utils/public.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/discord/utils/private.py b/discord/utils/private.py index bff476fda4..7ed3612f9e 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -447,7 +447,7 @@ def has(self, element: int) -> bool: def copy_doc(original: Callable[..., object]) -> Callable[[T], T]: def decorator(overridden: T) -> T: overridden.__doc__ = original.__doc__ - overridden.__signature__ = signature(original) # pyright: ignore[reportAttributeAccessIssue] + overridden.__signature__ = signature(original) # type: ignore[reportAttributeAccessIssue] return overridden return decorator diff --git a/discord/utils/public.py b/discord/utils/public.py index 875edb05ed..470ba37929 100644 --- a/discord/utils/public.py +++ b/discord/utils/public.py @@ -47,11 +47,11 @@ def utcnow() -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) -V = Iterable[OptionChoice] | Iterable[str] | Iterable[int] | Iterable[float] +V = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] AV = Awaitable[V] -Values = V | Callable[[AutocompleteContext], V | AV] | AV -AutocompleteFunc = Callable[[AutocompleteContext], AV] -FilterFunc = Callable[[AutocompleteContext, Any], bool | Awaitable[bool]] +Values = V | Callable[["AutocompleteContext"], V | AV] | AV +AutocompleteFunc = Callable[["AutocompleteContext"], AV] +FilterFunc = Callable[["AutocompleteContext", Any], bool | Awaitable[bool]] def basic_autocomplete(values: Values, *, filter: FilterFunc | None = None) -> AutocompleteFunc: From 626a4997104e757a6882612790669a7738c8324e Mon Sep 17 00:00:00 2001 From: Lumouille <144063653+Lumabots@users.noreply.github.com> Date: Mon, 29 Sep 2025 11:27:06 +0300 Subject: [PATCH 8/8] fix: restore asyncio_mode setting in pytest configuration --- pyproject.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 70b3caabd0..fe5c52be3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -302,7 +302,4 @@ show_error_codes = true ignore_errors = true [tool.pytest.ini_options] -asyncio_mode = "auto" - -[tool.basedpyright] -typeCheckingMode = "strict" +asyncio_mode = "auto" \ No newline at end of file