Skip to content

Commit 07e86ce

Browse files
committed
feat: Add optional check parameter in utils.basic_autocomplete
1 parent a2117ad commit 07e86ce

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

discord/utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,9 +1306,12 @@ def generate_snowflake(dt: datetime.datetime | None = None) -> int:
13061306
AV = Awaitable[V]
13071307
Values = Union[V, Callable[[AutocompleteContext], Union[V, AV]], AV]
13081308
AutocompleteFunc = Callable[[AutocompleteContext], AV]
1309+
CheckFunc = Callable[[AutocompleteContext, Any], Union[bool, Awaitable[bool]]]
13091310

13101311

1311-
def basic_autocomplete(values: Values) -> AutocompleteFunc:
1312+
def basic_autocomplete(
1313+
values: Values, *, check: CheckFunc | None = None
1314+
) -> AutocompleteFunc:
13121315
"""A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and
13131316
will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is
13141317
callable, it will be called with the AutocompleteContext.
@@ -1320,6 +1323,9 @@ def basic_autocomplete(values: Values) -> AutocompleteFunc:
13201323
values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]
13211324
Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a
13221325
single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`.
1326+
check: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]]
1327+
Predicate callable (sync or async) used to filter the autocomplete options. This function should accept two arguments:
1328+
the :class:`.AutocompleteContext` and an item from ``values``. If ``None`` is provided, a default check is used that includes items whose string representation starts with the user's input value, case-insensitive.
13231329
13241330
Returns
13251331
-------
@@ -1355,11 +1361,23 @@ async def autocomplete_callback(ctx: AutocompleteContext) -> V:
13551361
if asyncio.iscoroutine(_values):
13561362
_values = await _values
13571363

1358-
def check(item: Any) -> bool:
1359-
item = getattr(item, "name", item)
1360-
return str(item).lower().startswith(str(ctx.value or "").lower())
1364+
if check is None:
1365+
1366+
def _check(ctx: AutocompleteContext, item: Any) -> bool:
1367+
item = getattr(item, "name", item)
1368+
return str(item).lower().startswith(str(ctx.value or "").lower())
1369+
1370+
gen = (val for val in _values if _check(ctx, val))
1371+
1372+
elif asyncio.iscoroutinefunction(check):
1373+
gen = (val for val in _values if await check(ctx, val))
1374+
1375+
elif callable(check):
1376+
gen = (val for val in _values if check(ctx, val))
1377+
1378+
else:
1379+
raise TypeError("``check`` must be callable.")
13611380

1362-
gen = (val for val in _values if check(val))
13631381
return iter(itertools.islice(gen, 25))
13641382

13651383
return autocomplete_callback

0 commit comments

Comments
 (0)