Skip to content

Commit 67f1e13

Browse files
DorukyumLulalaby
andauthored
Cooldowns and max concurrency for application commands (#674)
Co-authored-by: Lala Sabathil <[email protected]>
1 parent 3220cec commit 67f1e13

File tree

2 files changed

+86
-8
lines changed

2 files changed

+86
-8
lines changed

discord/commands/commands.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import annotations
2727

2828
import asyncio
29+
import datetime
2930
import functools
3031
import inspect
3132
import re
@@ -111,6 +112,29 @@ class _BaseCommand:
111112
class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]):
112113
cog = None
113114

115+
def __init__(self, func: Callable, **kwargs) -> None:
116+
from ..ext.commands.cooldowns import CooldownMapping, BucketType, MaxConcurrency
117+
118+
try:
119+
cooldown = func.__commands_cooldown__
120+
except AttributeError:
121+
cooldown = kwargs.get('cooldown')
122+
123+
if cooldown is None:
124+
buckets = CooldownMapping(cooldown, BucketType.default)
125+
elif isinstance(cooldown, CooldownMapping):
126+
buckets = cooldown
127+
else:
128+
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
129+
self._buckets: CooldownMapping = buckets
130+
131+
try:
132+
max_concurrency = func.__commands_max_concurrency__
133+
except AttributeError:
134+
max_concurrency = kwargs.get('max_concurrency')
135+
136+
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
137+
114138
def __repr__(self):
115139
return f"<discord.commands.{self.__class__.__name__} name={self.name}>"
116140

@@ -127,16 +151,48 @@ async def __call__(self, ctx, *args, **kwargs):
127151
"""
128152
return await self.callback(ctx, *args, **kwargs)
129153

154+
def _prepare_cooldowns(self, ctx: ApplicationContext):
155+
if self._buckets.valid:
156+
current = datetime.datetime.now().timestamp()
157+
bucket = self._buckets.get_bucket(ctx, current) # type: ignore (ctx instead of non-existent message)
158+
159+
if bucket is not None:
160+
retry_after = bucket.update_rate_limit(current)
161+
162+
if retry_after:
163+
from ..ext.commands.errors import CommandOnCooldown
164+
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
165+
130166
async def prepare(self, ctx: ApplicationContext) -> None:
131167
# This should be same across all 3 types
132168
ctx.command = self
133169

134170
if not await self.can_run(ctx):
135171
raise CheckFailure(f'The check functions for the command {self.name} failed')
136172

137-
# TODO: Add cooldown
173+
if self._max_concurrency is not None:
174+
# For this application, context can be duck-typed as a Message
175+
await self._max_concurrency.acquire(ctx) # type: ignore (ctx instead of non-existent message)
138176

139-
await self.call_before_hooks(ctx)
177+
try:
178+
self._prepare_cooldowns(ctx)
179+
await self.call_before_hooks(ctx)
180+
except:
181+
if self._max_concurrency is not None:
182+
await self._max_concurrency.release(ctx) # type: ignore (ctx instead of non-existent message)
183+
raise
184+
185+
def reset_cooldown(self, ctx: ApplicationContext) -> None:
186+
"""Resets the cooldown on this command.
187+
188+
Parameters
189+
-----------
190+
ctx: :class:`.ApplicationContext`
191+
The invocation context to reset the cooldown under.
192+
"""
193+
if self._buckets.valid:
194+
bucket = self._buckets.get_bucket(ctx) # type: ignore (ctx instead of non-existent message)
195+
bucket.reset()
140196

141197
async def invoke(self, ctx: ApplicationContext) -> None:
142198
await self.prepare(ctx)
@@ -299,6 +355,10 @@ async def call_after_hooks(self, ctx: ApplicationContext) -> None:
299355
if hook is not None:
300356
await hook(ctx)
301357

358+
@property
359+
def cooldown(self):
360+
return self._buckets._cooldown
361+
302362
@property
303363
def full_parent_name(self) -> str:
304364
""":class:`str`: Retrieves the fully qualified parent command name.
@@ -371,6 +431,11 @@ class SlashCommand(ApplicationCommand):
371431
:exc:`.CommandError` should be used. Note that if the checks fail then
372432
:exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error`
373433
event.
434+
cooldown: Optional[:class:`~discord.ext.commands.Cooldown`]
435+
The cooldown applied when the command is invoked. ``None`` if the command
436+
doesn't have a cooldown.
437+
438+
.. versionadded:: 2.0
374439
"""
375440
type = 1
376441

@@ -381,6 +446,7 @@ def __new__(cls, *args, **kwargs) -> SlashCommand:
381446
return self
382447

383448
def __init__(self, func: Callable, *args, **kwargs) -> None:
449+
super().__init__(func, **kwargs)
384450
if not asyncio.iscoroutinefunction(func):
385451
raise TypeError("Callback must be a coroutine.")
386452
self.callback = func
@@ -963,6 +1029,11 @@ class ContextMenuCommand(ApplicationCommand):
9631029
:exc:`.CommandError` should be used. Note that if the checks fail then
9641030
:exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error`
9651031
event.
1032+
cooldown: Optional[:class:`~discord.ext.commands.Cooldown`]
1033+
The cooldown applied when the command is invoked. ``None`` if the command
1034+
doesn't have a cooldown.
1035+
1036+
.. versionadded:: 2.0
9661037
"""
9671038
def __new__(cls, *args, **kwargs) -> ContextMenuCommand:
9681039
self = super().__new__(cls)
@@ -971,6 +1042,7 @@ def __new__(cls, *args, **kwargs) -> ContextMenuCommand:
9711042
return self
9721043

9731044
def __init__(self, func: Callable, *args, **kwargs) -> None:
1045+
super().__init__(func, **kwargs)
9741046
if not asyncio.iscoroutinefunction(func):
9751047
raise TypeError("Callback must be a coroutine.")
9761048
self.callback = func

discord/ext/commands/core.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@
5353
from ...errors import *
5454
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
5555
from .converter import run_converters, get_converter, Greedy
56-
from ...commands import _BaseCommand, slash_command, user_command, message_command
56+
from ...commands import (
57+
ApplicationCommand,
58+
_BaseCommand,
59+
slash_command,
60+
user_command,
61+
message_command,
62+
)
5763
from .cog import Cog
5864
from .context import Context
5965

@@ -2147,7 +2153,7 @@ def pred(ctx: Context) -> bool:
21472153
return check(pred)
21482154

21492155
def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]:
2150-
"""A decorator that adds a cooldown to a :class:`.Command`
2156+
"""A decorator that adds a cooldown to a command
21512157
21522158
A cooldown allows a command to only be used a specific amount
21532159
of times in a specific time frame. These cooldowns can be based
@@ -2174,15 +2180,15 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message],
21742180
"""
21752181

21762182
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
2177-
if isinstance(func, Command):
2183+
if isinstance(func, (Command, ApplicationCommand)):
21782184
func._buckets = CooldownMapping(Cooldown(rate, per), type)
21792185
else:
21802186
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
21812187
return func
21822188
return decorator # type: ignore
21832189

21842190
def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
2185-
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
2191+
"""A decorator that adds a dynamic cooldown to a command
21862192
21872193
This differs from :func:`.cooldown` in that it takes a function that
21882194
accepts a single parameter of type :class:`.discord.Message` and must
@@ -2222,7 +2228,7 @@ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
22222228
return decorator # type: ignore
22232229

22242230
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
2225-
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
2231+
"""A decorator that adds a maximum concurrency to a command
22262232
22272233
This enables you to only allow a certain number of command invocations at the same time,
22282234
for example if a command takes too long or if only one user can use it at a time. This
@@ -2247,7 +2253,7 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
22472253

22482254
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
22492255
value = MaxConcurrency(number, per=per, wait=wait)
2250-
if isinstance(func, Command):
2256+
if isinstance(func, (Command, ApplicationCommand)):
22512257
func._max_concurrency = value
22522258
else:
22532259
func.__commands_max_concurrency__ = value

0 commit comments

Comments
 (0)