diff --git a/CHANGELOG.md b/CHANGELOG.md index 33d9e70167..efe02a0bbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2564](https://github.com/Pycord-Development/pycord/pull/2564)) - Added ability to change the API's base URL with `Route.API_BASE_URL`. ([#2714](https://github.com/Pycord-Development/pycord/pull/2714)) +- Added the ability to use `functools.partial` as `Command` and `SlashCommand` + callbacks. ([#2724](https://github.com/Pycord-Development/pycord/pull/2724)) ### Fixed diff --git a/discord/cog.py b/discord/cog.py index e23def38e2..1376c52b6f 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -31,6 +31,7 @@ import pathlib import sys import types +from functools import partial from typing import Any, Callable, ClassVar, Generator, Mapping, TypeVar, overload import discord.utils @@ -256,8 +257,14 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: if not isinstance(command, SlashCommandGroup) and not hasattr( command, "add_to" ): + actual_callback = ( + command.callback.func + if isinstance(command.callback, partial) + else command.callback + ) + # ignore bridge commands - cmd = getattr(new_cls, command.callback.__name__, None) + cmd = getattr(new_cls, actual_callback.__name__, None) if hasattr(cmd, "add_to"): setattr( cmd, @@ -265,7 +272,7 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: command, ) else: - setattr(new_cls, command.callback.__name__, command) + setattr(new_cls, actual_callback.__name__, command) parent = command.parent if parent is not None: diff --git a/discord/commands/core.py b/discord/commands/core.py index 6dd1b0d636..24bade0619 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -192,6 +192,8 @@ class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]): def __init__(self, func: Callable, **kwargs) -> None: from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency + actual_func = func if not isinstance(func, functools.partial) else func.func + cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) if cooldown is None: @@ -214,7 +216,7 @@ def __init__(self, func: Callable, **kwargs) -> None: self._callback = None self.module = None - self.name: str = kwargs.get("name", func.__name__) + self.name: str = kwargs.get("name", actual_func.__name__) try: checks = func.__commands_checks__ @@ -483,7 +485,13 @@ async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> Non ctx.bot.dispatch("application_command_error", ctx, error) def _get_signature_parameters(self): - return OrderedDict(inspect.signature(self.callback).parameters) + params = OrderedDict(inspect.signature(self.callback).parameters) + + if isinstance(self.callback, functools.partial): + for param in self.callback.keywords: + params.pop(param, None) + + return params def error(self, coro): """A decorator that registers a coroutine as a local error handler. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 6f6ef1dffa..1f11b56825 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -153,7 +153,9 @@ def get_signature_parameters( raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") params[name] = parameter.replace(annotation=annotation) - + if isinstance(function, functools.partial): + for param in function.keywords: + params.pop(param, None) return params @@ -328,7 +330,9 @@ def __init__( if not asyncio.iscoroutinefunction(func): raise TypeError("Callback must be a coroutine.") - name = kwargs.get("name") or func.__name__ + actual_func = func if not isinstance(func, functools.partial) else func.func + + name = kwargs.get("name", actual_func.__name__) if not isinstance(name, str): raise TypeError("Name of a command must be a string.") self.name: str = name