Skip to content

Commit f314302

Browse files
authored
Fix extensions in discord.Bot (#838)
1 parent 3eeeda9 commit f314302

File tree

4 files changed

+103
-27
lines changed

4 files changed

+103
-27
lines changed

discord/bot.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def __init__(self, *args, **kwargs) -> None:
8989
self._pending_application_commands = []
9090
self._application_commands = {}
9191

92+
@property
93+
def all_commands(self):
94+
return self._application_commands
95+
9296
@property
9397
def pending_application_commands(self):
9498
return self._pending_application_commands
@@ -149,7 +153,14 @@ def remove_application_command(
149153
The command that was removed. If the name is not valid then
150154
``None`` is returned instead.
151155
"""
152-
return self._application_commands.pop(command.id)
156+
if command.id is None:
157+
try:
158+
index = self._pending_application_commands.index(command)
159+
except ValueError:
160+
return None
161+
return self._pending_application_commands.pop(index)
162+
163+
return self._application_commands.pop(int(command.id), None)
153164

154165
@property
155166
def get_command(self):

discord/cog.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def get_commands(self) -> List[ApplicationCommand]:
251251
This does not include subcommands.
252252
"""
253253
return [c for c in self.__cog_commands__ if isinstance(c, ApplicationCommand) and c.parent is None]
254+
254255
@property
255256
def qualified_name(self) -> str:
256257
""":class:`str`: Returns the cog's specified name, not the class name."""
@@ -611,11 +612,17 @@ def _remove_module_references(self, name: str) -> None:
611612
self.remove_cog(cogname)
612613

613614
# remove all the commands from the module
614-
for cmd in self.all_commands.copy().values():
615+
if self._supports_prefixed_commands:
616+
for cmd in self.prefixed_commands.copy().values():
617+
if cmd.module is not None and _is_submodule(name, cmd.module):
618+
# if isinstance(cmd, GroupMixin):
619+
# cmd.recursively_remove_all_commands()
620+
self.remove_command(cmd.name)
621+
for cmd in self._application_commands.copy().values():
615622
if cmd.module is not None and _is_submodule(name, cmd.module):
616623
# if isinstance(cmd, GroupMixin):
617624
# cmd.recursively_remove_all_commands()
618-
self.remove_command(cmd.name)
625+
self.remove_application_command(cmd)
619626

620627
# remove all the listeners from the module
621628
for event_list in self.extra_events.copy().values():

discord/commands/core.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,22 @@
3232
import re
3333
import types
3434
from collections import OrderedDict
35-
from typing import Any, Callable, Dict, Generator, Generic, List, Optional, Type, TypeVar, Union, TYPE_CHECKING
35+
from typing import (
36+
Any,
37+
Callable,
38+
Dict,
39+
List,
40+
Optional,
41+
Union,
42+
TYPE_CHECKING,
43+
Awaitable,
44+
overload,
45+
TypeVar,
46+
Generic,
47+
Type,
48+
Generator,
49+
Coroutine,
50+
)
3651

3752
from .context import ApplicationContext, AutocompleteContext
3853
from .errors import ApplicationCommandError, CheckFailure, ApplicationCommandInvokeError
@@ -61,12 +76,13 @@
6176
)
6277

6378
if TYPE_CHECKING:
64-
from typing_extensions import ParamSpec
79+
from typing_extensions import ParamSpec, Concatenate
6580

6681
from ..cog import Cog
6782

6883
T = TypeVar('T')
69-
CogT = TypeVar('CogT', bound='Cog')
84+
CogT = TypeVar("CogT", bound="Cog")
85+
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
7086

7187
if TYPE_CHECKING:
7288
P = ParamSpec('P')
@@ -105,6 +121,16 @@ async def wrapped(arg):
105121
return ret
106122
return wrapped
107123

124+
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
125+
partial = functools.partial
126+
while True:
127+
if hasattr(function, '__wrapped__'):
128+
function = function.__wrapped__
129+
elif isinstance(function, partial):
130+
function = function.func
131+
else:
132+
return function
133+
108134
class _BaseCommand:
109135
__slots__ = ()
110136

@@ -118,7 +144,7 @@ def __init__(self, func: Callable, **kwargs) -> None:
118144
cooldown = func.__commands_cooldown__
119145
except AttributeError:
120146
cooldown = kwargs.get('cooldown')
121-
147+
122148
if cooldown is None:
123149
buckets = CooldownMapping(cooldown, BucketType.default)
124150
elif isinstance(cooldown, CooldownMapping):
@@ -134,7 +160,10 @@ def __init__(self, func: Callable, **kwargs) -> None:
134160

135161
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
136162

137-
def __repr__(self):
163+
self._callback = None
164+
self.module = None
165+
166+
def __repr__(self) -> str:
138167
return f"<discord.commands.{self.__class__.__name__} name={self.name}>"
139168

140169
def __eq__(self, other) -> bool:
@@ -161,6 +190,22 @@ async def __call__(self, ctx, *args, **kwargs):
161190
"""
162191
return await self.callback(ctx, *args, **kwargs)
163192

193+
@property
194+
def callback(self) -> Union[
195+
Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]],
196+
Callable[Concatenate[ApplicationContext, P], Coro[T]],
197+
]:
198+
return self._callback
199+
200+
@callback.setter
201+
def callback(self, function: Union[
202+
Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]],
203+
Callable[Concatenate[ApplicationContext, P], Coro[T]],
204+
]) -> None:
205+
self._callback = function
206+
unwrap = unwrap_function(function)
207+
self.module = unwrap.__module__
208+
164209
def _prepare_cooldowns(self, ctx: ApplicationContext):
165210
if self._buckets.valid:
166211
current = datetime.datetime.now().timestamp()
@@ -640,7 +685,7 @@ def _match_option_param_names(self, params, options):
640685
)
641686
p_obj = p_obj.annotation
642687

643-
if not any(c(o, p_obj) for c in check_annotations):
688+
if not any(c(o, p_obj) for c in check_annotations):
644689
raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.")
645690
o._parameter_name = p_name
646691

@@ -743,7 +788,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext):
743788

744789
if asyncio.iscoroutinefunction(option.autocomplete):
745790
result = await result
746-
791+
747792
choices = [
748793
o if isinstance(o, OptionChoice) else OptionChoice(o)
749794
for o in result
@@ -863,13 +908,18 @@ def __init__(
863908
self._before_invoke = None
864909
self._after_invoke = None
865910
self.cog = None
911+
self.id = None
866912

867913
# Permissions
868914
self.default_permission = kwargs.get("default_permission", True)
869915
self.permissions: List[CommandPermission] = kwargs.get("permissions", [])
870916
if self.permissions and self.default_permission:
871917
self.default_permission = False
872918

919+
@property
920+
def module(self) -> Optional[str]:
921+
return self.__module__
922+
873923
def to_dict(self) -> Dict:
874924
as_dict = {
875925
"name": self.name,
@@ -989,7 +1039,7 @@ def _ensure_assignment_on_copy(self, other):
9891039

9901040
if self.subcommands != other.subcommands:
9911041
other.subcommands = self.subcommands.copy()
992-
1042+
9931043
if self.checks != other.checks:
9941044
other.checks = self.checks.copy()
9951045

@@ -1069,6 +1119,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
10691119
raise TypeError("Name of a command must be a string.")
10701120

10711121
self.cog = None
1122+
self.id = None
10721123

10731124
try:
10741125
checks = func.__commands_checks__
@@ -1189,7 +1240,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None:
11891240

11901241
if self.cog is not None:
11911242
await self.callback(self.cog, ctx, target)
1192-
else:
1243+
else:
11931244
await self.callback(ctx, target)
11941245

11951246
def copy(self):

discord/ext/commands/core.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,17 +1167,24 @@ class GroupMixin(Generic[CogT]):
11671167
"""
11681168
def __init__(self, *args: Any, **kwargs: Any) -> None:
11691169
case_insensitive = kwargs.get('case_insensitive', False)
1170-
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
1170+
self.prefixed_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
11711171
self.case_insensitive: bool = case_insensitive
11721172
super().__init__(*args, **kwargs)
11731173

1174+
@property
1175+
def all_commands(self):
1176+
# merge app and prefixed commands
1177+
if hasattr(self, "_application_commands"):
1178+
return {**self._application_commands, **self.prefixed_commands}
1179+
return self.prefixed_commands
1180+
11741181
@property
11751182
def commands(self) -> Set[Command[CogT, Any, Any]]:
11761183
"""Set[:class:`.Command`]: A unique set of commands without aliases that are registered."""
1177-
return set(self.all_commands.values())
1184+
return set(self.prefixed_commands.values())
11781185

11791186
def recursively_remove_all_commands(self) -> None:
1180-
for command in self.all_commands.copy().values():
1187+
for command in self.prefixed_commands.copy().values():
11811188
if isinstance(command, GroupMixin):
11821189
command.recursively_remove_all_commands()
11831190
self.remove_command(command.name)
@@ -1210,15 +1217,15 @@ def add_command(self, command: Command[CogT, Any, Any]) -> None:
12101217
if isinstance(self, Command):
12111218
command.parent = self
12121219

1213-
if command.name in self.all_commands:
1220+
if command.name in self.prefixed_commands:
12141221
raise CommandRegistrationError(command.name)
12151222

1216-
self.all_commands[command.name] = command
1223+
self.prefixed_commands[command.name] = command
12171224
for alias in command.aliases:
1218-
if alias in self.all_commands:
1225+
if alias in self.prefixed_commands:
12191226
self.remove_command(command.name)
12201227
raise CommandRegistrationError(alias, alias_conflict=True)
1221-
self.all_commands[alias] = command
1228+
self.prefixed_commands[alias] = command
12221229

12231230
def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
12241231
"""Remove a :class:`.Command` from the internal list
@@ -1237,7 +1244,7 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
12371244
The command that was removed. If the name is not valid then
12381245
``None`` is returned instead.
12391246
"""
1240-
command = self.all_commands.pop(name, None)
1247+
command = self.prefixed_commands.pop(name, None)
12411248

12421249
# does not exist
12431250
if command is None:
@@ -1249,12 +1256,12 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
12491256

12501257
# we're not removing the alias so let's delete the rest of them.
12511258
for alias in command.aliases:
1252-
cmd = self.all_commands.pop(alias, None)
1259+
cmd = self.prefixed_commands.pop(alias, None)
12531260
# in the case of a CommandRegistrationError, an alias might conflict
12541261
# with an already existing command. If this is the case, we want to
12551262
# make sure the pre-existing command is not removed.
12561263
if cmd is not None and cmd != command:
1257-
self.all_commands[alias] = cmd
1264+
self.prefixed_commands[alias] = cmd
12581265
return command
12591266

12601267
def walk_commands(self) -> Generator[Command[CogT, Any, Any], None, None]:
@@ -1296,18 +1303,18 @@ def get_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
12961303

12971304
# fast path, no space in name.
12981305
if ' ' not in name:
1299-
return self.all_commands.get(name)
1306+
return self.prefixed_commands.get(name)
13001307

13011308
names = name.split()
13021309
if not names:
13031310
return None
1304-
obj = self.all_commands.get(names[0])
1311+
obj = self.prefixed_commands.get(names[0])
13051312
if not isinstance(obj, GroupMixin):
13061313
return obj
13071314

13081315
for name in names[1:]:
13091316
try:
1310-
obj = obj.all_commands[name] # type: ignore
1317+
obj = obj.prefixed_commands[name] # type: ignore
13111318
except (AttributeError, KeyError):
13121319
return None
13131320

@@ -1463,7 +1470,7 @@ async def invoke(self, ctx: Context) -> None:
14631470

14641471
if trigger:
14651472
ctx.subcommand_passed = trigger
1466-
ctx.invoked_subcommand = self.all_commands.get(trigger, None)
1473+
ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None)
14671474

14681475
if early_invoke:
14691476
injected = hooked_wrapped_callback(self, ctx, self.callback)
@@ -1497,7 +1504,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
14971504

14981505
if trigger:
14991506
ctx.subcommand_passed = trigger
1500-
ctx.invoked_subcommand = self.all_commands.get(trigger, None)
1507+
ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None)
15011508

15021509
if early_invoke:
15031510
try:

0 commit comments

Comments
 (0)