Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ These changes are available on the `master` branch, but have not yet been releas
([#2564](https://github.com/Pycord-Development/pycord/pull/2564))
- Added ability to change the API's base URL with `Route.API_BASE_URL`.
([#2714](https://github.com/Pycord-Development/pycord/pull/2714))
- Added the ability to use `functools.partial` as `Command` and `SlashCommand`
callbacks. ([#2724](https://github.com/Pycord-Development/pycord/pull/2724))

### Fixed

Expand Down
11 changes: 9 additions & 2 deletions discord/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import pathlib
import sys
import types
from functools import partial
from typing import Any, Callable, ClassVar, Generator, Mapping, TypeVar, overload

import discord.utils
Expand Down Expand Up @@ -256,16 +257,22 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
if not isinstance(command, SlashCommandGroup) and not hasattr(
command, "add_to"
):
actual_callback = (
command.callback.func
if isinstance(command.callback, partial)
else command.callback
)

# ignore bridge commands
cmd = getattr(new_cls, command.callback.__name__, None)
cmd = getattr(new_cls, actual_callback.__name__, None)
if hasattr(cmd, "add_to"):
setattr(
cmd,
f"{name_filter(command).replace('app', 'slash')}_variant",
command,
)
else:
setattr(new_cls, command.callback.__name__, command)
setattr(new_cls, actual_callback.__name__, command)

parent = command.parent
if parent is not None:
Expand Down
12 changes: 10 additions & 2 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]):
def __init__(self, func: Callable, **kwargs) -> None:
from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency

actual_func = func if not isinstance(func, functools.partial) else func.func

cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown"))

if cooldown is None:
Expand All @@ -214,7 +216,7 @@ def __init__(self, func: Callable, **kwargs) -> None:
self._callback = None
self.module = None

self.name: str = kwargs.get("name", func.__name__)
self.name: str = kwargs.get("name", actual_func.__name__)

try:
checks = func.__commands_checks__
Expand Down Expand Up @@ -483,7 +485,13 @@ async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> Non
ctx.bot.dispatch("application_command_error", ctx, error)

def _get_signature_parameters(self):
return OrderedDict(inspect.signature(self.callback).parameters)
params = OrderedDict(inspect.signature(self.callback).parameters)

if isinstance(self.callback, functools.partial):
for param in self.callback.keywords:
params.pop(param, None)

return params

def error(self, coro):
"""A decorator that registers a coroutine as a local error handler.
Expand Down
8 changes: 6 additions & 2 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def get_signature_parameters(
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")

params[name] = parameter.replace(annotation=annotation)

if isinstance(function, functools.partial):
for param in function.keywords:
params.pop(param, None)
return params


Expand Down Expand Up @@ -328,7 +330,9 @@ def __init__(
if not asyncio.iscoroutinefunction(func):
raise TypeError("Callback must be a coroutine.")

name = kwargs.get("name") or func.__name__
actual_func = func if not isinstance(func, functools.partial) else func.func

name = kwargs.get("name", actual_func.__name__)
if not isinstance(name, str):
raise TypeError("Name of a command must be a string.")
self.name: str = name
Expand Down