2626from __future__ import annotations
2727
2828import asyncio
29+ import inspect
2930import time
3031from collections import deque
31- from typing import TYPE_CHECKING , Any , Callable , Deque , TypeVar
32+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Deque , TypeVar
3233
3334import discord .abc
3435from discord .enums import Enum
3738from .errors import MaxConcurrencyReached
3839
3940if TYPE_CHECKING :
41+ from ...commands import ApplicationContext
42+ from ...ext .commands import Context
4043 from ...message import Message
4144
4245__all__ = (
@@ -60,31 +63,35 @@ class BucketType(Enum):
6063 category = 5
6164 role = 6
6265
63- def get_key (self , msg : Message ) -> Any :
66+ def get_key (self , ctx : Context | ApplicationContext ) -> Any :
6467 if self is BucketType .user :
65- return msg .author .id
68+ return ctx .author .id
6669 elif self is BucketType .guild :
67- return (msg .guild or msg .author ).id
70+ return (ctx .guild or ctx .author ).id
6871 elif self is BucketType .channel :
69- return msg .channel .id
72+ return ctx .channel .id
7073 elif self is BucketType .member :
71- return (msg .guild and msg .guild .id ), msg .author .id
74+ return (ctx .guild and ctx .guild .id ), ctx .author .id
7275 elif self is BucketType .category :
7376 return (
74- msg .channel .category .id
75- if isinstance (msg .channel , discord .abc .GuildChannel )
76- and msg .channel .category
77- else msg .channel .id
77+ ctx .channel .category .id
78+ if isinstance (ctx .channel , discord .abc .GuildChannel )
79+ and ctx .channel .category
80+ else ctx .channel .id
7881 )
7982 elif self is BucketType .role :
8083 # we return the channel id of a private-channel as there are only roles in guilds
8184 # and that yields the same result as for a guild with only the @everyone role
8285 # NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are
8386 # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
84- return (msg .channel if isinstance (msg .channel , PrivateChannel ) else msg .author .top_role ).id # type: ignore
87+ return (
88+ ctx .channel
89+ if isinstance (ctx .channel , PrivateChannel )
90+ else ctx .author .top_role
91+ ).id # type: ignore
8592
86- def __call__ (self , msg : Message ) -> Any :
87- return self .get_key (msg )
93+ def __call__ (self , ctx : Context | ApplicationContext ) -> Any :
94+ return self .get_key (ctx )
8895
8996
9097class Cooldown :
@@ -208,14 +215,14 @@ class CooldownMapping:
208215 def __init__ (
209216 self ,
210217 original : Cooldown | None ,
211- type : Callable [[Message ], Any ],
218+ type : Callable [[Context | ApplicationContext ], Any ],
212219 ) -> None :
213220 if not callable (type ):
214221 raise TypeError ("Cooldown type must be a BucketType or callable" )
215222
216223 self ._cache : dict [Any , Cooldown ] = {}
217224 self ._cooldown : Cooldown | None = original
218- self ._type : Callable [[Message ], Any ] = type
225+ self ._type : Callable [[Context | ApplicationContext ], Any ] = type
219226
220227 def copy (self ) -> CooldownMapping :
221228 ret = CooldownMapping (self ._cooldown , self ._type )
@@ -227,15 +234,15 @@ def valid(self) -> bool:
227234 return self ._cooldown is not None
228235
229236 @property
230- def type (self ) -> Callable [[Message ], Any ]:
237+ def type (self ) -> Callable [[Context | ApplicationContext ], Any ]:
231238 return self ._type
232239
233240 @classmethod
234241 def from_cooldown (cls : type [C ], rate , per , type ) -> C :
235242 return cls (Cooldown (rate , per ), type )
236243
237- def _bucket_key (self , msg : Message ) -> Any :
238- return self ._type (msg )
244+ def _bucket_key (self , ctx : Context | ApplicationContext ) -> Any :
245+ return self ._type (ctx )
239246
240247 def _verify_cache_integrity (self , current : float | None = None ) -> None :
241248 # we want to delete all cache objects that haven't been used
@@ -246,37 +253,45 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
246253 for k in dead_keys :
247254 del self ._cache [k ]
248255
249- def create_bucket (self , message : Message ) -> Cooldown :
256+ async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
250257 return self ._cooldown .copy () # type: ignore
251258
252- def get_bucket (self , message : Message , current : float | None = None ) -> Cooldown :
259+ async def get_bucket (
260+ self , ctx : Context | ApplicationContext , current : float | None = None
261+ ) -> Cooldown :
253262 if self ._type is BucketType .default :
254263 return self ._cooldown # type: ignore
255264
256265 self ._verify_cache_integrity (current )
257- key = self ._bucket_key (message )
266+ key = self ._bucket_key (ctx )
258267 if key not in self ._cache :
259- bucket = self .create_bucket (message )
268+ bucket = await self .create_bucket (ctx )
260269 if bucket is not None :
261270 self ._cache [key ] = bucket
262271 else :
263272 bucket = self ._cache [key ]
264273
265274 return bucket
266275
267- def update_rate_limit (
268- self , message : Message , current : float | None = None
276+ async def update_rate_limit (
277+ self , ctx : Context | ApplicationContext , current : float | None = None
269278 ) -> float | None :
270- bucket = self .get_bucket (message , current )
279+ bucket = await self .get_bucket (ctx , current )
271280 return bucket .update_rate_limit (current )
272281
273282
274283class DynamicCooldownMapping (CooldownMapping ):
275284 def __init__ (
276- self , factory : Callable [[Message ], Cooldown ], type : Callable [[Message ], Any ]
285+ self ,
286+ factory : Callable [
287+ [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
288+ ],
289+ type : Callable [[Context | ApplicationContext ], Any ],
277290 ) -> None :
278291 super ().__init__ (None , type )
279- self ._factory : Callable [[Message ], Cooldown ] = factory
292+ self ._factory : Callable [
293+ [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
294+ ] = factory
280295
281296 def copy (self ) -> DynamicCooldownMapping :
282297 ret = DynamicCooldownMapping (self ._factory , self ._type )
@@ -287,8 +302,16 @@ def copy(self) -> DynamicCooldownMapping:
287302 def valid (self ) -> bool :
288303 return True
289304
290- def create_bucket (self , message : Message ) -> Cooldown :
291- return self ._factory (message )
305+ async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
306+ from ...ext .commands import Context
307+
308+ if isinstance (ctx , Context ):
309+ result = self ._factory (ctx .message )
310+ else :
311+ result = self ._factory (ctx )
312+ if inspect .isawaitable (result ):
313+ return await result
314+ return result
292315
293316
294317class _Semaphore :
0 commit comments