87
87
88
88
T = TypeVar ("T" )
89
89
BE = TypeVar ("BE" , bound = BaseException )
90
- MU = TypeVar ("MU" , bound = "MaybeUnlock" )
91
90
Response = Coroutine [Any , Any , T ]
92
91
93
92
API_VERSION : int = 10
@@ -106,61 +105,92 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str
106
105
107
106
108
107
class Route :
109
- API_BASE_URL : str = "https://discord.com/api/v{API_VERSION}"
110
-
111
- def __init__ (self , method : str , path : str , ** parameters : Any ) -> None :
112
- self .path : str = path
113
- self .method : str = method
114
- url = self .base + self .path
115
- if parameters :
116
- url = url .format_map (
117
- {
118
- k : _uriquote (v ) if isinstance (v , str ) else v
119
- for k , v in parameters .items ()
120
- }
121
- )
122
- self .url : str = url
108
+ def __init__ (
109
+ self ,
110
+ method : str ,
111
+ path : str ,
112
+ guild_id : str | None = None ,
113
+ channel_id : str | None = None ,
114
+ webhook_id : str | None = None ,
115
+ webhook_token : str | None = None ,
116
+ ** parameters : str | int ,
117
+ ):
118
+ self .method = method
119
+ self .path = path
123
120
124
- # major parameters:
125
- self .channel_id : Snowflake | None = parameters . get ( "channel_id" )
126
- self .guild_id : Snowflake | None = parameters . get ( "guild_id" )
127
- self .webhook_id : Snowflake | None = parameters . get ( " webhook_id" )
128
- self .webhook_token : str | None = parameters . get ( " webhook_token" )
121
+ # major parameters
122
+ self .guild_id = guild_id
123
+ self .channel_id = channel_id
124
+ self .webhook_id = webhook_id
125
+ self .webhook_token = webhook_token
129
126
130
- @property
131
- def base (self ) -> str :
132
- return self .API_BASE_URL .format (API_VERSION = API_VERSION )
127
+ self .parameters = parameters
133
128
134
- @property
135
- def bucket (self ) -> str :
136
- # the bucket is just method + path w/ major parameters
137
- return f"{ self .channel_id } :{ self .guild_id } :{ self .path } "
129
+ def merge (self , url : str ):
130
+ return url + self .path .format (
131
+ guild_id = self .guild_id ,
132
+ channel_id = self .channel_id ,
133
+ webhook_id = self .webhook_id ,
134
+ webhook_token = self .webhook_token ,
135
+ ** self .parameters ,
136
+ )
138
137
138
+ def __eq__ (self , route : 'Route' ) -> bool :
139
+ return (
140
+ route .channel_id == self .channel_id
141
+ or route .guild_id == self .guild_id
142
+ or route .webhook_id == self .webhook_id
143
+ or route .webhook_token == self .webhook_token
144
+ ) and route .method == self .method
139
145
140
- class MaybeUnlock :
141
- def __init__ (self , lock : asyncio .Lock ) -> None :
142
- self .lock : asyncio .Lock = lock
143
- self ._unlock : bool = True
144
146
145
- def __enter__ (self : MU ) -> MU :
146
- return self
147
147
148
- def defer (self ) -> None :
149
- self ._unlock = False
148
+ class Executor :
149
+ def __init__ (self , route : Route ) -> None :
150
+ self .route = route
151
+ self .is_global : bool | None = None
152
+ self ._request_queue : asyncio .Queue [asyncio .Event ] | None = None
153
+ self .rate_limited : bool = False
150
154
151
- def __exit__ (
152
- self ,
153
- exc_type : type [BE ] | None ,
154
- exc : BE | None ,
155
- traceback : TracebackType | None ,
155
+ async def executed (
156
+ self , reset_after : int | float , limit : int , is_global : bool
156
157
) -> None :
157
- if self ._unlock :
158
- self .lock .release ()
158
+ self .rate_limited = True
159
+ self .is_global = is_global
160
+ self ._reset_after = reset_after
161
+ self ._request_queue = asyncio .Queue ()
162
+
163
+ await asyncio .sleep (reset_after )
159
164
165
+ self .is_global = False
160
166
161
- # For some reason, the Discord voice websocket expects this header to be
162
- # completely lowercase while aiohttp respects spec and does it as case-insensitive
163
- aiohttp .hdrs .WEBSOCKET = "websocket" # type: ignore
167
+ # NOTE: This could break if someone did a second global rate limit somehow
168
+ requests_passed : int = 0
169
+ for _ in range (self ._request_queue .qsize () - 1 ):
170
+ if requests_passed == limit :
171
+ requests_passed = 0
172
+ if not is_global :
173
+ await asyncio .sleep (reset_after )
174
+ else :
175
+ await asyncio .sleep (5 )
176
+
177
+ requests_passed += 1
178
+ e = await self ._request_queue .get ()
179
+ e .set ()
180
+
181
+ async def wait (self ) -> None :
182
+ if not self .rate_limited :
183
+ return
184
+
185
+ event = asyncio .Event ()
186
+
187
+ if self ._request_queue :
188
+ self ._request_queue .put_nowait (event )
189
+ else :
190
+ raise ValueError (
191
+ 'Request queue does not exist, rate limit may have been solved.'
192
+ )
193
+ await event .wait ()
164
194
165
195
166
196
class HTTPClient :
@@ -174,20 +204,20 @@ def __init__(
174
204
proxy_auth : aiohttp .BasicAuth | None = None ,
175
205
loop : asyncio .AbstractEventLoop | None = None ,
176
206
unsync_clock : bool = True ,
207
+ discord_api_url : str = "https://discord.com/api/v10"
177
208
) -> None :
209
+ self .api_url = discord_api_url
178
210
self .loop : asyncio .AbstractEventLoop = (
179
211
asyncio .get_event_loop () if loop is None else loop
180
212
)
181
213
self .connector = connector
182
214
self .__session : aiohttp .ClientSession | utils .Undefined = MISSING # filled in static_login
183
- self ._locks : weakref .WeakValueDictionary = weakref .WeakValueDictionary ()
184
- self ._global_over : asyncio .Event = asyncio .Event ()
185
- self ._global_over .set ()
186
215
self .token : str | None = None
187
216
self .bot_token : bool = False
188
217
self .proxy : str | None = proxy
189
218
self .proxy_auth : aiohttp .BasicAuth | None = proxy_auth
190
219
self .use_clock : bool = not unsync_clock
220
+ self ._executors : list [Executor ] = []
191
221
192
222
user_agent = (
193
223
"DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
@@ -226,15 +256,9 @@ async def request(
226
256
form : Iterable [dict [str , Any ]] | None = None ,
227
257
** kwargs : Any ,
228
258
) -> Any :
229
- bucket = route .bucket
259
+ bucket = route .merge ( self . api_url )
230
260
method = route .method
231
- url = route .url
232
-
233
- lock = self ._locks .get (bucket )
234
- if lock is None :
235
- lock = asyncio .Lock ()
236
- if bucket is not None :
237
- self ._locks [bucket ] = lock
261
+ url = bucket
238
262
239
263
# header creation
240
264
headers : dict [str , str ] = {
@@ -266,123 +290,97 @@ async def request(
266
290
if self .proxy_auth is not None :
267
291
kwargs ["proxy_auth" ] = self .proxy_auth
268
292
269
- if not self ._global_over .is_set ():
270
- # wait until the global lock is complete
271
- await self ._global_over .wait ()
272
-
273
293
response : aiohttp .ClientResponse | None = None
274
294
data : dict [str , Any ] | str | None = None
275
- await lock .acquire ()
276
- with MaybeUnlock (lock ) as maybe_lock :
277
- for tries in range (5 ):
278
- if files :
279
- for f in files :
280
- f .reset (seek = tries )
281
-
282
- if form :
283
- form_data = aiohttp .FormData (quote_fields = False )
284
- for params in form :
285
- form_data .add_field (** params )
286
- kwargs ["data" ] = form_data
287
-
288
- try :
289
- async with self .__session .request (
290
- method , url , ** kwargs
291
- ) as response :
292
- _log .debug (
293
- "%s %s with %s has returned %s" ,
294
- method ,
295
- url ,
296
- kwargs .get ("data" ),
297
- response .status ,
295
+
296
+ for executor in self ._executors :
297
+ if executor .is_global or executor .route == route :
298
+ _log .debug (f'Pausing request to { route } : Found rate limit executor' )
299
+ await executor .wait ()
300
+
301
+ for tries in range (5 ):
302
+ if files :
303
+ for f in files :
304
+ f .reset (seek = tries )
305
+
306
+ if form :
307
+ form_data = aiohttp .FormData (quote_fields = False )
308
+ for params in form :
309
+ form_data .add_field (** params )
310
+ kwargs ["data" ] = form_data
311
+
312
+ try :
313
+ async with self .__session .request (
314
+ method , url , ** kwargs
315
+ ) as response :
316
+ _log .debug (
317
+ "%s %s with %s has returned %s" ,
318
+ method ,
319
+ url ,
320
+ kwargs .get ("data" ),
321
+ response .status ,
322
+ )
323
+
324
+ # even errors have text involved in them so this is safe to call
325
+ data = await json_or_text (response )
326
+
327
+ # check if we have rate limit header information
328
+ remaining = response .headers .get ("X-Ratelimit-Remaining" )
329
+ if remaining == "0" and response .status != 429 :
330
+ _log .debug (f'Request to { route } failed: Request returned rate limit' )
331
+ executor = Executor (route = route )
332
+
333
+ self ._executors .append (executor )
334
+ await executor .executed (
335
+ # NOTE: 5 is just a placeholder since this should always be present
336
+ reset_after = float (response .headers .get ('X-RateLimit-Reset-After' , "5" )),
337
+ is_global = response .headers .get ('X-RateLimit-Scope' ) == 'global' ,
338
+ limit = int (response .headers .get ('X-RateLimit-Limit' , 10 )),
298
339
)
340
+ self ._executors .remove (executor )
341
+ continue
299
342
300
- # even errors have text involved in them so this is safe to call
301
- data = await json_or_text (response )
302
-
303
- # check if we have rate limit header information
304
- remaining = response .headers .get ("X-Ratelimit-Remaining" )
305
- if remaining == "0" and response .status != 429 :
306
- # we've depleted our current bucket
307
- delta = utils ._parse_ratelimit_header (
308
- response , use_clock = self .use_clock
309
- )
310
- _log .debug (
311
- (
312
- "A rate limit bucket has been exhausted (bucket:"
313
- " %s, retry: %s)."
314
- ),
315
- bucket ,
316
- delta ,
317
- )
318
- maybe_lock .defer ()
319
- self .loop .call_later (delta , lock .release )
320
-
321
- # the request was successful so just return the text/json
322
- if 300 > response .status >= 200 :
323
- _log .debug ("%s %s has received %s" , method , url , data )
324
- return data
325
-
326
- # we are being rate limited
327
- if response .status == 429 :
328
- if not response .headers .get ("Via" ) or isinstance (data , str ):
329
- # Banned by Cloudflare more than likely.
330
- raise HTTPException (response , data )
331
-
332
- fmt = (
333
- "We are being rate limited. Retrying in %.2f seconds."
334
- ' Handled under the bucket "%s"'
335
- )
336
-
337
- # sleep a bit
338
- retry_after : float = data ["retry_after" ]
339
- _log .warning (fmt , retry_after , bucket )
340
-
341
- # check if it's a global rate limit
342
- is_global = data .get ("global" , False )
343
- if is_global :
344
- _log .warning (
345
- (
346
- "Global rate limit has been hit. Retrying in"
347
- " %.2f seconds."
348
- ),
349
- retry_after ,
350
- )
351
- self ._global_over .clear ()
352
-
353
- await asyncio .sleep (retry_after )
354
- _log .debug ("Done sleeping for the rate limit. Retrying..." )
355
-
356
- # release the global lock now that the
357
- # global rate limit has passed
358
- if is_global :
359
- self ._global_over .set ()
360
- _log .debug ("Global rate limit is now over." )
361
-
362
- continue
363
-
364
- # we've received a 500, 502, 503, or 504, unconditional retry
365
- if response .status in {500 , 502 , 503 , 504 }:
366
- await asyncio .sleep (1 + tries * 2 )
367
- continue
368
-
369
- # the usual error cases
370
- if response .status == 403 :
371
- raise Forbidden (response , data )
372
- elif response .status == 404 :
373
- raise NotFound (response , data )
374
- elif response .status >= 500 :
375
- raise DiscordServerError (response , data )
376
- else :
377
- raise HTTPException (response , data )
378
-
379
- # This is handling exceptions from the request
380
- except OSError as e :
381
- # Connection reset by peer
382
- if tries < 4 and e .errno in (54 , 10054 ):
343
+ # the request was successful so just return the text/json
344
+ if 300 > response .status >= 200 :
345
+ _log .debug ("%s %s has received %s" , method , url , data )
346
+ return data
347
+
348
+ # we are being rate limited
349
+ if response .status == 429 :
350
+ _log .debug (f'Request to { route } failed: Request returned rate limit' )
351
+ executor = Executor (route = route )
352
+
353
+ self ._executors .append (executor )
354
+ await executor .executed (
355
+ reset_after = data ['retry_after' ],
356
+ is_global = response .headers .get ('X-RateLimit-Scope' ) == 'global' ,
357
+ limit = int (response .headers .get ('X-RateLimit-Limit' , 10 )),
358
+ )
359
+ self ._executors .remove (executor )
360
+ continue
361
+
362
+ # we've received a 500, 502, 503, or 504, unconditional retry
363
+ if response .status in {500 , 502 , 503 , 504 }:
383
364
await asyncio .sleep (1 + tries * 2 )
384
365
continue
385
- raise
366
+
367
+ # the usual error cases
368
+ if response .status == 403 :
369
+ raise Forbidden (response , data )
370
+ elif response .status == 404 :
371
+ raise NotFound (response , data )
372
+ elif response .status >= 500 :
373
+ raise DiscordServerError (response , data )
374
+ else :
375
+ raise HTTPException (response , data )
376
+
377
+ # This is handling exceptions from the request
378
+ except OSError as e :
379
+ # Connection reset by peer
380
+ if tries < 4 and e .errno in (54 , 10054 ):
381
+ await asyncio .sleep (1 + tries * 2 )
382
+ continue
383
+ raise
386
384
387
385
if response is not None :
388
386
# We've run out of retries, raise.
0 commit comments