Skip to content

Commit 037c996

Browse files
committed
SHOULD_RATELIMIT: pass along group & check group for rlm
1 parent d3b89cf commit 037c996

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

aikido_zen/background_process/commands/should_ratelimit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ def process_should_ratelimit(connection_manager, data, queue=None):
1616
remote_address=data["remote_address"],
1717
user=data["user"],
1818
connection_manager=connection_manager,
19+
group=data["group"],
1920
)

aikido_zen/middleware/init_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def test_cache_comms_with_endpoints():
145145
"url": "http://localhost:4000",
146146
},
147147
"user": {"id": "456"},
148+
"group": None,
148149
"remote_address": "::1",
149150
},
150151
receive=True,

aikido_zen/middleware/should_block_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def should_block_request():
5252
obj={
5353
"route_metadata": route_metadata,
5454
"user": context.user,
55+
"group": context.rate_limit_group,
5556
"remote_address": context.remote_address,
5657
},
5758
receive=True,

aikido_zen/ratelimiting/__init__.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,69 @@
55
from .get_ratelimited_endpoint import get_ratelimited_endpoint
66

77

8-
def should_ratelimit_request(route_metadata, remote_address, user, connection_manager):
8+
def should_ratelimit_request(
9+
route_metadata, remote_address, user, connection_manager, group=None
10+
):
911
"""
10-
Checks if the request should be ratelimited or not
12+
Checks if the request should be rate-limited or not (checks user, group id & ip)
1113
route_metadata object includes route, url and method
1214
"""
1315
endpoints = connection_manager.conf.get_endpoints(route_metadata)
1416
endpoint = get_ratelimited_endpoint(endpoints, route_metadata["route"])
1517
if not endpoint:
1618
return {"block": False}
1719

20+
is_bypassed_ip = connection_manager.conf.is_bypassed_ip(remote_address)
21+
if is_bypassed_ip:
22+
return {"block": False}
23+
1824
max_requests = int(endpoint["rateLimiting"]["maxRequests"])
1925
windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"])
20-
is_bypassed_ip = connection_manager.conf.is_bypassed_ip(remote_address)
2126

22-
if is_bypassed_ip:
27+
if group:
28+
allowed = connection_manager.rate_limiter.is_allowed(
29+
get_key_for_group(endpoint, group),
30+
windows_size_in_ms,
31+
max_requests,
32+
)
33+
if not allowed:
34+
return {"block": True, "trigger": "group"}
35+
36+
# Do not check IP or user rate limit if group is set
2337
return {"block": False}
2438
if user:
25-
uid = user["id"]
26-
method = endpoint.get("method")
27-
route = endpoint.get("route")
28-
2939
allowed = connection_manager.rate_limiter.is_allowed(
30-
f"{method}:{route}:user:{uid}",
40+
get_key_for_user(endpoint, user),
3141
windows_size_in_ms,
3242
max_requests,
3343
)
3444
if not allowed:
3545
return {"block": True, "trigger": "user"}
3646
# Do not check IP rate limit if user is set
3747
return {"block": False}
38-
3948
if remote_address:
40-
method = endpoint.get("method")
41-
route = endpoint.get("route")
42-
4349
allowed = connection_manager.rate_limiter.is_allowed(
44-
f"{method}:{route}:ip:{remote_address}",
50+
get_key_for_ip(endpoint, remote_address),
4551
windows_size_in_ms,
4652
max_requests,
4753
)
4854
if not allowed:
4955
return {"block": True, "trigger": "ip"}
5056

5157
return {"block": False}
58+
59+
60+
def get_key_for_group(endpoint, group_id):
61+
method, route = endpoint.get("method"), endpoint.get("route")
62+
return f"{method}:{route}:group:{group_id}"
63+
64+
65+
def get_key_for_user(endpoint, user):
66+
method, route = endpoint.get("method"), endpoint.get("route")
67+
user_id = user.get("id")
68+
return f"{method}:{route}:user:{user_id}"
69+
70+
71+
def get_key_for_ip(endpoint, remote_address):
72+
method, route = endpoint.get("method"), endpoint.get("route")
73+
return f"{method}:{route}:ip:{remote_address}"

0 commit comments

Comments
 (0)