Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 118 additions & 3 deletions discord/app_commands/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

if TYPE_CHECKING:
from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption
from ..types.command import ApplicationCommand
from ..interactions import Interaction
from ..abc import Snowflake
from .commands import ContextMenuCallback, CommandCallback, P, T
Expand Down Expand Up @@ -161,6 +162,9 @@ def __init__(
# it's uncommon and N=5 anyway.
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {}

self._global_command_ids: Dict[str, int] = {}
self._guild_command_ids: Dict[Tuple[str, int], int] = {}

async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand:
"""|coro|

Expand Down Expand Up @@ -198,7 +202,10 @@ async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake]
else:
command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id)

return AppCommand(data=command, state=self._state)
res = AppCommand(data=command, state=self._state)
# self._store_command_id((res, res.id))
self._store_command_from_data(command)
return res

async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
"""|coro|
Expand Down Expand Up @@ -238,7 +245,77 @@ async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[App
else:
commands = await self._http.get_guild_commands(self.client.application_id, guild.id)

return [AppCommand(data=data, state=self._state) for data in commands]
res = [AppCommand(data=command, state=self._state) for command in commands]
# self._store_command_id(*((cmd, cmd.id) for cmd in res))
self._store_command_from_data(*commands)
return res

def get_command_id(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> Optional[int]:
"""Gets the command ID for a command.

Parameters
-----------
name: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`]
The name of the command to get the ID for.
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the command ID for. If not passed then the global command
ID is fetched instead.

Returns
--------
Optional[:class:`~discord.app_commands.CommandID`]
The command ID if found, otherwise ``None``.

.. note::

Group commands will return the ID of the root command. Subcommands do not have their own IDs.
"""
name: Optional[str] = None

if isinstance(command, AppCommand):
return command.id

if isinstance(command, (Command, Group, ContextMenu)):
name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name
elif isinstance(command, str):
name = command.split()[0]

return self._global_command_ids.get(name) if guild is None else self._guild_command_ids.get((name, guild.id))

def get_command_mention(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> str | None:
"""Gets the mention string for a command.

Parameters
-----------
command: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`]
The command to get the mention string for.

Returns
--------
Optional[:class:`str`]
The mention string for the command if found, otherwise ``None``.

.. note::

Remember that groups cannot be mentioned, only with a subcommand.
"""
if isinstance(command, AppCommand):
return command.mention

command_id = self.get_command_id(command, guild=guild)
if command_id is None:
return None

if isinstance(command, (Command, Group, ContextMenu)):
full_name = command.qualified_name
elif isinstance(command, str):
full_name = command

return f'</{full_name}:{command_id}>'

def copy_global_to(self, *, guild: Snowflake) -> None:
"""Copies all global commands to the specified guild.
Expand Down Expand Up @@ -1134,7 +1211,40 @@ async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
raise CommandSyncFailure(e, commands) from None
raise

return [AppCommand(data=d, state=self._state) for d in data]
res = [AppCommand(data=d, state=self._state) for d in data]
# self._store_command_id(*((cmd, cmd.id) for cmd in res))
self._store_command_from_data(*data)
return res

def _store_command_id(self, *commands: Tuple[AppCommand | ContextMenu | Command[Any, ..., Any] | Group, int]) -> None:
for command, command_id in commands:
if isinstance(command, AppCommand):
guild_id = command.guild_id
if guild_id is None:
self._global_command_ids[command.name] = command_id
else:
key = (command.name, guild_id)
self._guild_command_ids[key] = command_id
else:
guild_ids = command._guild_ids
name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name

if not guild_ids:
self._global_command_ids[name] = command_id
else:
for guild_id in guild_ids:
key = (name, guild_id)
self._guild_command_ids[key] = command_id

def _store_command_from_data(self, *data: ApplicationCommandInteractionData | ApplicationCommand) -> None:
for d in data:
command_id = int(d['id'])
name = d['name']
guild_id = _get_as_snowflake(d, 'guild_id')
if guild_id is None:
self._global_command_ids[name] = command_id
else:
self._guild_command_ids[(name, guild_id)] = command_id

async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None:
command = interaction.command
Expand Down Expand Up @@ -1231,6 +1341,9 @@ async def _call_context_menu(
if ctx_menu is None:
raise CommandNotFound(name, [], AppCommandType(type))

# self._store_command_id((ctx_menu, int(data['id'])))
self._store_command_from_data(data)

resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {}))

# This is annotated as str | int but realistically this will always be str
Expand Down Expand Up @@ -1281,6 +1394,8 @@ async def _call(self, interaction: Interaction[ClientT]) -> None:
return

command, options = self._get_app_command_options(data)
# self._store_command_id((command, int(data['id'])))
self._store_command_from_data(data)

# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
Expand Down