Skip to content

Commit 32d932d

Browse files
VincentRPSplun1331
andauthored
refactor: port over v3 rate limit code (#1)
* refactor: port over v3 rate limit code * fix: route stuff * fix: AHHH * fix: append api url to path * fix: final batch * fix: oops put that in the wrong place * chore: take method into equality check --------- Signed-off-by: plun1331 <[email protected]> Co-authored-by: plun1331 <[email protected]>
1 parent b6512d2 commit 32d932d

File tree

3 files changed

+171
-171
lines changed

3 files changed

+171
-171
lines changed

discord/http.py

Lines changed: 165 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787

8888
T = TypeVar("T")
8989
BE = TypeVar("BE", bound=BaseException)
90-
MU = TypeVar("MU", bound="MaybeUnlock")
9190
Response = Coroutine[Any, Any, T]
9291

9392
API_VERSION: int = 10
@@ -106,61 +105,92 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str
106105

107106

108107
class Route:
109-
API_BASE_URL: str = "https://discord.com/api/v{API_VERSION}"
110-
111-
def __init__(self, method: str, path: str, **parameters: Any) -> None:
112-
self.path: str = path
113-
self.method: str = method
114-
url = self.base + self.path
115-
if parameters:
116-
url = url.format_map(
117-
{
118-
k: _uriquote(v) if isinstance(v, str) else v
119-
for k, v in parameters.items()
120-
}
121-
)
122-
self.url: str = url
108+
def __init__(
109+
self,
110+
method: str,
111+
path: str,
112+
guild_id: str | None = None,
113+
channel_id: str | None = None,
114+
webhook_id: str | None = None,
115+
webhook_token: str | None = None,
116+
**parameters: str | int,
117+
):
118+
self.method = method
119+
self.path = path
123120

124-
# major parameters:
125-
self.channel_id: Snowflake | None = parameters.get("channel_id")
126-
self.guild_id: Snowflake | None = parameters.get("guild_id")
127-
self.webhook_id: Snowflake | None = parameters.get("webhook_id")
128-
self.webhook_token: str | None = parameters.get("webhook_token")
121+
# major parameters
122+
self.guild_id = guild_id
123+
self.channel_id = channel_id
124+
self.webhook_id = webhook_id
125+
self.webhook_token = webhook_token
129126

130-
@property
131-
def base(self) -> str:
132-
return self.API_BASE_URL.format(API_VERSION=API_VERSION)
127+
self.parameters = parameters
133128

134-
@property
135-
def bucket(self) -> str:
136-
# the bucket is just method + path w/ major parameters
137-
return f"{self.channel_id}:{self.guild_id}:{self.path}"
129+
def merge(self, url: str):
130+
return url + self.path.format(
131+
guild_id=self.guild_id,
132+
channel_id=self.channel_id,
133+
webhook_id=self.webhook_id,
134+
webhook_token=self.webhook_token,
135+
**self.parameters,
136+
)
138137

138+
def __eq__(self, route: 'Route') -> bool:
139+
return (
140+
route.channel_id == self.channel_id
141+
or route.guild_id == self.guild_id
142+
or route.webhook_id == self.webhook_id
143+
or route.webhook_token == self.webhook_token
144+
) and route.method == self.method
139145

140-
class MaybeUnlock:
141-
def __init__(self, lock: asyncio.Lock) -> None:
142-
self.lock: asyncio.Lock = lock
143-
self._unlock: bool = True
144146

145-
def __enter__(self: MU) -> MU:
146-
return self
147147

148-
def defer(self) -> None:
149-
self._unlock = False
148+
class Executor:
149+
def __init__(self, route: Route) -> None:
150+
self.route = route
151+
self.is_global: bool | None = None
152+
self._request_queue: asyncio.Queue[asyncio.Event] | None = None
153+
self.rate_limited: bool = False
150154

151-
def __exit__(
152-
self,
153-
exc_type: type[BE] | None,
154-
exc: BE | None,
155-
traceback: TracebackType | None,
155+
async def executed(
156+
self, reset_after: int | float, limit: int, is_global: bool
156157
) -> None:
157-
if self._unlock:
158-
self.lock.release()
158+
self.rate_limited = True
159+
self.is_global = is_global
160+
self._reset_after = reset_after
161+
self._request_queue = asyncio.Queue()
162+
163+
await asyncio.sleep(reset_after)
159164

165+
self.is_global = False
160166

161-
# For some reason, the Discord voice websocket expects this header to be
162-
# completely lowercase while aiohttp respects spec and does it as case-insensitive
163-
aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore
167+
# NOTE: This could break if someone did a second global rate limit somehow
168+
requests_passed: int = 0
169+
for _ in range(self._request_queue.qsize() - 1):
170+
if requests_passed == limit:
171+
requests_passed = 0
172+
if not is_global:
173+
await asyncio.sleep(reset_after)
174+
else:
175+
await asyncio.sleep(5)
176+
177+
requests_passed += 1
178+
e = await self._request_queue.get()
179+
e.set()
180+
181+
async def wait(self) -> None:
182+
if not self.rate_limited:
183+
return
184+
185+
event = asyncio.Event()
186+
187+
if self._request_queue:
188+
self._request_queue.put_nowait(event)
189+
else:
190+
raise ValueError(
191+
'Request queue does not exist, rate limit may have been solved.'
192+
)
193+
await event.wait()
164194

165195

166196
class HTTPClient:
@@ -174,20 +204,20 @@ def __init__(
174204
proxy_auth: aiohttp.BasicAuth | None = None,
175205
loop: asyncio.AbstractEventLoop | None = None,
176206
unsync_clock: bool = True,
207+
discord_api_url: str = "https://discord.com/api/v10"
177208
) -> None:
209+
self.api_url = discord_api_url
178210
self.loop: asyncio.AbstractEventLoop = (
179211
asyncio.get_event_loop() if loop is None else loop
180212
)
181213
self.connector = connector
182214
self.__session: aiohttp.ClientSession | utils.Undefined = MISSING # filled in static_login
183-
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
184-
self._global_over: asyncio.Event = asyncio.Event()
185-
self._global_over.set()
186215
self.token: str | None = None
187216
self.bot_token: bool = False
188217
self.proxy: str | None = proxy
189218
self.proxy_auth: aiohttp.BasicAuth | None = proxy_auth
190219
self.use_clock: bool = not unsync_clock
220+
self._executors: list[Executor] = []
191221

192222
user_agent = (
193223
"DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
@@ -226,15 +256,9 @@ async def request(
226256
form: Iterable[dict[str, Any]] | None = None,
227257
**kwargs: Any,
228258
) -> Any:
229-
bucket = route.bucket
259+
bucket = route.merge(self.api_url)
230260
method = route.method
231-
url = route.url
232-
233-
lock = self._locks.get(bucket)
234-
if lock is None:
235-
lock = asyncio.Lock()
236-
if bucket is not None:
237-
self._locks[bucket] = lock
261+
url = bucket
238262

239263
# header creation
240264
headers: dict[str, str] = {
@@ -266,123 +290,97 @@ async def request(
266290
if self.proxy_auth is not None:
267291
kwargs["proxy_auth"] = self.proxy_auth
268292

269-
if not self._global_over.is_set():
270-
# wait until the global lock is complete
271-
await self._global_over.wait()
272-
273293
response: aiohttp.ClientResponse | None = None
274294
data: dict[str, Any] | str | None = None
275-
await lock.acquire()
276-
with MaybeUnlock(lock) as maybe_lock:
277-
for tries in range(5):
278-
if files:
279-
for f in files:
280-
f.reset(seek=tries)
281-
282-
if form:
283-
form_data = aiohttp.FormData(quote_fields=False)
284-
for params in form:
285-
form_data.add_field(**params)
286-
kwargs["data"] = form_data
287-
288-
try:
289-
async with self.__session.request(
290-
method, url, **kwargs
291-
) as response:
292-
_log.debug(
293-
"%s %s with %s has returned %s",
294-
method,
295-
url,
296-
kwargs.get("data"),
297-
response.status,
295+
296+
for executor in self._executors:
297+
if executor.is_global or executor.route == route:
298+
_log.debug(f'Pausing request to {route}: Found rate limit executor')
299+
await executor.wait()
300+
301+
for tries in range(5):
302+
if files:
303+
for f in files:
304+
f.reset(seek=tries)
305+
306+
if form:
307+
form_data = aiohttp.FormData(quote_fields=False)
308+
for params in form:
309+
form_data.add_field(**params)
310+
kwargs["data"] = form_data
311+
312+
try:
313+
async with self.__session.request(
314+
method, url, **kwargs
315+
) as response:
316+
_log.debug(
317+
"%s %s with %s has returned %s",
318+
method,
319+
url,
320+
kwargs.get("data"),
321+
response.status,
322+
)
323+
324+
# even errors have text involved in them so this is safe to call
325+
data = await json_or_text(response)
326+
327+
# check if we have rate limit header information
328+
remaining = response.headers.get("X-Ratelimit-Remaining")
329+
if remaining == "0" and response.status != 429:
330+
_log.debug(f'Request to {route} failed: Request returned rate limit')
331+
executor = Executor(route=route)
332+
333+
self._executors.append(executor)
334+
await executor.executed(
335+
# NOTE: 5 is just a placeholder since this should always be present
336+
reset_after=float(response.headers.get('X-RateLimit-Reset-After', "5")),
337+
is_global=response.headers.get('X-RateLimit-Scope') == 'global',
338+
limit=int(response.headers.get('X-RateLimit-Limit', 10)),
298339
)
340+
self._executors.remove(executor)
341+
continue
299342

300-
# even errors have text involved in them so this is safe to call
301-
data = await json_or_text(response)
302-
303-
# check if we have rate limit header information
304-
remaining = response.headers.get("X-Ratelimit-Remaining")
305-
if remaining == "0" and response.status != 429:
306-
# we've depleted our current bucket
307-
delta = utils._parse_ratelimit_header(
308-
response, use_clock=self.use_clock
309-
)
310-
_log.debug(
311-
(
312-
"A rate limit bucket has been exhausted (bucket:"
313-
" %s, retry: %s)."
314-
),
315-
bucket,
316-
delta,
317-
)
318-
maybe_lock.defer()
319-
self.loop.call_later(delta, lock.release)
320-
321-
# the request was successful so just return the text/json
322-
if 300 > response.status >= 200:
323-
_log.debug("%s %s has received %s", method, url, data)
324-
return data
325-
326-
# we are being rate limited
327-
if response.status == 429:
328-
if not response.headers.get("Via") or isinstance(data, str):
329-
# Banned by Cloudflare more than likely.
330-
raise HTTPException(response, data)
331-
332-
fmt = (
333-
"We are being rate limited. Retrying in %.2f seconds."
334-
' Handled under the bucket "%s"'
335-
)
336-
337-
# sleep a bit
338-
retry_after: float = data["retry_after"]
339-
_log.warning(fmt, retry_after, bucket)
340-
341-
# check if it's a global rate limit
342-
is_global = data.get("global", False)
343-
if is_global:
344-
_log.warning(
345-
(
346-
"Global rate limit has been hit. Retrying in"
347-
" %.2f seconds."
348-
),
349-
retry_after,
350-
)
351-
self._global_over.clear()
352-
353-
await asyncio.sleep(retry_after)
354-
_log.debug("Done sleeping for the rate limit. Retrying...")
355-
356-
# release the global lock now that the
357-
# global rate limit has passed
358-
if is_global:
359-
self._global_over.set()
360-
_log.debug("Global rate limit is now over.")
361-
362-
continue
363-
364-
# we've received a 500, 502, 503, or 504, unconditional retry
365-
if response.status in {500, 502, 503, 504}:
366-
await asyncio.sleep(1 + tries * 2)
367-
continue
368-
369-
# the usual error cases
370-
if response.status == 403:
371-
raise Forbidden(response, data)
372-
elif response.status == 404:
373-
raise NotFound(response, data)
374-
elif response.status >= 500:
375-
raise DiscordServerError(response, data)
376-
else:
377-
raise HTTPException(response, data)
378-
379-
# This is handling exceptions from the request
380-
except OSError as e:
381-
# Connection reset by peer
382-
if tries < 4 and e.errno in (54, 10054):
343+
# the request was successful so just return the text/json
344+
if 300 > response.status >= 200:
345+
_log.debug("%s %s has received %s", method, url, data)
346+
return data
347+
348+
# we are being rate limited
349+
if response.status == 429:
350+
_log.debug(f'Request to {route} failed: Request returned rate limit')
351+
executor = Executor(route=route)
352+
353+
self._executors.append(executor)
354+
await executor.executed(
355+
reset_after=data['retry_after'],
356+
is_global=response.headers.get('X-RateLimit-Scope') == 'global',
357+
limit=int(response.headers.get('X-RateLimit-Limit', 10)),
358+
)
359+
self._executors.remove(executor)
360+
continue
361+
362+
# we've received a 500, 502, 503, or 504, unconditional retry
363+
if response.status in {500, 502, 503, 504}:
383364
await asyncio.sleep(1 + tries * 2)
384365
continue
385-
raise
366+
367+
# the usual error cases
368+
if response.status == 403:
369+
raise Forbidden(response, data)
370+
elif response.status == 404:
371+
raise NotFound(response, data)
372+
elif response.status >= 500:
373+
raise DiscordServerError(response, data)
374+
else:
375+
raise HTTPException(response, data)
376+
377+
# This is handling exceptions from the request
378+
except OSError as e:
379+
# Connection reset by peer
380+
if tries < 4 and e.errno in (54, 10054):
381+
await asyncio.sleep(1 + tries * 2)
382+
continue
383+
raise
386384

387385
if response is not None:
388386
# We've run out of retries, raise.

0 commit comments

Comments
 (0)