Skip to content
30 changes: 10 additions & 20 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@
from .ui.view import View
from .user import ClientUser, User
from .utils import MISSING
from .utils.private import SequenceProxy, bytes_to_base64_data, resolve_invite, resolve_template
from .utils.private import (
SequenceProxy,
bytes_to_base64_data,
resolve_invite,
resolve_template,
)
from .voice_client import VoiceClient
from .webhook import Webhook
from .widget import Widget
Expand Down Expand Up @@ -629,7 +634,10 @@ async def login(self, token: str) -> None:
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)

print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or "discord")
print_banner(
bot_name=self._connection.user.display_name,
module=self._banner_module or "discord",
)
start_logging(self._flavor, debug=self._debug)

async def connect(self, *, reconnect: bool = True) -> None:
Expand Down Expand Up @@ -1130,24 +1138,6 @@ def get_all_members(self) -> Generator[Member]:
for guild in self.guilds:
yield from guild.members

async def get_or_fetch_user(self, id: int, /) -> User | None:
"""|coro|

Looks up a user in the user cache or fetches if not found.

Parameters
----------
id: :class:`int`
The ID to search for.

Returns
-------
Optional[:class:`~discord.User`]
The user or ``None`` if not found.
"""

return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None)

# listeners/waiters

async def wait_until_ready(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,10 @@ def get_member_named(self, name: str, /) -> Member | None:

# do the actual lookup and return if found
# if it isn't found then we'll do a full name lookup below.
result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members)
result = utils.find(
lambda m: m.name == name[:-5] and discriminator == potential_discriminator,
members,
)
if result is not None:
return result

Expand Down Expand Up @@ -3312,7 +3315,10 @@ async def widget(self) -> Widget:
return Widget(state=self._state, data=data)

async def edit_widget(
self, *, enabled: bool | utils.Undefined = MISSING, channel: Snowflake | None | utils.Undefined = MISSING
self,
*,
enabled: bool | utils.Undefined = MISSING,
channel: Snowflake | None | utils.Undefined = MISSING,
) -> None:
"""|coro|

Expand Down
66 changes: 0 additions & 66 deletions discord/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@

from __future__ import annotations

from typing import (
Any,
)

from ..errors import HTTPException
from .public import (
MISSING,
UNICODE_EMOJIS,
Expand All @@ -56,7 +51,6 @@
"oauth_url",
"snowflake_time",
"find",
"get_or_fetch",
"utcnow",
"remove_markdown",
"escape_markdown",
Expand All @@ -71,63 +65,3 @@
"MISSING",
"UNICODE_EMOJIS",
)


async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any:
"""|coro|
Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it.
If the fetch also fails, an error will be raised.
Parameters
----------
obj: Any
The object to use the get or fetch methods in
attr: :class:`str`
The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute.
id: :class:`int`
The ID of the object
default: Any
The default value to return if the object is not found, instead of raising an error.
Returns
-------
Any
The object found or the default value.
Raises
------
:exc:`AttributeError`
The object is missing a ``get_`` or ``fetch_`` method
:exc:`NotFound`
Invalid ID for the object
:exc:`HTTPException`
An error occurred fetching the object
:exc:`Forbidden`
You do not have permission to fetch the object
Examples
--------
Getting a guild from a guild ID: ::
guild = await utils.get_or_fetch(client, "guild", guild_id)
Getting a channel from the guild. If the channel is not found, return None: ::
channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None)
"""
getter = getattr(obj, f"get_{attr}")(id)
if getter is None:
try:
getter = await getattr(obj, f"fetch_{attr}")(id)
except AttributeError as e:
getter = await getattr(obj, f"_fetch_{attr}")(id)
if getter is None:
raise ValueError(f"Could not find {attr} with id {id} on {obj}") from e
except (HTTPException, ValueError):
if default is not MISSING:
return default
else:
raise
return getter
63 changes: 30 additions & 33 deletions discord/utils/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Sequence,
TypeVar,
Union,
get_args,
overload,
)

Expand Down Expand Up @@ -69,9 +70,6 @@ def resolve_invite(invite: Invite | str) -> str:
return invite


__all__ = ("resolve_invite",)


def get_as_snowflake(data: Any, key: str) -> int | None:
try:
value = data[key]
Expand Down Expand Up @@ -111,7 +109,7 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
return (reset - now).total_seconds()


def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
def string_width(string: str, *, _IS_ASCII: re.Pattern[str] = _IS_ASCII) -> int:
"""Returns string's width."""
match = _IS_ASCII.match(string)
if match:
Expand All @@ -138,7 +136,7 @@ def resolve_template(code: Template | str) -> str:
:class:`str`
The template code.
"""
from ..template import Template # noqa: PLC0415 # circular import
from ..template import Template # noqa: PLC0415

if isinstance(code, Template):
return code.code
Expand Down Expand Up @@ -167,10 +165,6 @@ def parse_time(timestamp: None) -> None: ...
def parse_time(timestamp: str) -> datetime.datetime: ...


@overload
def parse_time(timestamp: str | None) -> datetime.datetime | None: ...


def parse_time(timestamp: str | None) -> datetime.datetime | None:
"""A helper function to convert an ISO 8601 timestamp to a datetime object.

Expand Down Expand Up @@ -218,7 +212,6 @@ def warn_deprecated(
stacklevel: :class:`int`
The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3.
"""
warnings.simplefilter("always", DeprecationWarning) # turn off filter
message = f"{name} is deprecated"
if since:
message += f" since version {since}"
Expand All @@ -231,7 +224,6 @@ def warn_deprecated(
message += f" See {reference} for more information."

warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning)
warnings.simplefilter("default", DeprecationWarning) # reset filter


def deprecated(
Expand All @@ -242,7 +234,7 @@ def deprecated(
stacklevel: int = 3,
*,
use_qualname: bool = True,
) -> Callable[[Callable[[P], T]], Callable[[P], T]]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when
the decorated function is called.

Expand All @@ -267,7 +259,7 @@ def deprecated(
will display as ``login``. Defaults to ``True``.
"""

def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]:
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warn_deprecated(
Expand All @@ -289,11 +281,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T:


def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
params = []
params: list[Any] = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
params.extend(get_args(p))
else:
params.append(p)
return tuple(params)
Expand All @@ -311,7 +303,7 @@ def evaluate_annotation(
cache: dict[str, Any],
*,
implicit_str: bool = True,
):
) -> Any:
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
Expand All @@ -329,8 +321,8 @@ def evaluate_annotation(
is_literal = False
args = tp.__args__
if not hasattr(tp, "__origin__"):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore # noqa: UP007
if PY_310 and tp.__class__ is types.UnionType:
converted = Union[args] # noqa: UP007
return evaluate_annotation(converted, globals, locals, cache)

return tp
Expand Down Expand Up @@ -381,7 +373,7 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)


def delay_task(delay: float, func: Coroutine):
def delay_task(delay: float, func: Coroutine[Any, Any, Any]):
async def inner_call():
await asyncio.sleep(delay)
try:
Expand All @@ -408,7 +400,7 @@ async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwa
return value


async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]:
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]:
ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)

Expand All @@ -418,7 +410,7 @@ async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> s
return done


class SnowflakeList(array.array):
class SnowflakeList(array.array[int]):
"""Internal data storage class to efficiently store a list of snowflakes.

This should have the following characteristics:
Expand All @@ -437,7 +429,7 @@ class SnowflakeList(array.array):
def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ...

def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
return super().__new__(cls, "Q", data if is_sorted else sorted(data))

def add(self, element: int) -> None:
i = bisect_left(self, element)
Expand All @@ -452,22 +444,27 @@ def has(self, element: int) -> bool:
return i != len(self) and self[i] == element


def copy_doc(original: Callable) -> Callable[[T], T]:
def copy_doc(original: Callable[..., object]) -> Callable[[T], T]:
def decorator(overridden: T) -> T:
overridden.__doc__ = original.__doc__
overridden.__signature__ = signature(original) # type: ignore
overridden.__signature__ = signature(original) # type: ignore[reportAttributeAccessIssue]
return overridden

return decorator


class SequenceProxy(collections.abc.Sequence, Generic[T_co]):
class SequenceProxy(collections.abc.Sequence[T_co], Generic[T_co]):
"""Read-only proxy of a Sequence."""

def __init__(self, proxied: Sequence[T_co]):
self.__proxied = proxied

def __getitem__(self, idx: int) -> T_co:
@overload
def __getitem__(self, idx: int) -> T_co: ...
@overload
def __getitem__(self, idx: slice) -> Sequence[T_co]: ...

def __getitem__(self, idx: int | slice) -> T_co | Sequence[T_co]:
return self.__proxied[idx]

def __len__(self) -> int:
Expand All @@ -482,7 +479,7 @@ def __iter__(self) -> Iterator[T_co]:
def __reversed__(self) -> Iterator[T_co]:
return reversed(self.__proxied)

def index(self, value: Any, *args, **kwargs) -> int:
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
return self.__proxied.index(value, *args, **kwargs)

def count(self, value: Any) -> int:
Expand Down Expand Up @@ -515,10 +512,10 @@ def __get__(self, instance: T | None, owner: type[T]) -> Any:

def get_slots(cls: type[Any]) -> Iterator[str]:
for mro in reversed(cls.__mro__):
try:
yield from mro.__slots__
except AttributeError:
slots = getattr(mro, "__slots__", None)
if slots is None:
continue
yield from slots


def cached_slot_property(
Expand All @@ -533,15 +530,15 @@ def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
try:
import msgspec

def to_json(obj: Any) -> str: # type: ignore
def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction]
return msgspec.json.encode(obj).decode("utf-8")

from_json = msgspec.json.decode # type: ignore
from_json = msgspec.json.decode

except ModuleNotFoundError:
import json

def to_json(obj: Any) -> str:
def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction]
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)

from_json = json.loads
Loading