Skip to content

Commit ae423a0

Browse files
committed
fix: update cooldown handling to support async operations
1 parent 8619b69 commit ae423a0

File tree

4 files changed

+83
-63
lines changed

4 files changed

+83
-63
lines changed

discord/commands/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,10 @@ def guild_only(self, value: bool) -> None:
333333
InteractionContextType.private_channel,
334334
}
335335

336-
def _prepare_cooldowns(self, ctx: ApplicationContext):
336+
async def _prepare_cooldowns(self, ctx: ApplicationContext):
337337
if self._buckets.valid:
338338
current = datetime.datetime.now().timestamp()
339-
bucket = self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message
339+
bucket = await self._buckets.get_bucket(ctx, current) # type: ignore # ctx instead of non-existent message
340340

341341
if bucket is not None:
342342
retry_after = bucket.update_rate_limit(current)
@@ -360,7 +360,7 @@ async def prepare(self, ctx: ApplicationContext) -> None:
360360
await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message
361361

362362
try:
363-
self._prepare_cooldowns(ctx)
363+
await self._prepare_cooldowns(ctx)
364364
await self.call_before_hooks(ctx)
365365
except:
366366
if self._max_concurrency is not None:

discord/ext/commands/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,11 @@ def me(self) -> Member | ClientUser:
311311
message contexts, or when :meth:`Intents.guilds` is absent.
312312
"""
313313
# bot.user will never be None at this point.
314-
return self.guild.me if self.guild is not None and self.guild.me is not None else self.bot.user # type: ignore
314+
return (
315+
self.guild.me
316+
if self.guild is not None and self.guild.me is not None
317+
else self.bot.user
318+
) # type: ignore
315319

316320
@property
317321
def voice_client(self) -> VoiceProtocol | None:

discord/ext/commands/cooldowns.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from __future__ import annotations
2727

2828
import asyncio
29+
import inspect
2930
import time
3031
from collections import deque
31-
from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar
32+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Deque, TypeVar
3233

3334
import discord.abc
3435
from discord.enums import Enum
@@ -37,6 +38,8 @@
3738
from .errors import MaxConcurrencyReached
3839

3940
if TYPE_CHECKING:
41+
from ...commands import ApplicationContext
42+
from ...ext.commands import Context
4043
from ...message import Message
4144

4245
__all__ = (
@@ -60,31 +63,35 @@ class BucketType(Enum):
6063
category = 5
6164
role = 6
6265

63-
def get_key(self, msg: Message) -> Any:
66+
def get_key(self, ctx: Context | ApplicationContext) -> Any:
6467
if self is BucketType.user:
65-
return msg.author.id
68+
return ctx.author.id
6669
elif self is BucketType.guild:
67-
return (msg.guild or msg.author).id
70+
return (ctx.guild or ctx.author).id
6871
elif self is BucketType.channel:
69-
return msg.channel.id
72+
return ctx.channel.id
7073
elif self is BucketType.member:
71-
return (msg.guild and msg.guild.id), msg.author.id
74+
return (ctx.guild and ctx.guild.id), ctx.author.id
7275
elif self is BucketType.category:
7376
return (
74-
msg.channel.category.id
75-
if isinstance(msg.channel, discord.abc.GuildChannel)
76-
and msg.channel.category
77-
else msg.channel.id
77+
ctx.channel.category.id
78+
if isinstance(ctx.channel, discord.abc.GuildChannel)
79+
and ctx.channel.category
80+
else ctx.channel.id
7881
)
7982
elif self is BucketType.role:
8083
# we return the channel id of a private-channel as there are only roles in guilds
8184
# and that yields the same result as for a guild with only the @everyone role
8285
# NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are
8386
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
84-
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
87+
return (
88+
ctx.channel
89+
if isinstance(ctx.channel, PrivateChannel)
90+
else ctx.author.top_role
91+
).id # type: ignore
8592

86-
def __call__(self, msg: Message) -> Any:
87-
return self.get_key(msg)
93+
def __call__(self, ctx: Context | ApplicationContext) -> Any:
94+
return self.get_key(ctx)
8895

8996

9097
class Cooldown:
@@ -208,14 +215,14 @@ class CooldownMapping:
208215
def __init__(
209216
self,
210217
original: Cooldown | None,
211-
type: Callable[[Message], Any],
218+
type: Callable[[Context | ApplicationContext], Any],
212219
) -> None:
213220
if not callable(type):
214221
raise TypeError("Cooldown type must be a BucketType or callable")
215222

216223
self._cache: dict[Any, Cooldown] = {}
217224
self._cooldown: Cooldown | None = original
218-
self._type: Callable[[Message], Any] = type
225+
self._type: Callable[[Context | ApplicationContext], Any] = type
219226

220227
def copy(self) -> CooldownMapping:
221228
ret = CooldownMapping(self._cooldown, self._type)
@@ -227,15 +234,15 @@ def valid(self) -> bool:
227234
return self._cooldown is not None
228235

229236
@property
230-
def type(self) -> Callable[[Message], Any]:
237+
def type(self) -> Callable[[Context | ApplicationContext], Any]:
231238
return self._type
232239

233240
@classmethod
234241
def from_cooldown(cls: type[C], rate, per, type) -> C:
235242
return cls(Cooldown(rate, per), type)
236243

237-
def _bucket_key(self, msg: Message) -> Any:
238-
return self._type(msg)
244+
def _bucket_key(self, ctx: Context | ApplicationContext) -> Any:
245+
return self._type(ctx)
239246

240247
def _verify_cache_integrity(self, current: float | None = None) -> None:
241248
# we want to delete all cache objects that haven't been used
@@ -246,37 +253,45 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
246253
for k in dead_keys:
247254
del self._cache[k]
248255

249-
def create_bucket(self, message: Message) -> Cooldown:
256+
async def create_bucket(self, ctx: Context | ApplicationContext) -> Cooldown:
250257
return self._cooldown.copy() # type: ignore
251258

252-
def get_bucket(self, message: Message, current: float | None = None) -> Cooldown:
259+
async def get_bucket(
260+
self, ctx: Context | ApplicationContext, current: float | None = None
261+
) -> Cooldown:
253262
if self._type is BucketType.default:
254263
return self._cooldown # type: ignore
255264

256265
self._verify_cache_integrity(current)
257-
key = self._bucket_key(message)
266+
key = self._bucket_key(ctx)
258267
if key not in self._cache:
259-
bucket = self.create_bucket(message)
268+
bucket = await self.create_bucket(ctx)
260269
if bucket is not None:
261270
self._cache[key] = bucket
262271
else:
263272
bucket = self._cache[key]
264273

265274
return bucket
266275

267-
def update_rate_limit(
268-
self, message: Message, current: float | None = None
276+
async def update_rate_limit(
277+
self, ctx: Context | ApplicationContext, current: float | None = None
269278
) -> float | None:
270-
bucket = self.get_bucket(message, current)
279+
bucket = await self.get_bucket(ctx, current)
271280
return bucket.update_rate_limit(current)
272281

273282

274283
class DynamicCooldownMapping(CooldownMapping):
275284
def __init__(
276-
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
285+
self,
286+
factory: Callable[
287+
[Context | ApplicationContext], Cooldown | Awaitable[Cooldown]
288+
],
289+
type: Callable[[Context | ApplicationContext], Any],
277290
) -> None:
278291
super().__init__(None, type)
279-
self._factory: Callable[[Message], Cooldown] = factory
292+
self._factory: Callable[
293+
[Context | ApplicationContext], Cooldown | Awaitable[Cooldown]
294+
] = factory
280295

281296
def copy(self) -> DynamicCooldownMapping:
282297
ret = DynamicCooldownMapping(self._factory, self._type)
@@ -287,8 +302,16 @@ def copy(self) -> DynamicCooldownMapping:
287302
def valid(self) -> bool:
288303
return True
289304

290-
def create_bucket(self, message: Message) -> Cooldown:
291-
return self._factory(message)
305+
async def create_bucket(self, ctx: Context | ApplicationContext) -> Cooldown:
306+
from ...ext.commands import Context
307+
308+
if isinstance(ctx, Context):
309+
result = self._factory(ctx.message)
310+
else:
311+
result = self._factory(ctx)
312+
if inspect.isawaitable(result):
313+
return await result
314+
return result
292315

293316

294317
class _Semaphore:

discord/ext/commands/core.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from typing import (
3434
TYPE_CHECKING,
3535
Any,
36+
Awaitable,
3637
Callable,
3738
Generator,
3839
Generic,
@@ -69,6 +70,7 @@
6970
if TYPE_CHECKING:
7071
from typing_extensions import Concatenate, ParamSpec, TypeGuard
7172

73+
from discord import ApplicationContext
7274
from discord.message import Message
7375

7476
from ._types import Check, Coro, CoroFunc, Error, Hook
@@ -397,7 +399,9 @@ def __init__(
397399

398400
# bandaid for the fact that sometimes parent can be the bot instance
399401
parent = kwargs.get("parent")
400-
self.parent: GroupMixin | None = parent if isinstance(parent, _BaseCommand) else None # type: ignore
402+
self.parent: GroupMixin | None = (
403+
parent if isinstance(parent, _BaseCommand) else None
404+
) # type: ignore
401405

402406
self._before_invoke: Hook | None = None
403407
try:
@@ -850,11 +854,11 @@ async def call_after_hooks(self, ctx: Context) -> None:
850854
if hook is not None:
851855
await hook(ctx)
852856

853-
def _prepare_cooldowns(self, ctx: Context) -> None:
857+
async def _prepare_cooldowns(self, ctx: Context) -> None:
854858
if self._buckets.valid:
855859
dt = ctx.message.edited_at or ctx.message.created_at
856860
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
857-
bucket = self._buckets.get_bucket(ctx.message, current)
861+
bucket = await self._buckets.get_bucket(ctx, current)
858862
if bucket is not None:
859863
retry_after = bucket.update_rate_limit(current)
860864
if retry_after:
@@ -875,9 +879,9 @@ async def prepare(self, ctx: Context) -> None:
875879
try:
876880
if self.cooldown_after_parsing:
877881
await self._parse_arguments(ctx)
878-
self._prepare_cooldowns(ctx)
882+
await self._prepare_cooldowns(ctx)
879883
else:
880-
self._prepare_cooldowns(ctx)
884+
await self._prepare_cooldowns(ctx)
881885
await self._parse_arguments(ctx)
882886

883887
await self.call_before_hooks(ctx)
@@ -1204,7 +1208,9 @@ async def can_run(self, ctx: Context) -> bool:
12041208
# since we have no checks, then we just return True.
12051209
return True
12061210

1207-
return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore
1211+
return await discord.utils.async_all(
1212+
predicate(ctx) for predicate in predicates
1213+
) # type: ignore
12081214
finally:
12091215
ctx.command = original
12101216

@@ -2353,36 +2359,23 @@ def decorator(func: Command | CoroFunc) -> Command | CoroFunc:
23532359

23542360

23552361
def dynamic_cooldown(
2356-
cooldown: BucketType | Callable[[Message], Any],
2362+
cooldown: Callable[
2363+
[Context | ApplicationContext], Cooldown | Awaitable[Cooldown] | None
2364+
],
23572365
type: BucketType = BucketType.default,
23582366
) -> Callable[[T], T]:
2359-
"""A decorator that adds a dynamic cooldown to a command
2360-
2361-
This differs from :func:`.cooldown` in that it takes a function that
2362-
accepts a single parameter of type :class:`.discord.Message` and must
2363-
return a :class:`.Cooldown` or ``None``. If ``None`` is returned then
2364-
that cooldown is effectively bypassed.
2365-
2366-
A cooldown allows a command to only be used a specific amount
2367-
of times in a specific time frame. These cooldowns can be based
2368-
either on a per-guild, per-channel, per-user, per-role or global basis.
2369-
Denoted by the third argument of ``type`` which must be of enum
2370-
type :class:`.BucketType`.
2371-
2372-
If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in
2373-
:func:`.on_command_error` and the local error handler.
2374-
2375-
A command can only have a single cooldown.
2367+
"""A decorator that adds a dynamic cooldown to a command.
23762368
2377-
.. versionadded:: 2.0
2369+
This supports both sync and async cooldown factories and accepts either
2370+
a :class:`discord.Message` or :class:`discord.ApplicationContext`.
23782371
23792372
Parameters
23802373
----------
2381-
cooldown: Callable[[:class:`.discord.Message`], Optional[:class:`.Cooldown`]]
2382-
A function that takes a message and returns a cooldown that will
2383-
apply to this invocation or ``None`` if the cooldown should be bypassed.
2384-
type: :class:`.BucketType`
2385-
The type of cooldown to have.
2374+
cooldown: Callable[[Union[Message, ApplicationContext]], Union[Cooldown, Awaitable[Cooldown], None]]
2375+
A function that takes a message or context and returns a cooldown
2376+
to apply for that invocation or ``None`` to bypass.
2377+
type: :class:`BucketType`
2378+
The cooldown bucket type (e.g. per-user, per-channel).
23862379
"""
23872380
if not callable(cooldown):
23882381
raise TypeError("A callable must be provided")

0 commit comments

Comments
 (0)