Skip to content

Commit fef0500

Browse files
authored
🐛 Fixes Retry-After header in global_rate_limit_route (ITISFoundation#3379)
1 parent 4fc72ac commit fef0500

File tree

2 files changed

+83
-42
lines changed

2 files changed

+83
-42
lines changed
Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import json
2+
from dataclasses import dataclass
23
from datetime import datetime, timedelta
34
from functools import wraps
5+
from math import ceil
6+
from typing import NamedTuple
47

5-
import attr
68
from aiohttp.web_exceptions import HTTPTooManyRequests
79

810

9-
def global_rate_limit_route(number_of_requests: int, interval_seconds: int):
11+
class RateLimitSetup(NamedTuple):
12+
number_of_requests: int
13+
interval_seconds: float
14+
15+
16+
def global_rate_limit_route(number_of_requests: int, interval_seconds: float):
1017
"""
1118
Limits the requests per given interval to this endpoint
1219
from all incoming sources.
@@ -19,14 +26,14 @@ def global_rate_limit_route(number_of_requests: int, interval_seconds: int):
1926
"""
2027

2128
# compute the amount of requests per
22-
def decorating_function(decorated_function):
23-
@attr.s(auto_attribs=True)
24-
class Context:
29+
def decorator(decorated_function):
30+
@dataclass
31+
class _Context:
2532
max_allowed: int # maximum allowed requests per interval
2633
remaining: int # remaining requests
27-
rate_limit_reset: int # utc timestamp
34+
rate_limit_reset: float # utc timestamp
2835

29-
context = Context(
36+
context = _Context(
3037
max_allowed=number_of_requests,
3138
remaining=number_of_requests,
3239
rate_limit_reset=0,
@@ -35,24 +42,24 @@ class Context:
3542
@wraps(decorated_function)
3643
async def wrapper(*args, **kwargs):
3744
utc_now = datetime.utcnow()
38-
current_utc_timestamp = datetime.timestamp(utc_now)
45+
utc_now_timestamp = datetime.timestamp(utc_now)
3946

4047
# reset counter & first time initialization
41-
if current_utc_timestamp >= context.rate_limit_reset:
48+
if utc_now_timestamp >= context.rate_limit_reset:
4249
context.rate_limit_reset = datetime.timestamp(
4350
utc_now + timedelta(seconds=interval_seconds)
4451
)
4552
context.remaining = context.max_allowed
4653

47-
if (
48-
current_utc_timestamp <= context.rate_limit_reset
49-
and context.remaining <= 0
50-
):
51-
# show error and return from here
54+
if utc_now_timestamp <= context.rate_limit_reset and context.remaining <= 0:
55+
# SEE https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/429
56+
retry_after_sec = int(
57+
ceil(context.rate_limit_reset - utc_now_timestamp)
58+
)
5259
raise HTTPTooManyRequests(
5360
headers={
5461
"Content-Type": "application/json",
55-
"Retry-After": str(int(context.rate_limit_reset)),
62+
"Retry-After": f"{retry_after_sec}",
5663
},
5764
text=json.dumps(
5865
{
@@ -68,7 +75,7 @@ async def wrapper(*args, **kwargs):
6875
context.remaining -= 1
6976
return await decorated_function(*args, **kwargs)
7077

71-
wrapper.rate_limit_setup = (number_of_requests, interval_seconds)
78+
wrapper.rate_limit_setup = RateLimitSetup(number_of_requests, interval_seconds)
7279
return wrapper
7380

74-
return decorating_function
81+
return decorator

services/web/server/tests/unit/isolated/test_utils_rate_limiting.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44

55
import asyncio
66
import time
7+
from typing import Callable
78

89
import pytest
910
from aiohttp import web
10-
from aiohttp.web_exceptions import HTTPTooManyRequests
11+
from aiohttp.test_utils import TestClient
12+
from aiohttp.web_exceptions import HTTPOk, HTTPTooManyRequests
13+
from pydantic import ValidationError, conint, parse_obj_as
1114
from simcore_service_webserver.utils_rate_limiting import global_rate_limit_route
1215

16+
TOTAL_TEST_TIME = 1 # secs
1317
MAX_NUM_REQUESTS = 3
1418
MEASURE_INTERVAL = 0.5
1519
MAX_REQUEST_RATE = MAX_NUM_REQUESTS / MEASURE_INTERVAL
@@ -22,35 +26,43 @@ async def get_ok_handler(_request: web.Request):
2226
return web.json_response({"value": 1})
2327

2428

25-
@pytest.mark.parametrize(
26-
"requests_per_second",
27-
[0.5 * MAX_REQUEST_RATE, MAX_REQUEST_RATE, 2 * MAX_REQUEST_RATE],
28-
)
29-
async def test_global_rate_limit_route(requests_per_second, aiohttp_client):
30-
#
29+
@pytest.fixture
30+
def client(event_loop, aiohttp_client: Callable) -> TestClient:
3131
app = web.Application()
3232
app.router.add_get("/", get_ok_handler)
3333

34-
client = await aiohttp_client(app)
35-
# ---
34+
return event_loop.run_until_complete(aiohttp_client(app))
3635

36+
37+
def test_rate_limit_route_decorator():
3738
# decorated function keeps setup
3839
assert get_ok_handler.rate_limit_setup == (MAX_NUM_REQUESTS, MEASURE_INTERVAL)
3940

41+
42+
@pytest.mark.parametrize(
43+
"requests_per_second",
44+
[0.5 * MAX_REQUEST_RATE, MAX_REQUEST_RATE, 2 * MAX_REQUEST_RATE],
45+
)
46+
async def test_global_rate_limit_route(requests_per_second: float, client: TestClient):
47+
# WARNING: this test has some timings and might fail when using breakpoints
48+
4049
# Creates desired stream of requests for 1 second
41-
TOTAL_TEST_TIME = 1 # secs
4250
num_requests = int(requests_per_second * TOTAL_TEST_TIME)
4351
time_between_requests = 1.0 / requests_per_second
4452

45-
futures = []
53+
tasks = []
4654
t0 = time.time()
47-
while len(futures) < num_requests:
55+
while len(tasks) < num_requests:
4856
t1 = time.time()
49-
futures.append(asyncio.create_task(client.get("/")))
50-
time.sleep(time_between_requests - (time.time() - t1))
57+
tasks.append(asyncio.create_task(client.get("/")))
58+
elapsed_on_creation = time.time() - t1 # ANE is really precise here ;-)
59+
60+
# NOTE: I am not sure why using asyncio.sleep here would make some tests fail the check "after"
61+
# await asyncio.sleep(time_between_requests - create_gap)
62+
time.sleep(time_between_requests - elapsed_on_creation)
5163

5264
elapsed = time.time() - t0
53-
count = len(futures)
65+
count = len(tasks)
5466
print(
5567
count,
5668
"requests in",
@@ -63,20 +75,42 @@ async def test_global_rate_limit_route(requests_per_second, aiohttp_client):
6375
assert count == num_requests
6476
assert elapsed == pytest.approx(TOTAL_TEST_TIME, abs=0.1)
6577

66-
for i, fut in enumerate(futures):
67-
while not fut.done():
68-
await asyncio.sleep(0.1)
69-
assert not fut.cancelled()
70-
assert not fut.exception()
71-
print("%2d" % i, fut.result().status)
72-
73-
expected_status = 200
78+
msg = []
79+
for i, task in enumerate(tasks):
80+
while not task.done():
81+
await asyncio.sleep(0.01)
82+
assert not task.cancelled()
83+
assert not task.exception()
84+
msg.append(
85+
(
86+
"request # %2d" % i,
87+
f"status={task.result().status}",
88+
f"retry-after={task.result().headers.get('Retry-After')}",
89+
)
90+
)
91+
print(*msg[-1])
92+
93+
expected_status = HTTPOk.status_code
7494

7595
# first requests are OK
76-
assert all(f.result().status == expected_status for f in futures[:MAX_NUM_REQUESTS])
96+
assert all(
97+
t.result().status == expected_status for t in tasks[:MAX_NUM_REQUESTS]
98+
), f" Failed with { msg[:MAX_NUM_REQUESTS]}"
7799

78100
if requests_per_second >= MAX_REQUEST_RATE:
79101
expected_status = HTTPTooManyRequests.status_code
80102

81103
# after ...
82-
assert all(f.result().status == expected_status for f in futures[MAX_NUM_REQUESTS:])
104+
assert all(
105+
t.result().status == expected_status for t in tasks[MAX_NUM_REQUESTS:]
106+
), f" Failed with { msg[MAX_NUM_REQUESTS:]}"
107+
108+
# checks Retry-After header
109+
failed = []
110+
for t in tasks:
111+
if retry_after := t.result().headers.get("Retry-After"):
112+
try:
113+
parse_obj_as(conint(ge=1), retry_after)
114+
except ValidationError as err:
115+
failed.append((retry_after, f"{err}"))
116+
assert not failed

0 commit comments

Comments
 (0)