Skip to content

Commit 3b1a281

Browse files
authored
Fix required parameters validation error (#1589)
1 parent bb46862 commit 3b1a281

File tree

3 files changed

+30
-14
lines changed

3 files changed

+30
-14
lines changed

discord/bot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def add_application_command(self, command: ApplicationCommand) -> None:
132132
if isinstance(command, SlashCommand) and command.is_subcommand:
133133
raise TypeError("The provided command is a sub-command of group")
134134

135+
if command.cog is MISSING:
136+
command.cog = None
137+
135138
if self._bot.debug_guilds and command.guild_ids is None:
136139
command.guild_ids = self._bot.debug_guilds
137140

discord/cog.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,9 @@ def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
241241

242242
# Update the Command instances dynamically as well
243243
for command in new_cls.__cog_commands__:
244-
if (
245-
isinstance(command, ApplicationCommand)
246-
and command.guild_ids is None
247-
and len(new_cls.__cog_guild_ids__) != 0
248-
):
244+
if isinstance(command, ApplicationCommand) and not command.guild_ids and new_cls.__cog_guild_ids__:
249245
command.guild_ids = new_cls.__cog_guild_ids__
246+
250247
if not isinstance(command, SlashCommandGroup):
251248
setattr(new_cls, command.callback.__name__, command)
252249
parent = command.parent

discord/commands/core.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from ..role import Role
6464
from ..threads import Thread
6565
from ..user import User
66-
from ..utils import async_all, find, utcnow, maybe_coroutine
66+
from ..utils import async_all, find, utcnow, maybe_coroutine, MISSING
6767
from .context import ApplicationContext, AutocompleteContext
6868
from .options import Option, OptionChoice
6969

@@ -186,6 +186,7 @@ def __init__(self, func: Callable, **kwargs) -> None:
186186
buckets = cooldown
187187
else:
188188
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
189+
189190
self._buckets: CooldownMapping = buckets
190191

191192
max_concurrency = getattr(func, "__commands_max_concurrency__", kwargs.get("max_concurrency"))
@@ -646,13 +647,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
646647

647648
self.attached_to_group: bool = False
648649

649-
self.cog = None
650-
651-
params = self._get_signature_parameters()
652-
if kwop := kwargs.get("options", None):
653-
self.options: List[Option] = self._match_option_param_names(params, kwop)
654-
else:
655-
self.options: List[Option] = self._parse_options(params)
650+
self.options: List[Option] = kwargs.get("options", [])
656651

657652
try:
658653
checks = func.__commands_checks__
@@ -665,13 +660,21 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
665660
self._before_invoke = None
666661
self._after_invoke = None
667662

663+
self._cog = MISSING
664+
665+
def _validate_parameters(self):
666+
params = self._get_signature_parameters()
667+
if kwop := self.options:
668+
self.options: List[Option] = self._match_option_param_names(params, kwop)
669+
else:
670+
self.options: List[Option] = self._parse_options(params)
671+
668672
def _check_required_params(self, params):
669673
params = iter(params.items())
670674
required_params = (
671675
["self", "context"]
672676
if self.attached_to_group
673677
or self.cog
674-
or len(self.callback.__qualname__.split(".")) > 1
675678
else ["context"]
676679
)
677680
for p in required_params:
@@ -764,6 +767,15 @@ def _is_typing_union(self, annotation):
764767
def _is_typing_optional(self, annotation):
765768
return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore
766769

770+
@property
771+
def cog(self):
772+
return self._cog
773+
774+
@cog.setter
775+
def cog(self, val):
776+
self._cog = val
777+
self._validate_parameters()
778+
767779
@property
768780
def is_subcommand(self) -> bool:
769781
return self.parent is not None
@@ -956,6 +968,10 @@ def _update_copy(self, kwargs: Dict[str, Any]):
956968
else:
957969
return self.copy()
958970

971+
def _set_cog(self, cog):
972+
super()._set_cog(cog)
973+
self._validate_parameters()
974+
959975

960976
class SlashCommandGroup(ApplicationCommand):
961977
r"""A class that implements the protocol for a slash command group.

0 commit comments

Comments
 (0)