Skip to content

Commit cbf262a

Browse files
committed
fix(application_commands): Fixed command checks not working
feat(application_commands): Added the ability to pass the values directly for slash command choices
1 parent e7ec61f commit cbf262a

File tree

1 file changed

+76
-48
lines changed

1 file changed

+76
-48
lines changed

discord/application_commands.py

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,20 @@
2323
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
2424
DEALINGS IN THE SOFTWARE.
2525
"""
26+
from __future__ import annotations
27+
28+
from typing import (
29+
Union,
30+
Optional,
31+
List,
32+
Dict,
33+
Any,
34+
TYPE_CHECKING,
35+
Coroutine,
36+
Awaitable
37+
)
2638

39+
from typing_extensions import Literal
2740

2841
import re
2942
import copy
@@ -33,16 +46,14 @@
3346
from types import FunctionType
3447

3548
from .utils import async_all, find, get, snowflake_time
36-
from typing_extensions import Literal
3749
from .abc import GuildChannel
38-
from typing import Union, Optional, List, Dict, Any, TYPE_CHECKING, Coroutine, Awaitable
3950
from .enums import ApplicationCommandType, InteractionType, ChannelType, OptionType, Locale, try_enum
4051
from .permissions import Permissions
4152

4253
if TYPE_CHECKING:
43-
from .ext.commands import Cog, Converter
4454
from datetime import datetime
4555
from .guild import Guild
56+
from .ext.commands import Cog, Greedy, Converter
4657
from .interactions import BaseInteraction
4758

4859
__all__ = (
@@ -64,7 +75,7 @@
6475
api_docs = 'https://discord.com/developers/docs'
6576

6677

67-
# TODO: Add a (optional) feature for auto generated name_localizations & description_localizations by a translator
78+
# TODO: Add a (optional) feature for auto generated localizations by a translator
6879

6980
class Localizations:
7081
"""
@@ -160,7 +171,7 @@ def __getitem__(self, item) -> Optional[str]:
160171
raise KeyError(f'There is no locale value set for {locale.name}.')
161172

162173

163-
def __setitem__(self, key, value):
174+
def __setitem__(self, key, value) -> None:
164175
self.__languages_dict__[Locale[key].value] = value
165176

166177
def __bool__(self) -> bool:
@@ -174,10 +185,11 @@ def from_dict(cls, data: Dict[str, str]) -> 'Localizations':
174185
data = data or {}
175186
return cls(**{try_enum(Locale, key): value for key, value in data.items()})
176187

177-
def update(self, __m: 'Localizations'):
188+
def update(self, __m: 'Localizations') -> None:
189+
"""Similar to :meth:`dict.update`"""
178190
self.__languages_dict__.update(__m.__languages_dict__)
179191

180-
def from_target(self, target: Union['Guild', 'BaseInteraction'], *, default: Any = None):
192+
def from_target(self, target: Union[Guild, BaseInteraction], *, default: Any = None):
181193
"""
182194
Returns the value for the local of the object (if it's set), or :attr:`default`(:class:`None`)
183195
@@ -190,17 +202,16 @@ def from_target(self, target: Union['Guild', 'BaseInteraction'], *, default: Any
190202
The value or an object to return by default if there is no value for the locale of :attr:`target` set.
191203
Default to :class:`None` or :class:`~discord.Locale.english_US`/:class:`~discord.Locale.english_GB`
192204
193-
Return
194-
------
205+
Returns
206+
-------
195207
Union[:class:`str`, None]
196-
The value of the locale or :class:`None` if there is no value for the locale set.
208+
The value of the locale or :obj:`None` if there is no value for the locale set.
197209
198210
Raises
199211
------
200212
:exc:`TypeError`
201213
If :attr:`target` is of the wrong type.
202214
"""
203-
return_default = False
204215
if hasattr(target, 'preferred_locale'):
205216
try:
206217
return self[target.preferred_locale.value]
@@ -255,7 +266,7 @@ def __init__(self, type: int, *args, **kwargs):
255266
self.name_localizations: Localizations = kwargs.get('name_localizations', Localizations())
256267
self.description_localizations: Localizations = kwargs.get('description_localizations', Localizations())
257268

258-
def __getitem__(self, item):
269+
def __getitem__(self, item) -> Any:
259270
return getattr(self, item)
260271

261272
@property
@@ -268,22 +279,23 @@ def _state(self, value):
268279

269280
@property
270281
def cog(self) -> Optional['Cog']:
282+
"""Optional[:class:`~discord.ext.commands.Cog`]: The cog associated with this command if any."""
271283
return getattr(self, '_cog', None)
272284

273285
@cog.setter
274-
def cog(self, __cog: 'Cog'):
286+
def cog(self, __cog: 'Cog') -> None:
275287
setattr(self, '_cog', __cog)
276288

277-
def _set_cog(self, cog: 'Cog', recursive: bool = False):
289+
def _set_cog(self, cog: 'Cog', recursive: bool = False) -> None:
278290
self.cog = cog
279291

280292
def __call__(self, *args, **kwargs):
281293
return super().__init__(self, *args, **kwargs)
282294

283-
def __repr__(self):
295+
def __repr__(self) -> str:
284296
return '<%s name=%s, id=%s, disabled=%s>' % (self.__class__.__name__, self.name, self.id, self.disabled)
285297

286-
def __eq__(self, other):
298+
def __eq__(self, other) -> bool:
287299
if isinstance(other, self.__class__):
288300
other = other.to_dict()
289301
if isinstance(other, dict):
@@ -335,19 +347,19 @@ def check_options(_options: list, _other: list):
335347
and check_options(options, other.get('options', [])))
336348
return False
337349

338-
def __ne__(self, other):
350+
def __ne__(self, other) -> bool:
339351
return not self.__eq__(other)
340352

341-
def _fill_data(self, data):
353+
def _fill_data(self, data) -> ApplicationCommand:
342354
self._id = int(data.get('id', 0))
343355
self.application_id = int(data.get('application_id', 0))
344356
self._guild_id = int(data.get('guild_id', 0))
345357
self._permissions = data.get('permissions', {})
346358
return self
347359

348-
async def can_run(self, *args, **kwargs):
360+
async def can_run(self, *args, **kwargs) -> bool:
349361
check_func = kwargs.pop('__func', self)
350-
checks = getattr(check_func, '__command_checks__', getattr(self.func, '__command_checks__', None))
362+
checks = getattr(check_func, '__commands_checks__', getattr(self.func, '__commands_checks__', None))
351363
if not checks:
352364
return True
353365

@@ -370,13 +382,14 @@ async def invoke(self, interaction, *args, **kwargs):
370382
else:
371383
self._state.dispatch('application_command_error', self, interaction, exc)
372384

373-
def error(self, coro):
385+
def error(self, coro) -> Coroutine:
386+
"""A decorator to set an error handler for this command similar to :func:`on_application_command_error` but only for this command"""
374387
if not asyncio.iscoroutinefunction(coro):
375388
raise TypeError('The error handler must be a coroutine.')
376389
self.on_error = coro
377390
return coro
378391

379-
def to_dict(self):
392+
def to_dict(self) -> dict:
380393
base = {
381394
'type': int(self.type),
382395
'name': str(self.name),
@@ -394,22 +407,24 @@ def to_dict(self):
394407
return base
395408

396409
@property
397-
def id(self):
410+
def id(self) -> Optional[int]:
398411
"""Optional[:class:`int`]: The id of the command, only set if the bot is running"""
399412
return getattr(self, '_id', None)
400413

401414
@property
402-
def created_at(self) -> Optional['datetime']:
415+
def created_at(self) -> Optional[datetime]:
403416
"""Optional[:class:`datetime.datetime`]: The creation time of the command in UTC, only set if the bot is running"""
404417
if self.id:
405418
return snowflake_time(self.id)
406419

407420
@property
408421
def type(self) -> ApplicationCommandType:
422+
""":class:`ApplicationCommandType`: The type of the command"""
409423
return try_enum(ApplicationCommandType, self._type)
410424

411425
@property
412-
def guild_id(self):
426+
def guild_id(self) -> Optional[int]:
427+
"""Optional[:class:`int`]: Th id this command belongs to, if any"""
413428
return self._guild_id
414429

415430
@property
@@ -443,7 +458,7 @@ def _sorted_by_type(cls, commands):
443458
sorted_dict[predicate].append(cmd)
444459
return sorted_dict
445460

446-
async def delete(self):
461+
async def delete(self) -> None:
447462
"""|coro|
448463
449464
Deletes the application command
@@ -463,19 +478,23 @@ class SlashCommandOptionChoice:
463478
464479
Parameters
465480
-----------
466-
name: :class:`str`
481+
name: Union[:class:`str`, :class:`int`, :class:`float`]
467482
The 1-100 characters long name that will show in the client.
468-
value: Union[:class:`str`, :class:`int`, :class:`float`]
483+
value: Union[:class:`str`, :class:`int`, :class:`float`, :obj:`None`]
469484
The value that will send as the options value.
470-
Must be of the type the option is of (:class:`str`, :class:`int` or :class:`float`).
485+
Must be of the type the :class:`SlashCommandOption` is of (:class:`str`, :class:`int` or :class:`float`).
486+
487+
.. note::
488+
If this is left empty it takes the :attr:`~SlashCommandOption.name` as value.
489+
471490
name_localizations: Optional[:class:`Localizations`]
472491
Localized names for the choice.
473492
"""
474-
def __init__(self, name: str, value: Union[str, int, float] = None, name_localizations: Optional[Localizations] = Localizations()):
493+
def __init__(self, name: Union[str, int, float], value: Union[str, int, float] = None, name_localizations: Optional[Localizations] = Localizations()):
475494

476-
if 100 < len(name) < 1:
495+
if 100 < len(str(name)) < 1:
477496
raise ValueError('The name of a choice must bee between 1 and 100 characters long, got %s.' % len(name))
478-
self.name = name
497+
self.name = str(name)
479498
self.value = value if value is not None else name
480499
self.name_localizations = name_localizations
481500

@@ -519,19 +538,24 @@ class SlashCommandOption:
519538
required: Optional[:class:`bool`]
520539
Weather this option must be provided by the user, default ``True``.
521540
If ``False``, the parameter of the slash-command that takes this option needs a default value.
522-
choices: Optional[List[:class:`SlashCommandOptionChoice`]]
541+
choices: Optional[List[Union[:class:`SlashCommandOptionChoice`, :class:`str`, :class:`int`, :class:`float`]]]
523542
A list of up to 25 choices the user could select. Only valid if the :attr:`option_type` is one of
524-
:class:`OptionType.string`, :class:`OptionType.integer` or :class:`OptionType.number`.
525-
The :attr:`value`'s of the choices must be of the :attr:`~SlashCommandOption.option_type` of this option
543+
:attr:`~OptionType.string`, :attr:`~OptionType.integer` or :attr:`~OptionType.number`.
544+
545+
.. note::
546+
If you want to have values that are not the same as their name, you can use :class:`SlashCommandOptionChoice`
547+
548+
The :attr:`~SlashCommandOptionChoice.value`'s of the choices must be of the :attr:`~SlashCommandOption.option_type` of this option
526549
(e.g. :class:`str`, :class:`int` or :class:`float`).
527550
If choices are set they are the only options a user could pass.
528551
autocomplete: Optional[:class:`bool`]
529552
Whether to enable
530553
`autocomplete <https://discord.com/developers/docs/interactions/application-commands#autocomplete>`_
531554
interactions for this option, default ``False``.
532555
With autocomplete, you can check the user's input and send matching choices to the client.
533-
**Autocomplete can only be used with options of the type** ``string``, ``integer`` or ``number``.
534-
**If autocomplete is activated, the option cannot have** :attr:`~SlashCommandOption.choices` **.**
556+
.. note::
557+
Autocomplete can only be used with options of the type :attr:`~OptionType.string`, :attr:`~OptionType.integer` or :attr:`~OptionType.number`.
558+
**If autocomplete is activated, the option cannot have** :attr:`~SlashCommandOption.choices` **.**
535559
min_value: Optional[Union[:class:`int`, :class:`float`]]
536560
If the :attr:`~SlashCommandOption.option_type` is one of :attr:`~OptionType.integer` or :attr:`~OptionType.number`
537561
this is the minimum value the users input must be of.
@@ -558,7 +582,7 @@ def __init__(self,
558582
name_localizations: Optional[Localizations] = Localizations(),
559583
description_localizations: Optional[Localizations] = Localizations(),
560584
required: bool = True,
561-
choices: Optional[List[SlashCommandOptionChoice]] = [],
585+
choices: Optional[List[Union[SlashCommandOptionChoice, str, int, float]]] = [],
562586
autocomplete: bool = False,
563587
min_value: Optional[Union[int, float]] = None,
564588
max_value: Optional[Union[int, float]] = None,
@@ -585,27 +609,30 @@ def __init__(self,
585609
f'{api_docs}/interactions/application-commands#application-command-object-application-command-naming.'
586610
f'Got "{name}" with length {len(name)}.'
587611
)
588-
self.name = name
612+
self.name: str = name
589613
self.name_localizations: Localizations = name_localizations
590614
if 100 < len(description) < 1:
591615
raise ValueError('The description must be between 1 and 100 characters long, got %s.' % len(description))
592-
self.description = description
616+
self.description: str = description
593617
self.description_localizations: Localizations = description_localizations
594-
self.required = required
618+
self.required: bool = required
595619
options = kwargs.get('__options', [])
596620
if self.type == 2 and (not options):
597621
raise ValueError('You need to pass __options if the option_type is subcommand_group.')
598622
self._options = options
599623
self.autocomplete: bool = autocomplete
600624
self.min_value: Optional[Union[int, float]] = min_value
601625
self.max_value: Optional[Union[int, float]] = max_value
602-
self.choices: Optional[List[SlashCommandOptionChoice]] = choices
626+
for index, choice in enumerate(choices): # TODO: find a more efficient way to do this
627+
if not isinstance(choice, SlashCommandOptionChoice):
628+
choices[index] = SlashCommandOptionChoice(choice)
629+
self.choices: List[SlashCommandOptionChoice] = choices
603630
self.channel_types: Optional[List[Union[GuildChannel, ChannelType, int]]] = channel_types
604-
self.default = default
605-
self.converter = converter
606-
self.ignore_conversion_failures = ignore_conversion_failures
631+
self.default: Any = default
632+
self.converter: Union[Greedy, Converter] = converter
633+
self.ignore_conversion_failures: bool = ignore_conversion_failures
607634

608-
def __repr__(self):
635+
def __repr__(self) -> str:
609636
return '<SlashCommandOption type=%s, name=%s, description=%s, required=%s, choices=%s>'\
610637
% (self.type,
611638
self.name,
@@ -744,7 +771,7 @@ def to_dict(self) -> dict:
744771
return base
745772

746773
@classmethod
747-
def from_dict(cls, data):
774+
def from_dict(cls, data) -> SlashCommandOption:
748775
option_type: OptionType = try_enum(OptionType, data['type'])
749776
if option_type.sub_command_group:
750777
return SubCommandGroup.from_dict(data)
@@ -815,7 +842,8 @@ def to_dict(self):
815842
async def can_run(self, *args, **kwargs):
816843
if self.cog is not None:
817844
args = (self.cog, *args)
818-
checks = getattr(self, '__command_checks__', [])
845+
check_func = kwargs.pop('__func', self)
846+
checks = getattr(check_func, '__commands_checks__', getattr(self.func, '__commands_checks__', None))
819847
if not checks:
820848
return True
821849

0 commit comments

Comments
 (0)