26
26
from __future__ import annotations
27
27
28
28
import asyncio
29
+ import datetime
29
30
import functools
30
31
import inspect
31
32
import re
@@ -111,6 +112,29 @@ class _BaseCommand:
111
112
class ApplicationCommand (_BaseCommand , Generic [CogT , P , T ]):
112
113
cog = None
113
114
115
+ def __init__ (self , func : Callable , ** kwargs ) -> None :
116
+ from ..ext .commands .cooldowns import CooldownMapping , BucketType , MaxConcurrency
117
+
118
+ try :
119
+ cooldown = func .__commands_cooldown__
120
+ except AttributeError :
121
+ cooldown = kwargs .get ('cooldown' )
122
+
123
+ if cooldown is None :
124
+ buckets = CooldownMapping (cooldown , BucketType .default )
125
+ elif isinstance (cooldown , CooldownMapping ):
126
+ buckets = cooldown
127
+ else :
128
+ raise TypeError ("Cooldown must be a an instance of CooldownMapping or None." )
129
+ self ._buckets : CooldownMapping = buckets
130
+
131
+ try :
132
+ max_concurrency = func .__commands_max_concurrency__
133
+ except AttributeError :
134
+ max_concurrency = kwargs .get ('max_concurrency' )
135
+
136
+ self ._max_concurrency : Optional [MaxConcurrency ] = max_concurrency
137
+
114
138
def __repr__ (self ):
115
139
return f"<discord.commands.{ self .__class__ .__name__ } name={ self .name } >"
116
140
@@ -127,16 +151,48 @@ async def __call__(self, ctx, *args, **kwargs):
127
151
"""
128
152
return await self .callback (ctx , * args , ** kwargs )
129
153
154
+ def _prepare_cooldowns (self , ctx : ApplicationContext ):
155
+ if self ._buckets .valid :
156
+ current = datetime .datetime .now ().timestamp ()
157
+ bucket = self ._buckets .get_bucket (ctx , current ) # type: ignore (ctx instead of non-existent message)
158
+
159
+ if bucket is not None :
160
+ retry_after = bucket .update_rate_limit (current )
161
+
162
+ if retry_after :
163
+ from ..ext .commands .errors import CommandOnCooldown
164
+ raise CommandOnCooldown (bucket , retry_after , self ._buckets .type ) # type: ignore
165
+
130
166
async def prepare (self , ctx : ApplicationContext ) -> None :
131
167
# This should be same across all 3 types
132
168
ctx .command = self
133
169
134
170
if not await self .can_run (ctx ):
135
171
raise CheckFailure (f'The check functions for the command { self .name } failed' )
136
172
137
- # TODO: Add cooldown
173
+ if self ._max_concurrency is not None :
174
+ # For this application, context can be duck-typed as a Message
175
+ await self ._max_concurrency .acquire (ctx ) # type: ignore (ctx instead of non-existent message)
138
176
139
- await self .call_before_hooks (ctx )
177
+ try :
178
+ self ._prepare_cooldowns (ctx )
179
+ await self .call_before_hooks (ctx )
180
+ except :
181
+ if self ._max_concurrency is not None :
182
+ await self ._max_concurrency .release (ctx ) # type: ignore (ctx instead of non-existent message)
183
+ raise
184
+
185
+ def reset_cooldown (self , ctx : ApplicationContext ) -> None :
186
+ """Resets the cooldown on this command.
187
+
188
+ Parameters
189
+ -----------
190
+ ctx: :class:`.ApplicationContext`
191
+ The invocation context to reset the cooldown under.
192
+ """
193
+ if self ._buckets .valid :
194
+ bucket = self ._buckets .get_bucket (ctx ) # type: ignore (ctx instead of non-existent message)
195
+ bucket .reset ()
140
196
141
197
async def invoke (self , ctx : ApplicationContext ) -> None :
142
198
await self .prepare (ctx )
@@ -299,6 +355,10 @@ async def call_after_hooks(self, ctx: ApplicationContext) -> None:
299
355
if hook is not None :
300
356
await hook (ctx )
301
357
358
+ @property
359
+ def cooldown (self ):
360
+ return self ._buckets ._cooldown
361
+
302
362
@property
303
363
def full_parent_name (self ) -> str :
304
364
""":class:`str`: Retrieves the fully qualified parent command name.
@@ -371,6 +431,11 @@ class SlashCommand(ApplicationCommand):
371
431
:exc:`.CommandError` should be used. Note that if the checks fail then
372
432
:exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error`
373
433
event.
434
+ cooldown: Optional[:class:`~discord.ext.commands.Cooldown`]
435
+ The cooldown applied when the command is invoked. ``None`` if the command
436
+ doesn't have a cooldown.
437
+
438
+ .. versionadded:: 2.0
374
439
"""
375
440
type = 1
376
441
@@ -381,6 +446,7 @@ def __new__(cls, *args, **kwargs) -> SlashCommand:
381
446
return self
382
447
383
448
def __init__ (self , func : Callable , * args , ** kwargs ) -> None :
449
+ super ().__init__ (func , ** kwargs )
384
450
if not asyncio .iscoroutinefunction (func ):
385
451
raise TypeError ("Callback must be a coroutine." )
386
452
self .callback = func
@@ -963,6 +1029,11 @@ class ContextMenuCommand(ApplicationCommand):
963
1029
:exc:`.CommandError` should be used. Note that if the checks fail then
964
1030
:exc:`.CheckFailure` exception is raised to the :func:`.on_application_command_error`
965
1031
event.
1032
+ cooldown: Optional[:class:`~discord.ext.commands.Cooldown`]
1033
+ The cooldown applied when the command is invoked. ``None`` if the command
1034
+ doesn't have a cooldown.
1035
+
1036
+ .. versionadded:: 2.0
966
1037
"""
967
1038
def __new__ (cls , * args , ** kwargs ) -> ContextMenuCommand :
968
1039
self = super ().__new__ (cls )
@@ -971,6 +1042,7 @@ def __new__(cls, *args, **kwargs) -> ContextMenuCommand:
971
1042
return self
972
1043
973
1044
def __init__ (self , func : Callable , * args , ** kwargs ) -> None :
1045
+ super ().__init__ (func , ** kwargs )
974
1046
if not asyncio .iscoroutinefunction (func ):
975
1047
raise TypeError ("Callback must be a coroutine." )
976
1048
self .callback = func
0 commit comments