Skip to content

Commit e321fa3

Browse files
EmmmaTechBobDotComLulalaby
committed
feat(bot.py): fix type issues (#1534)
Co-authored-by: BobDotCom <[email protected]> Co-authored-by: Lala Sabathil <[email protected]>
1 parent 38ba10a commit e321fa3

File tree

2 files changed

+96
-81
lines changed

2 files changed

+96
-81
lines changed

discord/bot.py

Lines changed: 93 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import asyncio
2929
import collections
30+
import collections.abc
3031
import copy
3132
import inspect
3233
import logging
@@ -41,6 +42,7 @@
4142
Generator,
4243
List,
4344
Literal,
45+
Mapping,
4446
Optional,
4547
Type,
4648
TypeVar,
@@ -109,7 +111,7 @@ def pending_application_commands(self):
109111
def commands(self) -> List[Union[ApplicationCommand, Any]]:
110112
commands = self.application_commands
111113
if self._bot._supports_prefixed_commands and hasattr(self._bot, "prefixed_commands"):
112-
commands += self._bot.prefixed_commands
114+
commands += getattr(self._bot, "prefixed_commands")
113115
return commands
114116

115117
@property
@@ -217,7 +219,7 @@ def get_application_command(
217219
async def get_desynced_commands(
218220
self,
219221
guild_id: Optional[int] = None,
220-
prefetched: Optional[List[ApplicationCommand]] = None
222+
prefetched: Optional[List[interactions.ApplicationCommand]] = None
221223
) -> List[Dict[str, Any]]:
222224
"""|coro|
223225
@@ -248,7 +250,7 @@ async def get_desynced_commands(
248250

249251
# We can suggest the user to upsert, edit, delete, or bulk upsert the commands
250252

251-
def _check_command(cmd: ApplicationCommand, match: Dict) -> bool:
253+
def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool:
252254
if isinstance(cmd, SlashCommandGroup):
253255
if len(cmd.subcommands) != len(match.get("options", [])):
254256
return True
@@ -300,24 +302,25 @@ def _check_command(cmd: ApplicationCommand, match: Dict) -> bool:
300302
# TODO: Remove for perms v2
301303
continue
302304
return True
303-
return False
305+
return False
304306

305307
return_value = []
306308
cmds = self.pending_application_commands.copy()
307309

308310
if guild_id is None:
309-
if prefetched is not None:
310-
registered_commands = prefetched
311-
else:
312-
registered_commands = await self._bot.http.get_global_commands(self.user.id)
313311
pending = [cmd for cmd in cmds if cmd.guild_ids is None]
314312
else:
315-
if prefetched is not None:
316-
registered_commands = prefetched
317-
else:
318-
registered_commands = await self._bot.http.get_guild_commands(self.user.id, guild_id)
319313
pending = [cmd for cmd in cmds if cmd.guild_ids is not None and guild_id in cmd.guild_ids]
320314

315+
registered_commands: List[interactions.ApplicationCommand] = []
316+
if prefetched is not None:
317+
registered_commands = prefetched
318+
elif self._bot.user:
319+
if guild_id is None:
320+
registered_commands = await self._bot.http.get_global_commands(self._bot.user.id)
321+
else:
322+
registered_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
323+
321324
registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands}
322325
# First let's check if the commands we have locally are the same as the ones on discord
323326
for cmd in pending:
@@ -358,7 +361,7 @@ async def register_command(
358361
self,
359362
command: ApplicationCommand,
360363
force: bool = True,
361-
guild_ids: List[int] = None,
364+
guild_ids: Optional[List[int]] = None,
362365
) -> None:
363366
"""|coro|
364367
@@ -382,7 +385,7 @@ async def register_command(
382385
The command that was registered
383386
"""
384387
# TODO: Write this
385-
raise RuntimeError("This function has not been implemented yet")
388+
raise NotImplementedError
386389

387390
async def register_commands(
388391
self,
@@ -439,7 +442,7 @@ async def register_commands(
439442
}
440443

441444
def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
442-
return registration_methods[method](self._bot.user.id, *args, **kwargs)
445+
return registration_methods[method](self._bot.user and self._bot.user.id, *args, **kwargs)
443446

444447
else:
445448
pending = list(
@@ -456,27 +459,30 @@ def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwar
456459
}
457460

458461
def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
459-
return registration_methods[method](self._bot.user.id, guild_id, *args, **kwargs)
462+
return registration_methods[method](self._bot.user and self._bot.user.id, guild_id, *args, **kwargs)
460463

461464
def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
462465
if kwargs.pop("_log", True):
463466
if method == "bulk":
464467
_log.debug(f"Bulk updating commands {[c['name'] for c in args[0]]} for guild {guild_id}")
468+
# TODO: Find where "cmd" is defined
465469
elif method == "upsert":
466-
_log.debug(f"Creating command {cmd['name']} for guild {guild_id}")
470+
_log.debug(f"Creating command {cmd['name']} for guild {guild_id}") # type: ignore
467471
elif method == "edit":
468-
_log.debug(f"Editing command {cmd['name']} for guild {guild_id}")
472+
_log.debug(f"Editing command {cmd['name']} for guild {guild_id}") # type: ignore
469473
elif method == "delete":
470-
_log.debug(f"Deleting command {cmd['name']} for guild {guild_id}")
474+
_log.debug(f"Deleting command {cmd['name']} for guild {guild_id}") # type: ignore
471475
return _register(method, *args, **kwargs)
472476

473477
pending_actions = []
474478

475479
if not force:
476-
if guild_id is None:
477-
prefetched_commands = await self.http.get_global_commands(self.user.id)
478-
else:
479-
prefetched_commands = await self.http.get_guild_commands(self.user.id, guild_id)
480+
prefetched_commands: List[interactions.ApplicationCommand] = []
481+
if self._bot.user:
482+
if guild_id is None:
483+
prefetched_commands = await self._bot.http.get_global_commands(self._bot.user.id)
484+
else:
485+
prefetched_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
480486
desynced = await self.get_desynced_commands(guild_id=guild_id, prefetched=prefetched_commands)
481487

482488
for cmd in desynced:
@@ -549,10 +555,11 @@ def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwarg
549555

550556
# TODO: Our lists dont work sometimes, see if that can be fixed so we can avoid this second API call
551557
if method != "bulk":
552-
if guild_id is None:
553-
registered = await self._bot.http.get_global_commands(self._bot.user.id)
554-
else:
555-
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
558+
if self._bot.user:
559+
if guild_id is None:
560+
registered = await self._bot.http.get_global_commands(self._bot.user.id)
561+
else:
562+
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
556563
else:
557564
data = [cmd.to_dict() for cmd in pending]
558565
registered = await register("bulk", data)
@@ -561,10 +568,10 @@ def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwarg
561568
cmd = get(
562569
self.pending_application_commands,
563570
name=i["name"],
564-
type=i["type"],
571+
type=i.get("type"),
565572
)
566573
if not cmd:
567-
raise ValueError(f"Registered command {i['name']}, type {i['type']} not found in pending commands")
574+
raise ValueError(f"Registered command {i['name']}, type {i.get('type')} not found in pending commands")
568575
cmd.id = i["id"]
569576
self._application_commands[cmd.id] = cmd
570577

@@ -622,7 +629,7 @@ async def sync_commands(
622629
Whether to delete existing commands that are not in the list of commands to register. Defaults to True.
623630
"""
624631

625-
check_guilds = list(set((check_guilds or []) + (self.debug_guilds or [])))
632+
check_guilds = list(set((check_guilds or []) + (self._bot.debug_guilds or [])))
626633

627634
if commands is None:
628635
commands = self.pending_application_commands
@@ -636,48 +643,51 @@ async def sync_commands(
636643
global_commands, method=method, force=force, delete_existing=delete_existing
637644
)
638645

639-
registered_guild_commands = {}
646+
registered_guild_commands: Dict[int, List[interactions.ApplicationCommand]] = {}
640647

641648
if register_guild_commands:
642-
cmd_guild_ids = []
649+
cmd_guild_ids: List[int] = []
643650
for cmd in commands:
644651
if cmd.guild_ids is not None:
645652
cmd_guild_ids.extend(cmd.guild_ids)
646653
if check_guilds is not None:
647654
cmd_guild_ids.extend(check_guilds)
648655
for guild_id in set(cmd_guild_ids):
649656
guild_commands = [cmd for cmd in commands if cmd.guild_ids is not None and guild_id in cmd.guild_ids]
650-
registered_guild_commands[guild_id] = await self.register_commands(
657+
app_cmds = await self.register_commands(
651658
guild_commands, guild_id=guild_id, method=method, force=force, delete_existing=delete_existing
652659
)
660+
registered_guild_commands[guild_id] = app_cmds
653661

654662
for i in registered_commands:
655663
cmd = get(
656664
self.pending_application_commands,
657665
name=i["name"],
658666
guild_ids=None,
659-
type=i["type"],
667+
type=i.get("type"),
660668
)
661669
if cmd:
662670
cmd.id = i["id"]
663671
self._application_commands[cmd.id] = cmd
664672

665-
for guild_id, commands in registered_guild_commands.items():
666-
for i in commands:
667-
cmd = find(
668-
lambda cmd: cmd.name == i["name"]
669-
and cmd.type == i["type"]
670-
and cmd.guild_ids is not None
671-
and int(i["guild_id"]) in cmd.guild_ids,
672-
self.pending_application_commands,
673-
)
674-
if not cmd:
675-
# command has not been added yet
676-
continue
677-
cmd.id = i["id"]
678-
self._application_commands[cmd.id] = cmd
673+
if register_guild_commands and registered_guild_commands:
674+
for guild_id, guild_cmds in registered_guild_commands.items():
675+
for i in guild_cmds:
676+
cmd = find(
677+
lambda cmd: cmd.name == i["name"]
678+
and cmd.type == i.get("type")
679+
and cmd.guild_ids is not None
680+
# TODO: fix this type error (guild_id is not defined in ApplicationCommand Typed Dict)
681+
and int(i["guild_id"]) in cmd.guild_ids, # type: ignore
682+
self.pending_application_commands,
683+
)
684+
if not cmd:
685+
# command has not been added yet
686+
continue
687+
cmd.id = i["id"]
688+
self._application_commands[cmd.id] = cmd
679689

680-
async def process_application_commands(self, interaction: Interaction, auto_sync: bool = None) -> None:
690+
async def process_application_commands(self, interaction: Interaction, auto_sync: Optional[bool] = None) -> None:
681691
"""|coro|
682692
683693
This function processes the commands that have been registered
@@ -698,33 +708,37 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
698708
-----------
699709
interaction: :class:`discord.Interaction`
700710
The interaction to process
701-
auto_sync: :class:`bool`
711+
auto_sync: Optional[:class:`bool`]
702712
Whether to automatically sync and unregister the command if it is not found in the internal cache. This will
703713
invoke the :meth:`~.Bot.sync_commands` method on the context of the command, either globally or per-guild,
704714
based on the type of the command, respectively. Defaults to :attr:`.Bot.auto_sync_commands`.
705715
"""
706716
if auto_sync is None:
707717
auto_sync = self._bot.auto_sync_commands
718+
# TODO: find out why the isinstance check below doesn't stop the type errors below
708719
if interaction.type not in (
709720
InteractionType.application_command,
710721
InteractionType.auto_complete,
711-
):
722+
) and isinstance(interaction.data, interactions.ComponentInteractionData):
712723
return
713724

725+
command: Optional[ApplicationCommand] = None
714726
try:
715-
command = self._application_commands[interaction.data["id"]]
727+
if interaction.data:
728+
command = self._application_commands[interaction.data["id"]] # type: ignore
716729
except KeyError:
717730
for cmd in self.application_commands + self.pending_application_commands:
718-
guild_id = interaction.data.get("guild_id")
719-
if guild_id:
720-
guild_id = int(guild_id)
721-
if cmd.name == interaction.data["name"] and (
722-
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
723-
):
724-
command = cmd
725-
break
731+
if interaction.data:
732+
guild_id = interaction.data.get("guild_id")
733+
if guild_id:
734+
guild_id = int(guild_id)
735+
if cmd.name == interaction.data["name"] and ( # type: ignore
736+
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
737+
):
738+
command = cmd
739+
break
726740
else:
727-
if auto_sync:
741+
if auto_sync and interaction.data:
728742
guild_id = interaction.data.get("guild_id")
729743
if guild_id is None:
730744
await self.sync_commands()
@@ -734,26 +748,28 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
734748
return self._bot.dispatch("unknown_application_command", interaction)
735749

736750
if interaction.type is InteractionType.auto_complete:
737-
return self.dispatch("application_command_auto_complete", interaction, command)
751+
return self._bot.dispatch("application_command_auto_complete", interaction, command)
738752

739753
ctx = await self.get_application_context(interaction)
740-
ctx.command = command
754+
if command:
755+
ctx.command = command
741756
await self.invoke_application_command(ctx)
742757

743758
async def on_application_command_auto_complete(self, interaction: Interaction, command: ApplicationCommand) -> None:
744-
async def callback() -> None:
745-
ctx = await self.get_autocomplete_context(interaction)
746-
ctx.command = command
747-
return await command.invoke_autocomplete_callback(ctx)
759+
if isinstance(command, SlashCommand):
760+
async def callback() -> None:
761+
ctx = await self.get_autocomplete_context(interaction)
762+
ctx.command = command
763+
return await command.invoke_autocomplete_callback(ctx)
748764

749-
autocomplete_task = self.loop.create_task(callback())
750-
try:
751-
await self.wait_for("application_command_auto_complete", check=lambda i, c: c == command, timeout=3)
752-
except asyncio.TimeoutError:
753-
return
754-
else:
755-
if not autocomplete_task.done():
756-
autocomplete_task.cancel()
765+
autocomplete_task = self._bot.loop.create_task(callback())
766+
try:
767+
await self._bot.wait_for("application_command_auto_complete", check=lambda i, c: c == command, timeout=3)
768+
except asyncio.TimeoutError:
769+
return
770+
else:
771+
if not autocomplete_task.done():
772+
autocomplete_task.cancel()
757773

758774
def slash_command(self, **kwargs):
759775
"""A shortcut decorator that invokes :func:`command` and adds it to
@@ -924,7 +940,7 @@ def walk_application_commands(self) -> Generator[ApplicationCommand, None, None]
924940
yield from command.walk_commands()
925941
yield command
926942

927-
async def get_application_context(self, interaction: Interaction, cls=None) -> ApplicationContext:
943+
async def get_application_context(self, interaction: Interaction, cls: Any = ApplicationContext) -> ApplicationContext:
928944
r"""|coro|
929945
930946
Returns the invocation context from the interaction.
@@ -948,11 +964,9 @@ class be provided, it must be similar enough to
948964
The invocation context. The type of this can change via the
949965
``cls`` parameter.
950966
"""
951-
if cls is None:
952-
cls = ApplicationContext
953967
return cls(self, interaction)
954968

955-
async def get_autocomplete_context(self, interaction: Interaction, cls=None) -> AutocompleteContext:
969+
async def get_autocomplete_context(self, interaction: Interaction, cls: Any = AutocompleteContext) -> AutocompleteContext:
956970
r"""|coro|
957971
958972
Returns the autocomplete context from the interaction.
@@ -976,8 +990,6 @@ class be provided, it must be similar enough to
976990
The autocomplete context. The type of this can change via the
977991
``cls`` parameter.
978992
"""
979-
if cls is None:
980-
cls = AutocompleteContext
981993
return cls(self, interaction)
982994

983995
async def invoke_application_command(self, ctx: ApplicationContext) -> None:

discord/commands/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ def qualified_name(self) -> str:
565565
else:
566566
return self.name
567567

568+
def to_dict(self) -> Dict[str, Any]:
569+
raise NotImplementedError
570+
568571
def __str__(self) -> str:
569572
return self.qualified_name
570573

0 commit comments

Comments
 (0)