26
26
from __future__ import annotations
27
27
28
28
import asyncio
29
+ import inspect
29
30
import time
30
31
from collections import deque
31
- from typing import TYPE_CHECKING , Any , Callable , Deque , TypeVar
32
+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Deque , TypeVar
32
33
33
34
import discord .abc
34
35
from discord .enums import Enum
37
38
from .errors import MaxConcurrencyReached
38
39
39
40
if TYPE_CHECKING :
41
+ from ...commands import ApplicationContext
42
+ from ...ext .commands import Context
40
43
from ...message import Message
41
44
42
45
__all__ = (
@@ -60,31 +63,35 @@ class BucketType(Enum):
60
63
category = 5
61
64
role = 6
62
65
63
- def get_key (self , msg : Message ) -> Any :
66
+ def get_key (self , ctx : Context | ApplicationContext ) -> Any :
64
67
if self is BucketType .user :
65
- return msg .author .id
68
+ return ctx .author .id
66
69
elif self is BucketType .guild :
67
- return (msg .guild or msg .author ).id
70
+ return (ctx .guild or ctx .author ).id
68
71
elif self is BucketType .channel :
69
- return msg .channel .id
72
+ return ctx .channel .id
70
73
elif self is BucketType .member :
71
- return (msg .guild and msg .guild .id ), msg .author .id
74
+ return (ctx .guild and ctx .guild .id ), ctx .author .id
72
75
elif self is BucketType .category :
73
76
return (
74
- msg .channel .category .id
75
- if isinstance (msg .channel , discord .abc .GuildChannel )
76
- and msg .channel .category
77
- else msg .channel .id
77
+ ctx .channel .category .id
78
+ if isinstance (ctx .channel , discord .abc .GuildChannel )
79
+ and ctx .channel .category
80
+ else ctx .channel .id
78
81
)
79
82
elif self is BucketType .role :
80
83
# we return the channel id of a private-channel as there are only roles in guilds
81
84
# and that yields the same result as for a guild with only the @everyone role
82
85
# NOTE: PrivateChannel doesn't actually have an id attribute, but we assume we are
83
86
# receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do
84
- return (msg .channel if isinstance (msg .channel , PrivateChannel ) else msg .author .top_role ).id # type: ignore
87
+ return (
88
+ ctx .channel
89
+ if isinstance (ctx .channel , PrivateChannel )
90
+ else ctx .author .top_role
91
+ ).id # type: ignore
85
92
86
- def __call__ (self , msg : Message ) -> Any :
87
- return self .get_key (msg )
93
+ def __call__ (self , ctx : Context | ApplicationContext ) -> Any :
94
+ return self .get_key (ctx )
88
95
89
96
90
97
class Cooldown :
@@ -208,14 +215,14 @@ class CooldownMapping:
208
215
def __init__ (
209
216
self ,
210
217
original : Cooldown | None ,
211
- type : Callable [[Message ], Any ],
218
+ type : Callable [[Context | ApplicationContext ], Any ],
212
219
) -> None :
213
220
if not callable (type ):
214
221
raise TypeError ("Cooldown type must be a BucketType or callable" )
215
222
216
223
self ._cache : dict [Any , Cooldown ] = {}
217
224
self ._cooldown : Cooldown | None = original
218
- self ._type : Callable [[Message ], Any ] = type
225
+ self ._type : Callable [[Context | ApplicationContext ], Any ] = type
219
226
220
227
def copy (self ) -> CooldownMapping :
221
228
ret = CooldownMapping (self ._cooldown , self ._type )
@@ -227,15 +234,15 @@ def valid(self) -> bool:
227
234
return self ._cooldown is not None
228
235
229
236
@property
230
- def type (self ) -> Callable [[Message ], Any ]:
237
+ def type (self ) -> Callable [[Context | ApplicationContext ], Any ]:
231
238
return self ._type
232
239
233
240
@classmethod
234
241
def from_cooldown (cls : type [C ], rate , per , type ) -> C :
235
242
return cls (Cooldown (rate , per ), type )
236
243
237
- def _bucket_key (self , msg : Message ) -> Any :
238
- return self ._type (msg )
244
+ def _bucket_key (self , ctx : Context | ApplicationContext ) -> Any :
245
+ return self ._type (ctx )
239
246
240
247
def _verify_cache_integrity (self , current : float | None = None ) -> None :
241
248
# we want to delete all cache objects that haven't been used
@@ -246,37 +253,45 @@ def _verify_cache_integrity(self, current: float | None = None) -> None:
246
253
for k in dead_keys :
247
254
del self ._cache [k ]
248
255
249
- def create_bucket (self , message : Message ) -> Cooldown :
256
+ async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
250
257
return self ._cooldown .copy () # type: ignore
251
258
252
- def get_bucket (self , message : Message , current : float | None = None ) -> Cooldown :
259
+ async def get_bucket (
260
+ self , ctx : Context | ApplicationContext , current : float | None = None
261
+ ) -> Cooldown :
253
262
if self ._type is BucketType .default :
254
263
return self ._cooldown # type: ignore
255
264
256
265
self ._verify_cache_integrity (current )
257
- key = self ._bucket_key (message )
266
+ key = self ._bucket_key (ctx )
258
267
if key not in self ._cache :
259
- bucket = self .create_bucket (message )
268
+ bucket = await self .create_bucket (ctx )
260
269
if bucket is not None :
261
270
self ._cache [key ] = bucket
262
271
else :
263
272
bucket = self ._cache [key ]
264
273
265
274
return bucket
266
275
267
- def update_rate_limit (
268
- self , message : Message , current : float | None = None
276
+ async def update_rate_limit (
277
+ self , ctx : Context | ApplicationContext , current : float | None = None
269
278
) -> float | None :
270
- bucket = self .get_bucket (message , current )
279
+ bucket = await self .get_bucket (ctx , current )
271
280
return bucket .update_rate_limit (current )
272
281
273
282
274
283
class DynamicCooldownMapping (CooldownMapping ):
275
284
def __init__ (
276
- self , factory : Callable [[Message ], Cooldown ], type : Callable [[Message ], Any ]
285
+ self ,
286
+ factory : Callable [
287
+ [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
288
+ ],
289
+ type : Callable [[Context | ApplicationContext ], Any ],
277
290
) -> None :
278
291
super ().__init__ (None , type )
279
- self ._factory : Callable [[Message ], Cooldown ] = factory
292
+ self ._factory : Callable [
293
+ [Context | ApplicationContext ], Cooldown | Awaitable [Cooldown ]
294
+ ] = factory
280
295
281
296
def copy (self ) -> DynamicCooldownMapping :
282
297
ret = DynamicCooldownMapping (self ._factory , self ._type )
@@ -287,8 +302,16 @@ def copy(self) -> DynamicCooldownMapping:
287
302
def valid (self ) -> bool :
288
303
return True
289
304
290
- def create_bucket (self , message : Message ) -> Cooldown :
291
- return self ._factory (message )
305
+ async def create_bucket (self , ctx : Context | ApplicationContext ) -> Cooldown :
306
+ from ...ext .commands import Context
307
+
308
+ if isinstance (ctx , Context ):
309
+ result = self ._factory (ctx .message )
310
+ else :
311
+ result = self ._factory (ctx )
312
+ if inspect .isawaitable (result ):
313
+ return await result
314
+ return result
292
315
293
316
294
317
class _Semaphore :
0 commit comments