Skip to content
14 changes: 10 additions & 4 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _sync(self, data: GuildPayload) -> None:
if "channels" in data:
channels = data["channels"]
for c in channels:
factory, ch_type = _guild_channel_factory(c["type"])
factory, _ = _guild_channel_factory(c["type"])
if factory:
self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore

Expand Down Expand Up @@ -1020,7 +1020,10 @@ def get_member_named(self, name: str, /) -> Member | None:

# do the actual lookup and return if found
# if it isn't found then we'll do a full name lookup below.
result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members)
result = utils.find(
lambda m: m.name == name[:-5] and discriminator == potential_discriminator,
members,
)
if result is not None:
return result

Expand Down Expand Up @@ -1885,7 +1888,7 @@ async def fetch_channels(self) -> Sequence[GuildChannel]:
data = await self._state.http.get_all_guild_channels(self.id)

def convert(d):
factory, ch_type = _guild_channel_factory(d["type"])
factory, _ = _guild_channel_factory(d["type"])
if factory is None:
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d))

Expand Down Expand Up @@ -3312,7 +3315,10 @@ async def widget(self) -> Widget:
return Widget(state=self._state, data=data)

async def edit_widget(
self, *, enabled: bool | utils.Undefined = MISSING, channel: Snowflake | None | utils.Undefined = MISSING
self,
*,
enabled: bool | utils.Undefined = MISSING,
channel: Snowflake | None | utils.Undefined = MISSING,
) -> None:
"""|coro|

Expand Down
2 changes: 1 addition & 1 deletion discord/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def parse_channel_update(self, data) -> None:
)

def parse_channel_create(self, data) -> None:
factory, ch_type = _channel_factory(data["type"])
factory, _ = _channel_factory(data["type"])
if factory is None:
_log.debug(
"CHANNEL_CREATE referencing an unknown channel type %s. Discarding.",
Expand Down
66 changes: 0 additions & 66 deletions discord/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@

from __future__ import annotations

from typing import (
Any,
)

from ..errors import HTTPException
from .public import (
MISSING,
UNICODE_EMOJIS,
Expand All @@ -56,7 +51,6 @@
"oauth_url",
"snowflake_time",
"find",
"get_or_fetch",
"utcnow",
"remove_markdown",
"escape_markdown",
Expand All @@ -71,63 +65,3 @@
"MISSING",
"UNICODE_EMOJIS",
)


async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any:
"""|coro|
Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it.
If the fetch also fails, an error will be raised.
Parameters
----------
obj: Any
The object to use the get or fetch methods in
attr: :class:`str`
The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute.
id: :class:`int`
The ID of the object
default: Any
The default value to return if the object is not found, instead of raising an error.
Returns
-------
Any
The object found or the default value.
Raises
------
:exc:`AttributeError`
The object is missing a ``get_`` or ``fetch_`` method
:exc:`NotFound`
Invalid ID for the object
:exc:`HTTPException`
An error occurred fetching the object
:exc:`Forbidden`
You do not have permission to fetch the object
Examples
--------
Getting a guild from a guild ID: ::
guild = await utils.get_or_fetch(client, "guild", guild_id)
Getting a channel from the guild. If the channel is not found, return None: ::
channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None)
"""
getter = getattr(obj, f"get_{attr}")(id)
if getter is None:
try:
getter = await getattr(obj, f"fetch_{attr}")(id)
except AttributeError as e:
getter = await getattr(obj, f"_fetch_{attr}")(id)
if getter is None:
raise ValueError(f"Could not find {attr} with id {id} on {obj}") from e
except (HTTPException, ValueError):
if default is not MISSING:
return default
else:
raise
return getter
60 changes: 28 additions & 32 deletions discord/utils/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
Sequence,
TypeVar,
Union,
get_args,
overload,
reveal_type,
)

from ..errors import HTTPException, InvalidArgument
Expand Down Expand Up @@ -69,9 +71,6 @@ def resolve_invite(invite: Invite | str) -> str:
return invite


__all__ = ("resolve_invite",)


def get_as_snowflake(data: Any, key: str) -> int | None:
try:
value = data[key]
Expand Down Expand Up @@ -111,7 +110,7 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
return (reset - now).total_seconds()


def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
def string_width(string: str, *, _IS_ASCII: re.Pattern[str] = _IS_ASCII) -> int:
"""Returns string's width."""
match = _IS_ASCII.match(string)
if match:
Expand All @@ -138,7 +137,7 @@ def resolve_template(code: Template | str) -> str:
:class:`str`
The template code.
"""
from ..template import Template # noqa: PLC0415 # circular import
from ..template import Template # noqa: PLC0415

if isinstance(code, Template):
return code.code
Expand Down Expand Up @@ -167,10 +166,6 @@ def parse_time(timestamp: None) -> None: ...
def parse_time(timestamp: str) -> datetime.datetime: ...


@overload
def parse_time(timestamp: str | None) -> datetime.datetime | None: ...


def parse_time(timestamp: str | None) -> datetime.datetime | None:
"""A helper function to convert an ISO 8601 timestamp to a datetime object.

Expand Down Expand Up @@ -242,7 +237,7 @@ def deprecated(
stacklevel: int = 3,
*,
use_qualname: bool = True,
) -> Callable[[Callable[[P], T]], Callable[[P], T]]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when
the decorated function is called.

Expand All @@ -267,7 +262,7 @@ def deprecated(
will display as ``login``. Defaults to ``True``.
"""

def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]:
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warn_deprecated(
Expand All @@ -289,11 +284,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T:


def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
params = []
params: list[Any] = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
params.extend(get_args(p))
else:
params.append(p)
return tuple(params)
Expand All @@ -311,7 +306,7 @@ def evaluate_annotation(
cache: dict[str, Any],
*,
implicit_str: bool = True,
):
) -> Any:
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
Expand All @@ -329,8 +324,8 @@ def evaluate_annotation(
is_literal = False
args = tp.__args__
if not hasattr(tp, "__origin__"):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore # noqa: UP007
if PY_310 and tp.__class__ is types.UnionType:
converted = Union[args] # noqa: UP007
return evaluate_annotation(converted, globals, locals, cache)

return tp
Expand Down Expand Up @@ -381,7 +376,7 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)


def delay_task(delay: float, func: Coroutine):
def delay_task(delay: float, func: Coroutine[Any, Any, Any]):
async def inner_call():
await asyncio.sleep(delay)
try:
Expand All @@ -404,11 +399,12 @@ async def async_all(gen: Iterable[Any]) -> bool:
async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T:
value = f(*args, **kwargs)
if isawaitable(value):
reveal_type(f)
return await value
return value


async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]:
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]:
ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)

Expand All @@ -418,7 +414,7 @@ async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> s
return done


class SnowflakeList(array.array):
class SnowflakeList(array.array[int]):
"""Internal data storage class to efficiently store a list of snowflakes.

This should have the following characteristics:
Expand All @@ -437,7 +433,7 @@ class SnowflakeList(array.array):
def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ...

def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
return super().__new__(cls, "Q", data if is_sorted else sorted(data))

def add(self, element: int) -> None:
i = bisect_left(self, element)
Expand All @@ -452,22 +448,22 @@ def has(self, element: int) -> bool:
return i != len(self) and self[i] == element


def copy_doc(original: Callable) -> Callable[[T], T]:
def copy_doc(original: Callable[..., object]) -> Callable[[T], T]:
def decorator(overridden: T) -> T:
overridden.__doc__ = original.__doc__
overridden.__signature__ = signature(original) # type: ignore
overridden.__signature__ = signature(original)
return overridden

return decorator


class SequenceProxy(collections.abc.Sequence, Generic[T_co]):
class SequenceProxy(collections.abc.Sequence[T_co], Generic[T_co]):
"""Read-only proxy of a Sequence."""

def __init__(self, proxied: Sequence[T_co]):
self.__proxied = proxied

def __getitem__(self, idx: int) -> T_co:
def __getitem__(self, idx: int) -> T_co: # type: ignore[override]
return self.__proxied[idx]

def __len__(self) -> int:
Expand All @@ -482,7 +478,7 @@ def __iter__(self) -> Iterator[T_co]:
def __reversed__(self) -> Iterator[T_co]:
return reversed(self.__proxied)

def index(self, value: Any, *args, **kwargs) -> int:
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
return self.__proxied.index(value, *args, **kwargs)

def count(self, value: Any) -> int:
Expand Down Expand Up @@ -515,10 +511,10 @@ def __get__(self, instance: T | None, owner: type[T]) -> Any:

def get_slots(cls: type[Any]) -> Iterator[str]:
for mro in reversed(cls.__mro__):
try:
yield from mro.__slots__
except AttributeError:
slots = getattr(mro, "__slots__", None)
if slots is None:
continue
yield from slots


def cached_slot_property(
Expand All @@ -533,15 +529,15 @@ def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
try:
import msgspec

def to_json(obj: Any) -> str: # type: ignore
def _to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction]
return msgspec.json.encode(obj).decode("utf-8")

from_json = msgspec.json.decode # type: ignore
_from_json = msgspec.json.decode

except ModuleNotFoundError:
import json

def to_json(obj: Any) -> str:
def _to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction]
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)

from_json = json.loads
_from_json = json.loads
Loading
Loading