Skip to content

Commit e081b9c

Browse files
feat: Improve typing of utils (#118)
Co-authored-by: Paillat <[email protected]>
1 parent 2e1025a commit e081b9c

File tree

6 files changed

+60
-129
lines changed

6 files changed

+60
-129
lines changed

discord/client.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@
6363
from .ui.view import View
6464
from .user import ClientUser, User
6565
from .utils import MISSING
66-
from .utils.private import SequenceProxy, bytes_to_base64_data, resolve_invite, resolve_template
66+
from .utils.private import (
67+
SequenceProxy,
68+
bytes_to_base64_data,
69+
resolve_invite,
70+
resolve_template,
71+
)
6772
from .voice_client import VoiceClient
6873
from .webhook import Webhook
6974
from .widget import Widget
@@ -629,7 +634,10 @@ async def login(self, token: str) -> None:
629634
data = await self.http.static_login(token.strip())
630635
self._connection.user = ClientUser(state=self._connection, data=data)
631636

632-
print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or "discord")
637+
print_banner(
638+
bot_name=self._connection.user.display_name,
639+
module=self._banner_module or "discord",
640+
)
633641
start_logging(self._flavor, debug=self._debug)
634642

635643
async def connect(self, *, reconnect: bool = True) -> None:
@@ -1130,24 +1138,6 @@ def get_all_members(self) -> Generator[Member]:
11301138
for guild in self.guilds:
11311139
yield from guild.members
11321140

1133-
async def get_or_fetch_user(self, id: int, /) -> User | None:
1134-
"""|coro|
1135-
1136-
Looks up a user in the user cache or fetches if not found.
1137-
1138-
Parameters
1139-
----------
1140-
id: :class:`int`
1141-
The ID to search for.
1142-
1143-
Returns
1144-
-------
1145-
Optional[:class:`~discord.User`]
1146-
The user or ``None`` if not found.
1147-
"""
1148-
1149-
return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None)
1150-
11511141
# listeners/waiters
11521142

11531143
async def wait_until_ready(self) -> None:

discord/guild.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,10 @@ def get_member_named(self, name: str, /) -> Member | None:
10201020

10211021
# do the actual lookup and return if found
10221022
# if it isn't found then we'll do a full name lookup below.
1023-
result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members)
1023+
result = utils.find(
1024+
lambda m: m.name == name[:-5] and discriminator == potential_discriminator,
1025+
members,
1026+
)
10241027
if result is not None:
10251028
return result
10261029

@@ -3312,7 +3315,10 @@ async def widget(self) -> Widget:
33123315
return Widget(state=self._state, data=data)
33133316

33143317
async def edit_widget(
3315-
self, *, enabled: bool | utils.Undefined = MISSING, channel: Snowflake | None | utils.Undefined = MISSING
3318+
self,
3319+
*,
3320+
enabled: bool | utils.Undefined = MISSING,
3321+
channel: Snowflake | None | utils.Undefined = MISSING,
33163322
) -> None:
33173323
"""|coro|
33183324

discord/utils/__init__.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525

2626
from __future__ import annotations
2727

28-
from typing import (
29-
Any,
30-
)
31-
32-
from ..errors import HTTPException
3328
from .public import (
3429
MISSING,
3530
UNICODE_EMOJIS,
@@ -56,7 +51,6 @@
5651
"oauth_url",
5752
"snowflake_time",
5853
"find",
59-
"get_or_fetch",
6054
"utcnow",
6155
"remove_markdown",
6256
"escape_markdown",
@@ -71,63 +65,3 @@
7165
"MISSING",
7266
"UNICODE_EMOJIS",
7367
)
74-
75-
76-
async def get_or_fetch(obj, attr: str, id: int, *, default: Any = MISSING) -> Any:
77-
"""|coro|
78-
79-
Attempts to get an attribute from the object in cache. If it fails, it will attempt to fetch it.
80-
If the fetch also fails, an error will be raised.
81-
82-
Parameters
83-
----------
84-
obj: Any
85-
The object to use the get or fetch methods in
86-
attr: :class:`str`
87-
The attribute to get or fetch. Note the object must have both a ``get_`` and ``fetch_`` method for this attribute.
88-
id: :class:`int`
89-
The ID of the object
90-
default: Any
91-
The default value to return if the object is not found, instead of raising an error.
92-
93-
Returns
94-
-------
95-
Any
96-
The object found or the default value.
97-
98-
Raises
99-
------
100-
:exc:`AttributeError`
101-
The object is missing a ``get_`` or ``fetch_`` method
102-
:exc:`NotFound`
103-
Invalid ID for the object
104-
:exc:`HTTPException`
105-
An error occurred fetching the object
106-
:exc:`Forbidden`
107-
You do not have permission to fetch the object
108-
109-
Examples
110-
--------
111-
112-
Getting a guild from a guild ID: ::
113-
114-
guild = await utils.get_or_fetch(client, "guild", guild_id)
115-
116-
Getting a channel from the guild. If the channel is not found, return None: ::
117-
118-
channel = await utils.get_or_fetch(guild, "channel", channel_id, default=None)
119-
"""
120-
getter = getattr(obj, f"get_{attr}")(id)
121-
if getter is None:
122-
try:
123-
getter = await getattr(obj, f"fetch_{attr}")(id)
124-
except AttributeError as e:
125-
getter = await getattr(obj, f"_fetch_{attr}")(id)
126-
if getter is None:
127-
raise ValueError(f"Could not find {attr} with id {id} on {obj}") from e
128-
except (HTTPException, ValueError):
129-
if default is not MISSING:
130-
return default
131-
else:
132-
raise
133-
return getter

discord/utils/private.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Sequence,
2929
TypeVar,
3030
Union,
31+
get_args,
3132
overload,
3233
)
3334

@@ -69,9 +70,6 @@ def resolve_invite(invite: Invite | str) -> str:
6970
return invite
7071

7172

72-
__all__ = ("resolve_invite",)
73-
74-
7573
def get_as_snowflake(data: Any, key: str) -> int | None:
7674
try:
7775
value = data[key]
@@ -111,7 +109,7 @@ def parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
111109
return (reset - now).total_seconds()
112110

113111

114-
def string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
112+
def string_width(string: str, *, _IS_ASCII: re.Pattern[str] = _IS_ASCII) -> int:
115113
"""Returns string's width."""
116114
match = _IS_ASCII.match(string)
117115
if match:
@@ -138,7 +136,7 @@ def resolve_template(code: Template | str) -> str:
138136
:class:`str`
139137
The template code.
140138
"""
141-
from ..template import Template # noqa: PLC0415 # circular import
139+
from ..template import Template # noqa: PLC0415
142140

143141
if isinstance(code, Template):
144142
return code.code
@@ -167,10 +165,6 @@ def parse_time(timestamp: None) -> None: ...
167165
def parse_time(timestamp: str) -> datetime.datetime: ...
168166

169167

170-
@overload
171-
def parse_time(timestamp: str | None) -> datetime.datetime | None: ...
172-
173-
174168
def parse_time(timestamp: str | None) -> datetime.datetime | None:
175169
"""A helper function to convert an ISO 8601 timestamp to a datetime object.
176170
@@ -218,7 +212,6 @@ def warn_deprecated(
218212
stacklevel: :class:`int`
219213
The stacklevel kwarg passed to :func:`warnings.warn`. Defaults to 3.
220214
"""
221-
warnings.simplefilter("always", DeprecationWarning) # turn off filter
222215
message = f"{name} is deprecated"
223216
if since:
224217
message += f" since version {since}"
@@ -231,7 +224,6 @@ def warn_deprecated(
231224
message += f" See {reference} for more information."
232225

233226
warnings.warn(message, stacklevel=stacklevel, category=DeprecationWarning)
234-
warnings.simplefilter("default", DeprecationWarning) # reset filter
235227

236228

237229
def deprecated(
@@ -242,7 +234,7 @@ def deprecated(
242234
stacklevel: int = 3,
243235
*,
244236
use_qualname: bool = True,
245-
) -> Callable[[Callable[[P], T]], Callable[[P], T]]:
237+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
246238
"""A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when
247239
the decorated function is called.
248240
@@ -267,7 +259,7 @@ def deprecated(
267259
will display as ``login``. Defaults to ``True``.
268260
"""
269261

270-
def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]:
262+
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
271263
@functools.wraps(func)
272264
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
273265
warn_deprecated(
@@ -289,11 +281,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
289281

290282

291283
def flatten_literal_params(parameters: Iterable[Any]) -> tuple[Any, ...]:
292-
params = []
284+
params: list[Any] = []
293285
literal_cls = type(Literal[0])
294286
for p in parameters:
295287
if isinstance(p, literal_cls):
296-
params.extend(p.__args__)
288+
params.extend(get_args(p))
297289
else:
298290
params.append(p)
299291
return tuple(params)
@@ -311,7 +303,7 @@ def evaluate_annotation(
311303
cache: dict[str, Any],
312304
*,
313305
implicit_str: bool = True,
314-
):
306+
) -> Any:
315307
if isinstance(tp, ForwardRef):
316308
tp = tp.__forward_arg__
317309
# ForwardRefs always evaluate their internals
@@ -329,8 +321,8 @@ def evaluate_annotation(
329321
is_literal = False
330322
args = tp.__args__
331323
if not hasattr(tp, "__origin__"):
332-
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
333-
converted = Union[args] # type: ignore # noqa: UP007
324+
if PY_310 and tp.__class__ is types.UnionType:
325+
converted = Union[args] # noqa: UP007
334326
return evaluate_annotation(converted, globals, locals, cache)
335327

336328
return tp
@@ -381,7 +373,7 @@ def resolve_annotation(
381373
return evaluate_annotation(annotation, globalns, locals, cache)
382374

383375

384-
def delay_task(delay: float, func: Coroutine):
376+
def delay_task(delay: float, func: Coroutine[Any, Any, Any]):
385377
async def inner_call():
386378
await asyncio.sleep(delay)
387379
try:
@@ -408,7 +400,7 @@ async def maybe_awaitable(f: Callable[P, T | Awaitable[T]], *args: P.args, **kwa
408400
return value
409401

410402

411-
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Future[T]]:
403+
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> set[asyncio.Task[T]]:
412404
ensured = [asyncio.ensure_future(fut) for fut in futures]
413405
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
414406

@@ -418,7 +410,7 @@ async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: float) -> s
418410
return done
419411

420412

421-
class SnowflakeList(array.array):
413+
class SnowflakeList(array.array[int]):
422414
"""Internal data storage class to efficiently store a list of snowflakes.
423415
424416
This should have the following characteristics:
@@ -437,7 +429,7 @@ class SnowflakeList(array.array):
437429
def __init__(self, data: Iterable[int], *, is_sorted: bool = False): ...
438430

439431
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
440-
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
432+
return super().__new__(cls, "Q", data if is_sorted else sorted(data))
441433

442434
def add(self, element: int) -> None:
443435
i = bisect_left(self, element)
@@ -452,22 +444,27 @@ def has(self, element: int) -> bool:
452444
return i != len(self) and self[i] == element
453445

454446

455-
def copy_doc(original: Callable) -> Callable[[T], T]:
447+
def copy_doc(original: Callable[..., object]) -> Callable[[T], T]:
456448
def decorator(overridden: T) -> T:
457449
overridden.__doc__ = original.__doc__
458-
overridden.__signature__ = signature(original) # type: ignore
450+
overridden.__signature__ = signature(original) # type: ignore[reportAttributeAccessIssue]
459451
return overridden
460452

461453
return decorator
462454

463455

464-
class SequenceProxy(collections.abc.Sequence, Generic[T_co]):
456+
class SequenceProxy(collections.abc.Sequence[T_co], Generic[T_co]):
465457
"""Read-only proxy of a Sequence."""
466458

467459
def __init__(self, proxied: Sequence[T_co]):
468460
self.__proxied = proxied
469461

470-
def __getitem__(self, idx: int) -> T_co:
462+
@overload
463+
def __getitem__(self, idx: int) -> T_co: ...
464+
@overload
465+
def __getitem__(self, idx: slice) -> Sequence[T_co]: ...
466+
467+
def __getitem__(self, idx: int | slice) -> T_co | Sequence[T_co]:
471468
return self.__proxied[idx]
472469

473470
def __len__(self) -> int:
@@ -482,7 +479,7 @@ def __iter__(self) -> Iterator[T_co]:
482479
def __reversed__(self) -> Iterator[T_co]:
483480
return reversed(self.__proxied)
484481

485-
def index(self, value: Any, *args, **kwargs) -> int:
482+
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
486483
return self.__proxied.index(value, *args, **kwargs)
487484

488485
def count(self, value: Any) -> int:
@@ -515,10 +512,10 @@ def __get__(self, instance: T | None, owner: type[T]) -> Any:
515512

516513
def get_slots(cls: type[Any]) -> Iterator[str]:
517514
for mro in reversed(cls.__mro__):
518-
try:
519-
yield from mro.__slots__
520-
except AttributeError:
515+
slots = getattr(mro, "__slots__", None)
516+
if slots is None:
521517
continue
518+
yield from slots
522519

523520

524521
def cached_slot_property(
@@ -533,15 +530,15 @@ def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
533530
try:
534531
import msgspec
535532

536-
def to_json(obj: Any) -> str: # type: ignore
533+
def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction]
537534
return msgspec.json.encode(obj).decode("utf-8")
538535

539-
from_json = msgspec.json.decode # type: ignore
536+
from_json = msgspec.json.decode
540537

541538
except ModuleNotFoundError:
542539
import json
543540

544-
def to_json(obj: Any) -> str:
541+
def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction]
545542
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
546543

547544
from_json = json.loads

0 commit comments

Comments
 (0)