Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aikido_zen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Re-export functions :
from aikido_zen.context.users import set_user
from aikido_zen.middleware import should_block_request
from aikido_zen.middleware.set_rate_limit_group import set_rate_limit_group

# Import logger
from aikido_zen.helpers.logging import logger
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/background_process/commands/should_ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ def process_should_ratelimit(connection_manager, data, queue=None):
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager,
group=data["group"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_process_should_ratelimit(should_ratelimit, expected_call):
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": "123",
}

with patch(
Expand All @@ -37,6 +38,7 @@ def test_process_should_ratelimit(should_ratelimit, expected_call):
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager,
group="123",
)


Expand All @@ -51,6 +53,7 @@ def test_process_should_ratelimit_no_connection_manager():
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": None,
}

# Act
Expand All @@ -72,6 +75,7 @@ def test_process_should_ratelimit_multiple_calls():
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": None,
}

with patch(
Expand All @@ -88,6 +92,7 @@ def test_process_should_ratelimit_multiple_calls():
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager,
group=None,
)


Expand All @@ -104,6 +109,7 @@ def test_process_should_ratelimit_with_different_connection_manager():
},
"remote_address": "192.168.1.1",
"user": {"id": 1, "name": "Test User"},
"group": None,
}

with patch(
Expand All @@ -120,10 +126,12 @@ def test_process_should_ratelimit_with_different_connection_manager():
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager1,
group=None,
)
mock_should_ratelimit.assert_any_call(
route_metadata=data["route_metadata"],
remote_address=data["remote_address"],
user=data["user"],
connection_manager=connection_manager2,
group=None,
)
2 changes: 2 additions & 0 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
# Define emtpy variables/Properties :
self.source = source
self.user = None
self.rate_limit_group = None
self.parsed_userinput = {}
self.xml = {}
self.outgoing_req_redirects = []
Expand Down Expand Up @@ -84,6 +85,7 @@ def __reduce__(self):
"route": self.route,
"subdomains": self.subdomains,
"user": self.user,
"rate_limit_group": self.rate_limit_group,
"xml": self.xml,
"outgoing_req_redirects": self.outgoing_req_redirects,
"executed_middleware": self.executed_middleware,
Expand Down
2 changes: 2 additions & 0 deletions aikido_zen/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_wsgi_context_1():
"route": "/hello",
"subdomains": [],
"user": None,
"rate_limit_group": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_wsgi_context_2():
"route": "/hello",
"subdomains": [],
"user": None,
"rate_limit_group": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
Expand Down
4 changes: 4 additions & 0 deletions aikido_zen/middleware/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aikido_zen.context import current_context, Context, get_current_context
from aikido_zen.thread.thread_cache import ThreadCache, get_cache
from . import should_block_request
from .. import set_rate_limit_group


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -36,6 +37,7 @@ def set_context(user=None, executed_middleware=False):
"source": "flask",
"route": "/posts/:id",
"user": user,
"rate_limit_group": None,
"executed_middleware": executed_middleware,
}
).set_as_current_context()
Expand Down Expand Up @@ -75,6 +77,7 @@ def test_with_context_with_cache():

def test_cache_comms_with_endpoints():
set_context(user={"id": "456"})
set_rate_limit_group("my_group")
thread_cache = get_cache()
thread_cache.config.blocked_uids = ["123"]
thread_cache.config.endpoints = [
Expand Down Expand Up @@ -145,6 +148,7 @@ def test_cache_comms_with_endpoints():
"url": "http://localhost:4000",
},
"user": {"id": "456"},
"group": "my_group",
"remote_address": "::1",
},
receive=True,
Expand Down
31 changes: 31 additions & 0 deletions aikido_zen/middleware/set_rate_limit_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Union

from aikido_zen.context import get_current_context
from aikido_zen.helpers.logging import logger


def set_rate_limit_group(group_id: Union[str, int]):
if not group_id:
logger.warning("Group ID cannot be empty.")
return

# Check if it's string of number, ensure string.
if not isinstance(group_id, str) and not isinstance(group_id, int):
logger.warning("Group ID must be a string or a number")
return
group_id = str(group_id)

context = get_current_context()
if not context:
logger.warning(
"set_rate_limit_group(...) was called without a context. Make sure to call set_rate_limit_group(...) within an HTTP request."
)
return

if context.executed_middleware:
logger.warning(
"set_rate_limit_group(...) must be called before the Zen middleware is executed."
)

context.rate_limit_group = group_id
context.set_as_current_context()
104 changes: 104 additions & 0 deletions aikido_zen/middleware/set_rate_limit_group_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
from aikido_zen.context import get_current_context, Context
from aikido_zen.thread.thread_cache import get_cache
from .set_rate_limit_group import set_rate_limit_group


@pytest.fixture(autouse=True)
def run_around_tests():
get_cache().reset()
yield
# Reset context and cache after every test
from aikido_zen.context import current_context

current_context.set(None)
get_cache().reset()


def set_context_and_lifecycle():
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(
req=wsgi_request,
body=None,
source="flask",
)
context.set_as_current_context()
return context


def test_set_rate_limit_group_valid_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group("group1")
assert context1.rate_limit_group == "group1"
assert "Group ID cannot be empty." not in caplog.text
assert "was called without a context" not in caplog.text
assert "must be called before the Zen middleware is executed" not in caplog.text


def test_set_rate_limit_group_empty_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group("")
assert context1.rate_limit_group is None
assert "Group ID cannot be empty." in caplog.text


def test_set_rate_limit_group_none_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group(None)
assert context1.rate_limit_group is None
assert "Group ID cannot be empty." in caplog.text


def test_set_rate_limit_group_no_context(caplog):
from aikido_zen.context import current_context

current_context.set(None)
set_rate_limit_group("group1")
assert "was called without a context" in caplog.text


def test_set_rate_limit_group_middleware_already_executed(caplog):
context1 = set_context_and_lifecycle()
context1.executed_middleware = True
set_rate_limit_group("group1")
assert "must be called before the Zen middleware is executed" in caplog.text
assert context1.rate_limit_group is "group1"


def test_set_rate_limit_group_non_string_group_id(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group(123)
assert context1.rate_limit_group == "123"


def test_set_rate_limit_group_non_string_group_id_non_number(caplog):
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group({"a": "b"})
assert context1.rate_limit_group is None
assert "Group ID must be a string or a number" in caplog.text


def test_set_rate_limit_group_overwrite_existing_group():
context1 = set_context_and_lifecycle()
assert context1.rate_limit_group is None
set_rate_limit_group("group1")
assert context1.rate_limit_group == "group1"
set_rate_limit_group("group2")
assert context1.rate_limit_group == "group2"
21 changes: 14 additions & 7 deletions aikido_zen/middleware/should_block_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,28 @@
def should_block_request():
"""
Checks for rate-limiting and checks if the current user is blocked.
Returns: {
block: boolean,
data: {
type: "blocked" | "ratelimited",
trigger: "ip" | "user" | "group",
ip?: string
}
}
"""
try:
context = get_current_context()
cache = get_cache()
if not context or not cache:
return {"block": False}

context.executed_middleware = (
True # Update context with middleware execution set to true
)
context.set_as_current_context()

# Make sure we set middleware installed to true (reports back to core) :
# These indicators allow us to check in core whether the middleware is installed correctly,
# and to display a warning when a user or group is set after this has run.
cache.middleware_installed = True
context.executed_middleware = True
context.set_as_current_context()

# Blocked users:
# User blocking allows customers to easily take action when attacks are coming from specific accounts
if context.user and cache.is_user_blocked(context.user["id"]):
return {"block": True, "type": "blocked", "trigger": "user"}

Expand All @@ -46,6 +52,7 @@ def should_block_request():
obj={
"route_metadata": route_metadata,
"user": context.user,
"group": context.rate_limit_group,
"remote_address": context.remote_address,
},
receive=True,
Expand Down
Loading
Loading