Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aikido_zen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Re-export functions :
from aikido_zen.context.users import set_user
from aikido_zen.middleware import should_block_request
from aikido_zen.middleware.set_rate_limit_group import set_rate_limit_group

# Import logger
from aikido_zen.helpers.logging import logger
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/background_process/commands/should_ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ def process_should_ratelimit(connection_manager, data, queue=None):
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager,
group=data["group"],
)
10 changes: 10 additions & 0 deletions aikido_zen/background_process/commands/should_ratelimit_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from tokenize import group

import pytest
from unittest.mock import Mock, patch
from .should_ratelimit import process_should_ratelimit
Expand All @@ -22,6 +24,7 @@ def test_process_should_ratelimit(should_ratelimit, expected_call):
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": "123",
}

with patch(
Expand All @@ -37,6 +40,7 @@ def test_process_should_ratelimit(should_ratelimit, expected_call):
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager,
group="123",
)


Expand All @@ -51,6 +55,7 @@ def test_process_should_ratelimit_no_connection_manager():
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": None,
}

# Act
Expand All @@ -72,6 +77,7 @@ def test_process_should_ratelimit_multiple_calls():
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": None,
}

with patch(
Expand All @@ -88,6 +94,7 @@ def test_process_should_ratelimit_multiple_calls():
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager,
group=None,
)


Expand All @@ -104,6 +111,7 @@ def test_process_should_ratelimit_with_different_connection_manager():
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": None,
}

with patch(
Expand All @@ -120,10 +128,12 @@ def test_process_should_ratelimit_with_different_connection_manager():
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager1,
group=None,
)
mock_should_ratelimit.assert_any_call(
route_metadata=data["route_metadata"],
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager2,
group=None,
)
2 changes: 2 additions & 0 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
# Define emtpy variables/Properties :
self.source = source
self.user = None
self.rate_limit_group = None
self.parsed_userinput = {}
self.xml = {}
self.outgoing_req_redirects = []
Expand Down Expand Up @@ -84,6 +85,7 @@ def __reduce__(self):
"route": self.route,
"subdomains": self.subdomains,
"user": self.user,
"rate_limit_group": self.rate_limit_group,
"xml": self.xml,
"outgoing_req_redirects": self.outgoing_req_redirects,
"executed_middleware": self.executed_middleware,
Expand Down
2 changes: 2 additions & 0 deletions aikido_zen/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_wsgi_context_1():
"route": "/hello",
"subdomains": [],
"user": None,
"rate_limit_group": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_wsgi_context_2():
"route": "/hello",
"subdomains": [],
"user": None,
"rate_limit_group": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
Expand Down
4 changes: 4 additions & 0 deletions aikido_zen/middleware/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aikido_zen.context import current_context, Context, get_current_context
from aikido_zen.thread.thread_cache import ThreadCache, get_cache
from . import should_block_request
from .. import set_rate_limit_group


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -36,6 +37,7 @@ def set_context(user=None, executed_middleware=False):
"source": "flask",
"route": "/posts/:id",
"user": user,
"rate_limit_group": None,
"executed_middleware": executed_middleware,
}
).set_as_current_context()
Expand Down Expand Up @@ -75,6 +77,7 @@ def test_with_context_with_cache():

def test_cache_comms_with_endpoints():
set_context(user={"id": "456"})
set_rate_limit_group("my_group")
thread_cache = get_cache()
thread_cache.config.blocked_uids = ["123"]
thread_cache.config.endpoints = [
Expand Down Expand Up @@ -145,6 +148,7 @@ def test_cache_comms_with_endpoints():
"url": "http://localhost:4000",
},
"user": {"id": "456"},
"group": "my_group",
"remote_address": "::1",
},
receive=True,
Expand Down
23 changes: 23 additions & 0 deletions aikido_zen/middleware/set_rate_limit_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from aikido_zen.context import get_current_context
from aikido_zen.helpers.logging import logger


def set_rate_limit_group(group_id: str):
if not group_id or not isinstance(group_id, str):
logger.warning("Group ID cannot be empty.")
return

context = get_current_context()
if not context:
logger.warning(
"set_rate_limit_group(...) was called without a context. Make sure to call set_rate_limit_group(...) within an HTTP request."
)
return

if context.executed_middleware:
logger.warning(
"set_rate_limit_group(...) must be called before the Zen middleware is executed."
)

context.rate_limit_group = group_id
context.set_as_current_context()
97 changes: 97 additions & 0 deletions aikido_zen/middleware/set_rate_limit_group_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest
from aikido_zen.context import get_current_context, Context
from aikido_zen.thread.thread_cache import get_cache
from .set_rate_limit_group import set_rate_limit_group


@pytest.fixture(autouse=True)
def run_around_tests():
get_cache().reset()
yield
# Reset context and cache after every test
from aikido_zen.context import current_context

current_context.set(None)
get_cache().reset()


def set_context_and_lifecycle():
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(
req=wsgi_request,
body=None,
source="flask",
)
context.set_as_current_context()
return context


def test_set_rate_limit_group_valid_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group("group1")
assert context1.rate_limit_group == "group1"
assert "Group ID cannot be empty." not in caplog.text
assert "was called without a context" not in caplog.text
assert "must be called before the Zen middleware is executed" not in caplog.text


def test_set_rate_limit_group_empty_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group("")
assert context1.rate_limit_group is None
assert "Group ID cannot be empty." in caplog.text


def test_set_rate_limit_group_none_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group(None)
assert context1.rate_limit_group is None
assert "Group ID cannot be empty." in caplog.text


def test_set_rate_limit_group_no_context(caplog):
from aikido_zen.context import current_context

current_context.set(None)
set_rate_limit_group("group1")
assert "was called without a context" in caplog.text


def test_set_rate_limit_group_middleware_already_executed(caplog):
context1 = set_context_and_lifecycle()
context1.executed_middleware = True
set_rate_limit_group("group1")
assert "must be called before the Zen middleware is executed" in caplog.text
assert context1.rate_limit_group is "group1"


def test_set_rate_limit_group_non_string_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group(123)
assert context1.rate_limit_group is None
assert "Group ID cannot be empty." in caplog.text


def test_set_rate_limit_group_overwrite_existing_group():
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group("group1")
assert context1.rate_limit_group == "group1"
set_rate_limit_group("group2")
assert context1.rate_limit_group == "group2"
21 changes: 14 additions & 7 deletions aikido_zen/middleware/should_block_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,28 @@
def should_block_request():
"""
Checks for rate-limiting and checks if the current user is blocked.
Returns: {
block: boolean,
data: {
type: "blocked" | "ratelimited",
trigger: "ip" | "user" | "group",
ip?: string
}
}
"""
try:
context = get_current_context()
cache = get_cache()
if not context or not cache:
return {"block": False}

context.executed_middleware = (
True # Update context with middleware execution set to true
)
context.set_as_current_context()

# Make sure we set middleware installed to true (reports back to core) :
# These indicators allow us to check in core whether the middleware is installed correctly,
# and to display a warning when a user or group is set after this has run.
cache.middleware_installed = True
context.executed_middleware = True
context.set_as_current_context()

# Blocked users:
# User blocking allows customers to easily take action when attacks are coming from specific accounts
if context.user and cache.is_user_blocked(context.user["id"]):
return {"block": True, "type": "blocked", "trigger": "user"}

Expand All @@ -46,6 +52,7 @@ def should_block_request():
obj={
"route_metadata": route_metadata,
"user": context.user,
"group": context.rate_limit_group,
"remote_address": context.remote_address,
},
receive=True,
Expand Down
50 changes: 36 additions & 14 deletions aikido_zen/ratelimiting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,69 @@
from .get_ratelimited_endpoint import get_ratelimited_endpoint


def should_ratelimit_request(route_metadata, remote_address, user, connection_manager):
def should_ratelimit_request(
route_metadata, remote_address, user, connection_manager, group=None
):
"""
Checks if the request should be ratelimited or not
Checks if the request should be rate-limited or not (checks user, group id & ip)
route_metadata object includes route, url and method
"""
endpoints = connection_manager.conf.get_endpoints(route_metadata)
endpoint = get_ratelimited_endpoint(endpoints, route_metadata["route"])
if not endpoint:
return {"block": False}

is_bypassed_ip = connection_manager.conf.is_bypassed_ip(remote_address)
if is_bypassed_ip:
return {"block": False}

max_requests = int(endpoint["rateLimiting"]["maxRequests"])
windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"])
is_bypassed_ip = connection_manager.conf.is_bypassed_ip(remote_address)

if is_bypassed_ip:
if group:
allowed = connection_manager.rate_limiter.is_allowed(
get_key_for_group(endpoint, group),
windows_size_in_ms,
max_requests,
)
if not allowed:
return {"block": True, "trigger": "group"}

# Do not check IP or user rate limit if group is set
return {"block": False}
if user:
uid = user["id"]
method = endpoint.get("method")
route = endpoint.get("route")

allowed = connection_manager.rate_limiter.is_allowed(
f"{method}:{route}:user:{uid}",
get_key_for_user(endpoint, user),
windows_size_in_ms,
max_requests,
)
if not allowed:
return {"block": True, "trigger": "user"}
# Do not check IP rate limit if user is set
return {"block": False}

if remote_address:
method = endpoint.get("method")
route = endpoint.get("route")

allowed = connection_manager.rate_limiter.is_allowed(
f"{method}:{route}:ip:{remote_address}",
get_key_for_ip(endpoint, remote_address),
windows_size_in_ms,
max_requests,
)
if not allowed:
return {"block": True, "trigger": "ip"}

return {"block": False}


def get_key_for_group(endpoint, group_id):
method, route = endpoint.get("method"), endpoint.get("route")
return f"{method}:{route}:group:{group_id}"


def get_key_for_user(endpoint, user):
method, route = endpoint.get("method"), endpoint.get("route")
user_id = user.get("id")
return f"{method}:{route}:user:{user_id}"


def get_key_for_ip(endpoint, remote_address):
method, route = endpoint.get("method"), endpoint.get("route")
return f"{method}:{route}:ip:{remote_address}"
Loading
Loading