Skip to content
38 changes: 9 additions & 29 deletions aikido_zen/middleware/init_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from unittest.mock import patch, MagicMock

import pytest
from aikido_zen.context import current_context, Context, get_current_context
from aikido_zen.thread.thread_cache import ThreadCache, get_cache
import aikido_zen.test_utils as test_utils
from aikido_zen.context import current_context, get_current_context
from aikido_zen.thread.thread_cache import get_cache
from . import should_block_request
from .. import set_rate_limit_group

Expand All @@ -22,35 +23,14 @@ def test_without_context():
assert should_block_request() == {"block": False}


def set_context(user=None, executed_middleware=False):
Context(
context_obj={
"remote_address": "::1",
"method": "POST",
"url": "http://localhost:4000",
"query": {
"abc": "def",
},
"headers": {},
"body": None,
"cookies": {},
"source": "flask",
"route": "/posts/:id",
"user": user,
"rate_limit_group": None,
"executed_middleware": executed_middleware,
}
).set_as_current_context()


def test_with_context_without_cache():
set_context()
test_utils.generate_and_set_context()
get_cache().cache = None
assert should_block_request() == {"block": False}


def test_with_context_with_cache():
set_context(user={"id": "123"})
test_utils.generate_and_set_context(user={"id": "123"})
thread_cache = get_cache()

thread_cache.config.blocked_uids = ["123"]
Expand All @@ -76,7 +56,7 @@ def test_with_context_with_cache():


def test_cache_comms_with_endpoints():
set_context(user={"id": "456"})
test_utils.generate_and_set_context(user={"id": "456"}, route="/posts/:id")
set_rate_limit_group("my_group")
thread_cache = get_cache()
thread_cache.config.blocked_uids = ["123"]
Expand Down Expand Up @@ -145,11 +125,11 @@ def test_cache_comms_with_endpoints():
"route_metadata": {
"method": "POST",
"route": "/posts/:id",
"url": "http://localhost:4000",
"url": "http://localhost:8080/",
},
"user": {"id": "456"},
"group": "my_group",
"remote_address": "::1",
"remote_address": "1.1.1.1",
},
receive=True,
timeout_in_sec=0.01,
Expand All @@ -168,7 +148,7 @@ def test_cache_comms_with_endpoints():
assert thread_cache.stats.rate_limited_hits == 0
assert should_block_request() == {
"block": True,
"ip": "::1",
"ip": "1.1.1.1",
"type": "ratelimited",
"trigger": "my_trigger",
}
Expand Down
39 changes: 8 additions & 31 deletions aikido_zen/middleware/set_rate_limit_group_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from aikido_zen.context import get_current_context, Context
from aikido_zen.thread.thread_cache import get_cache
import aikido_zen.test_utils as test_utils
from .set_rate_limit_group import set_rate_limit_group


Expand All @@ -15,31 +15,8 @@ def run_around_tests():
get_cache().reset()


def set_context_and_lifecycle():
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(
req=wsgi_request,
body=None,
source="flask",
)
context.set_as_current_context()
return context


def test_set_rate_limit_group_valid_group_id(caplog):
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
assert context1.rate_limit_group is None
set_rate_limit_group("group1")
assert context1.rate_limit_group == "group1"
Expand All @@ -49,15 +26,15 @@ def test_set_rate_limit_group_valid_group_id(caplog):


def test_set_rate_limit_group_empty_group_id(caplog):
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
assert context1.rate_limit_group is None
set_rate_limit_group("")
assert context1.rate_limit_group is None
assert "Group ID cannot be empty." in caplog.text


def test_set_rate_limit_group_none_group_id(caplog):
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
assert context1.rate_limit_group is None
set_rate_limit_group(None)
assert context1.rate_limit_group is None
Expand All @@ -73,30 +50,30 @@ def test_set_rate_limit_group_no_context(caplog):


def test_set_rate_limit_group_middleware_already_executed(caplog):
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
context1.executed_middleware = True
set_rate_limit_group("group1")
assert "must be called before the Zen middleware is executed" in caplog.text
assert context1.rate_limit_group is "group1"


def test_set_rate_limit_group_non_string_group_id(caplog):
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
assert context1.rate_limit_group is None
set_rate_limit_group(123)
assert context1.rate_limit_group == "123"


def test_set_rate_limit_group_non_string_group_id_non_number(caplog):
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
assert context1.rate_limit_group is None
set_rate_limit_group({"a": "b"})
assert context1.rate_limit_group is None
assert "Group ID must be a string or a number" in caplog.text


def test_set_rate_limit_group_overwrite_existing_group():
context1 = set_context_and_lifecycle()
context1 = test_utils.generate_and_set_context()
assert context1.rate_limit_group is None
set_rate_limit_group("group1")
assert context1.rate_limit_group == "group1"
Expand Down
45 changes: 11 additions & 34 deletions aikido_zen/sinks/tests/clickhouse_driver_test.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,13 @@
import aikido_zen.sinks.clickhouse_driver
import pytest
from aikido_zen.background_process import reset_comms
from aikido_zen.context import Context
from aikido_zen.errors import AikidoSQLInjection


class Context1(Context):
def __init__(self, body):
self.cookies = {}
self.headers = {}
self.remote_address = "1.1.1.1"
self.method = "POST"
self.url = "url"
self.query = {}
self.body = body
self.source = "express"
self.route = "/"
self.parsed_userinput = {}
self.protection_forced_off = False
import aikido_zen.test_utils as test_utils


@pytest.fixture(autouse=True)
def set_blocking_to_true(monkeypatch):
def setup(monkeypatch):
reset_comms()
monkeypatch.setenv("AIKIDO_BLOCK", "1")


Expand All @@ -35,25 +21,22 @@ def client():


def test_client_execute_without_context(client):
reset_comms()
dog_name = "Steve"
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
client.execute(sql)


def test_client_execute_safe(client):
reset_comms()
dog_name = "Steve"
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)
client.execute(sql)


def test_client_execute_unsafe(client, monkeypatch):
reset_comms()
dog_name = "Malicious dog', 1); -- "
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)

with pytest.raises(AikidoSQLInjection):
client.execute(sql)
Expand All @@ -66,21 +49,19 @@ def test_cursor_execute_safe():
from clickhouse_driver import connect

conn = connect("clickhouse://localhost:9000")
reset_comms()
dog_name = "Steve"
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)
conn.cursor().execute(sql)


def test_cursor_execute_unsafe(monkeypatch):
from clickhouse_driver import connect

conn = connect("clickhouse://localhost:9000")
reset_comms()
dog_name = "Malicious dog', 1); -- "
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)

with pytest.raises(AikidoSQLInjection):
conn.cursor().execute(sql)
Expand All @@ -90,18 +71,16 @@ def test_cursor_execute_unsafe(monkeypatch):


def test_client_execute_with_progress_safe(client):
reset_comms()
dog_name = "Steve"
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)
client.execute_with_progress(sql)


def test_client_execute_with_progress_unsafe(client, monkeypatch):
reset_comms()
dog_name = "Malicious dog', 1); -- "
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)

with pytest.raises(AikidoSQLInjection):
client.execute_with_progress(sql)
Expand All @@ -111,18 +90,16 @@ def test_client_execute_with_progress_unsafe(client, monkeypatch):


def test_client_execute_iter_safe(client):
reset_comms()
dog_name = "Steve"
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)
client.execute_iter(sql)


def test_client_execute_iter_unsafe(client, monkeypatch):
reset_comms()
dog_name = "Malicious dog', 1); -- "
sql = "INSERT INTO dogs (dog_name, isAdmin) VALUES ('{}' , 0)".format(dog_name)
Context1({"dog_name": dog_name}).set_as_current_context()
test_utils.generate_and_set_context(value=dog_name)

with pytest.raises(AikidoSQLInjection):
client.execute_iter(sql)
Expand Down
41 changes: 41 additions & 0 deletions aikido_zen/test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from aikido_zen.context import Context
from aikido_zen.helpers.headers import Headers


def generate_and_set_context(*args, **kwargs) -> Context:
context = generate_context(*args, **kwargs)
context.set_as_current_context()
return context


def generate_context(value=None, query_value=None, user=None, route=None) -> Context:
context = MockTestContext()

if value is not None:
context.body["key1"] = value
if query_value is not None:
context.query["key1"] = query_value
if user is not None:
context.user = user
if route is not None:
context.route = route

return context


class MockTestContext(Context):
def __init__(self):
self.cookies = {}
self.headers = Headers()
self.remote_address = "1.1.1.1"
self.method = "POST"
self.url = "http://localhost:8080/"
self.body = {}
self.query = {}
self.source = "flask"
self.route = "/"
self.parsed_userinput = {}
self.user = None
self.rate_limit_group = None
self.executed_middleware = False
self.protection_forced_off = False
Loading