Skip to content

Commit 30b1bbc

Browse files
committed
🏷️ Boring typing stuff
1 parent c2b4011 commit 30b1bbc

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

discord/commands/options.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626

2727
import inspect
2828
import logging
29+
from collections.abc import Awaitable, Callable, Iterable
2930
from enum import Enum
30-
from typing import TYPE_CHECKING, Iterable, Literal, Optional, Type, Union
31+
from typing import TYPE_CHECKING, Any, Literal, Optional, Type, TypeVar, Union
32+
33+
import discord
3134

3235
from ..abc import GuildChannel, Mentionable
3336
from ..channel import (
@@ -46,6 +49,8 @@
4649
from ..utils import MISSING, basic_autocomplete
4750

4851
if TYPE_CHECKING:
52+
from typing_extensions import TypeAlias
53+
4954
from ..ext.commands import Converter
5055
from ..member import Member
5156
from ..message import Attachment
@@ -114,6 +119,19 @@ def __init__(self, thread_type: Literal["public", "private", "news"]):
114119
AutocompleteReturnType = Union[
115120
Iterable["OptionChoice"], Iterable[str], Iterable[int], Iterable[float]
116121
]
122+
T = TypeVar("T", bound=AutocompleteReturnType)
123+
MaybeAwaitable: TypeAlias = Union[T, Awaitable[T]]
124+
AutocompleteFunction = Union[
125+
Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]],
126+
Callable[
127+
["discord.Cog", AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]
128+
],
129+
Callable[[AutocompleteContext, Any], MaybeAwaitable[AutocompleteReturnType]],
130+
Callable[
131+
["discord.Cog", AutocompleteContext, Any],
132+
MaybeAwaitable[AutocompleteReturnType],
133+
],
134+
]
117135

118136

119137
class Option:
@@ -268,7 +286,7 @@ def __init__(
268286
)
269287
self.default = kwargs.pop("default", None)
270288

271-
self._autocomplete = None
289+
self._autocomplete: AutocompleteFunction | None = None
272290
self.autocomplete = kwargs.pop("autocomplete", None)
273291
if len(enum_choices) > 25:
274292
self.choices: list[OptionChoice] = []
@@ -388,22 +406,17 @@ def __repr__(self):
388406
return f"<discord.commands.{self.__class__.__name__} name={self.name}>"
389407

390408
@property
391-
def autocomplete(self):
409+
def autocomplete(self) -> AutocompleteFunction | None:
392410
"""
393411
The autocomplete handler for the option. Accepts a callable (sync or async)
394-
that takes a single required argument of :class:`AutocompleteContext`.
412+
that takes a single required argument of :class:`AutocompleteContext` or two arguments
413+
of :class:`discord.Cog` (being the command's cog) and :class:`AutocompleteContext`.
395414
The callable must return an iterable of :class:`str` or :class:`OptionChoice`.
396415
Alternatively, :func:`discord.utils.basic_autocomplete` may be used in place of the callable.
397416
398417
Returns
399418
-------
400-
Union[
401-
Callable[[Self, AutocompleteContext, Any], AutocompleteReturnType],
402-
Callable[[AutocompleteContext, Any], AutocompleteReturnType],
403-
Callable[[Self, AutocompleteContext, Any], Awaitable[AutocompleteReturnType]],
404-
Callable[[AutocompleteContext, Any], Awaitable[AutocompleteReturnType]],
405-
None
406-
]
419+
Optional[AutocompleteFunction]
407420
408421
.. versionchanged:: 2.7
409422
@@ -413,17 +426,17 @@ def autocomplete(self):
413426
return self._autocomplete
414427

415428
@autocomplete.setter
416-
def autocomplete(self, value) -> None:
429+
def autocomplete(self, value: AutocompleteFunction | None) -> None:
417430
self._autocomplete = value
418431
# this is done here so it does not have to be computed every time the autocomplete is invoked
419432
if self._autocomplete is not None:
420-
self._autocomplete._is_instance_method = (
433+
self._autocomplete._is_instance_method = ( # pyright: ignore [reportFunctionMemberAccess]
421434
sum(
422435
1
423436
for param in inspect.signature(
424-
self.autocomplete
437+
self._autocomplete
425438
).parameters.values()
426-
if param.default == param.empty
439+
if param.default == param.empty # pyright: ignore[reportAny]
427440
and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
428441
)
429442
== 2

0 commit comments

Comments
 (0)