diff --git a/discord/cog.py b/discord/cog.py index bf898d77da..5fbe6e61a4 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -31,7 +31,17 @@ import pathlib import sys import types -from typing import Any, Callable, ClassVar, Generator, Mapping, TypeVar, overload +from collections.abc import Generator, Mapping +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + TypeVar, + overload, +) + +from typing_extensions import TypeGuard import discord.utils @@ -43,6 +53,10 @@ _BaseCommand, ) +if TYPE_CHECKING: + from .ext.bridge import BridgeCommand + + __all__ = ( "CogMeta", "Cog", @@ -59,6 +73,118 @@ def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(f"{parent}.") +def _is_bridge_command(command: Any) -> TypeGuard[BridgeCommand]: + return getattr(command, "__bridge__", False) + + +def _name_filter(c: Any) -> str: + return ( + "app" + if isinstance(c, ApplicationCommand) + else ("bridge" if not _is_bridge_command(c) else "ext") + ) + + +def _validate_name_prefix(base_class: type, name: str) -> None: + if name.startswith(("cog_", "bot_")): + raise TypeError( + f"Commands or listeners must not start with cog_ or bot_ (in method {base_class}.{name})" + ) + + +def _process_attributes( + base: type, +) -> tuple[dict[str, Any], dict[str, Any]]: # pyright: ignore[reportExplicitAny] + commands: dict[str, _BaseCommand | BridgeCommand] = {} + listeners: dict[str, Callable[..., Any]] = {} + + for attr_name, attr_value in base.__dict__.items(): + if attr_name in commands: + del commands[attr_name] + if attr_name in listeners: + del listeners[attr_name] + + if getattr(attr_value, "parent", None) and isinstance( + attr_value, ApplicationCommand + ): + # Skip application commands if they are a part of a group + # Since they are already added when the group is added + continue + + is_static_method = isinstance(attr_value, staticmethod) + if is_static_method: + attr_value = attr_value.__func__ + + if inspect.iscoroutinefunction(attr_value) and getattr( + attr_value, "__cog_listener__", False + ): + _validate_name_prefix(base, attr_name) + listeners[attr_name] = attr_value + continue + + if isinstance(attr_value, _BaseCommand) or _is_bridge_command(attr_value): + if is_static_method: + raise TypeError( + f"Command in method {base}.{attr_name!r} must not be staticmethod." + ) + _validate_name_prefix(base, attr_name) + + if isinstance(attr_value, _BaseCommand): + commands[attr_name] = attr_value + + if _is_bridge_command(attr_value) and not attr_value.parent: + commands[f"ext_{attr_name}"] = attr_value.ext_variant + commands[f"app_{attr_name}"] = attr_value.slash_variant + commands[attr_name] = attr_value + for cmd in getattr(attr_value, "subcommands", []): + commands[f"ext_{cmd.ext_variant.qualified_name}"] = cmd.ext_variant + + return commands, listeners + + +def _update_command( + command: _BaseCommand | BridgeCommand, + guild_ids: list[int], + lookup_table: dict[str, _BaseCommand | BridgeCommand], + new_cls: type[Cog], +) -> None: + if isinstance(command, ApplicationCommand) and not command.guild_ids and guild_ids: + command.guild_ids = guild_ids + + if not isinstance(command, SlashCommandGroup) and not _is_bridge_command(command): + # ignore bridge commands + cmd: BridgeCommand | _BaseCommand | None = getattr( + new_cls, + command.callback.__name__, + None, # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportAttributeAccessIssue] + ) + if _is_bridge_command(cmd): + setattr( + cmd, + f"{_name_filter(command).replace('app', 'slash')}_variant", + command, + ) + else: + setattr( + new_cls, + command.callback.__name__, + command, # pyright: ignore [reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] + ) + + parent: ( + BridgeCommand | _BaseCommand | None + ) = ( # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType] + command.parent # pyright: ignore [reportAttributeAccessIssue] + ) + if parent is not None: + # Get the latest parent reference + parent = lookup_table[f"{_name_filter(command)}_{parent.qualified_name}"] # type: ignore # pyright: ignore[reportUnknownMemberType] + + # Update the parent's reference to our self + parent.remove_command(command.name) # type: ignore # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + parent.add_command(command) # type: ignore # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + + class CogMeta(type): """A metaclass for defining a cog. @@ -127,7 +253,7 @@ async def bar(self, ctx): __cog_name__: str __cog_settings__: dict[str, Any] - __cog_commands__: list[ApplicationCommand] + __cog_commands__: list[_BaseCommand | BridgeCommand] __cog_listeners__: list[tuple[str, str]] __cog_guild_ids__: list[int] @@ -142,128 +268,38 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: description = inspect.cleandoc(attrs.get("__doc__", "")) attrs["__cog_description__"] = description - commands = {} - listeners = {} - no_bot_cog = ( - "Commands or listeners must not start with cog_ or bot_ (in method" - " {0.__name__}.{1})" - ) + commands: dict[str, _BaseCommand | BridgeCommand] = {} + listeners: dict[str, Callable[..., Any]] = {} new_cls = super().__new__(cls, name, bases, attrs, **kwargs) for base in reversed(new_cls.__mro__): - for elem, value in base.__dict__.items(): - if elem in commands: - del commands[elem] - if elem in listeners: - del listeners[elem] - - if getattr(value, "parent", None) and isinstance( - value, ApplicationCommand - ): - # Skip commands if they are a part of a group - continue - - is_static_method = isinstance(value, staticmethod) - if is_static_method: - value = value.__func__ - if isinstance(value, _BaseCommand): - if is_static_method: - raise TypeError( - f"Command in method {base}.{elem!r} must not be" - " staticmethod." - ) - if elem.startswith(("cog_", "bot_")): - raise TypeError(no_bot_cog.format(base, elem)) - commands[elem] = value - - # a test to see if this value is a BridgeCommand - if hasattr(value, "add_to") and not getattr(value, "parent", None): - if is_static_method: - raise TypeError( - f"Command in method {base}.{elem!r} must not be" - " staticmethod." - ) - if elem.startswith(("cog_", "bot_")): - raise TypeError(no_bot_cog.format(base, elem)) - - commands[f"ext_{elem}"] = value.ext_variant - commands[f"app_{elem}"] = value.slash_variant - commands[elem] = value - for cmd in getattr(value, "subcommands", []): - commands[f"ext_{cmd.ext_variant.qualified_name}"] = ( - cmd.ext_variant - ) - - if inspect.iscoroutinefunction(value): - try: - getattr(value, "__cog_listener__") - except AttributeError: - continue - else: - if elem.startswith(("cog_", "bot_")): - raise TypeError(no_bot_cog.format(base, elem)) - listeners[elem] = value + new_commands, new_listeners = _process_attributes(base) + commands.update(new_commands) + listeners.update(new_listeners) new_cls.__cog_commands__ = list(commands.values()) - listeners_as_list = [] - for listener in listeners.values(): - for listener_name in listener.__cog_listener_names__: - # I use __name__ instead of just storing the value, so I can inject - # the self attribute when the time comes to add them to the bot - listeners_as_list.append((listener_name, listener.__name__)) - - new_cls.__cog_listeners__ = listeners_as_list + new_cls.__cog_listeners__ = [ + (listener_name, listener.__name__) + for listener in listeners.values() + for listener_name in listener.__cog_listener_names__ + ] cmd_attrs = new_cls.__cog_settings__ # Either update the command with the cog provided defaults or copy it. # r.e type ignore, type-checker complains about overriding a ClassVar - new_cls.__cog_commands__ = tuple(c._update_copy(cmd_attrs) if not hasattr(c, "add_to") else c for c in new_cls.__cog_commands__) # type: ignore - - name_filter = lambda c: ( - "app" - if isinstance(c, ApplicationCommand) - else ("bridge" if not hasattr(c, "add_to") else "ext") - ) + new_cls.__cog_commands__ = list(tuple(c._update_copy(cmd_attrs) if not _is_bridge_command(c) else c for c in new_cls.__cog_commands__)) # type: ignore lookup = { - f"{name_filter(cmd)}_{cmd.qualified_name}": cmd + f"{_name_filter(cmd)}_{cmd.qualified_name}": cmd for cmd in new_cls.__cog_commands__ } # Update the Command instances dynamically as well for command in new_cls.__cog_commands__: - if ( - isinstance(command, ApplicationCommand) - and not command.guild_ids - and new_cls.__cog_guild_ids__ - ): - command.guild_ids = new_cls.__cog_guild_ids__ - - if not isinstance(command, SlashCommandGroup) and not hasattr( - command, "add_to" - ): - # ignore bridge commands - cmd = getattr(new_cls, command.callback.__name__, None) - if hasattr(cmd, "add_to"): - setattr( - cmd, - f"{name_filter(command).replace('app', 'slash')}_variant", - command, - ) - else: - setattr(new_cls, command.callback.__name__, command) - - parent = command.parent - if parent is not None: - # Get the latest parent reference - parent = lookup[f"{name_filter(command)}_{parent.qualified_name}"] # type: ignore - - # Update our parent's reference to our self - parent.remove_command(command.name) # type: ignore - parent.add_command(command) # type: ignore + _update_command(command, new_cls.__cog_guild_ids__, lookup, new_cls) return new_cls @@ -537,7 +573,7 @@ def _inject(self: CogT, bot) -> CogT: # we've added so far for some form of atomic loading. for index, command in enumerate(self.__cog_commands__): - if hasattr(command, "add_to"): + if _is_bridge_command(command): bot.bridge_commands.append(command) continue @@ -582,7 +618,7 @@ def _eject(self, bot) -> None: try: for command in self.__cog_commands__: - if hasattr(command, "add_to"): + if _is_bridge_command(command): bot.bridge_commands.remove(command) continue elif isinstance(command, ApplicationCommand): diff --git a/discord/commands/core.py b/discord/commands/core.py index bee43341fd..82fb65dbf4 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -726,6 +726,8 @@ class SlashCommand(ApplicationCommand): type = 1 + parent: SlashCommandGroup | None + def __new__(cls, *args, **kwargs) -> SlashCommand: self = super().__new__(cls) diff --git a/discord/ext/bridge/core.py b/discord/ext/bridge/core.py index 9dc58fb009..ff3a3a8d63 100644 --- a/discord/ext/bridge/core.py +++ b/discord/ext/bridge/core.py @@ -184,6 +184,8 @@ class BridgeCommand: The prefix-based version of this bridge command. """ + __bridge__: bool = True + __special_attrs__ = ["slash_variant", "ext_variant", "parent"] def __init__(self, callback, **kwargs):