Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
self.headers: Headers = Headers()
self.cookies = dict()
self.query = dict()
self.protection_forced_off = None

# Parse WSGI/ASGI/... request :
self.method = self.remote_address = self.url = None
Expand Down Expand Up @@ -137,3 +138,6 @@ def get_route_metadata(self):

def get_user_agent(self):
return self.headers.get_header("USER_AGENT")

def set_force_protection_off(self, value: bool):
self.protection_forced_off = value
12 changes: 12 additions & 0 deletions aikido_zen/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_wsgi_context_1():
"outgoing_req_redirects": [],
"executed_middleware": False,
"route_params": [],
"protection_forced_off": None,
}
assert context.get_user_agent() is None

Expand Down Expand Up @@ -103,6 +104,7 @@ def test_wsgi_context_2():
"outgoing_req_redirects": [],
"executed_middleware": False,
"route_params": [],
"protection_forced_off": None,
}
assert context.get_user_agent() == "Mozilla/5.0"

Expand Down Expand Up @@ -284,3 +286,13 @@ def test_set_valid_json_with_special_characters_bytes():
context = Context(req=basic_wsgi_req, body=None, source="flask")
context.set_body(b'{"key": "value with special characters !@#$%^&*()"}')
assert context.body == {"key": "value with special characters !@#$%^&*()"}


def test_set_protection_forced_off():
context = Context(req=basic_wsgi_req, body=None, source="flask")
context.set_force_protection_off(True)
assert context.protection_forced_off is True
context.set_force_protection_off(False)
assert context.protection_forced_off is False
context.set_force_protection_off(None)
assert context.protection_forced_off is None
30 changes: 30 additions & 0 deletions aikido_zen/helpers/is_protection_forced_off_cached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from aikido_zen.thread.thread_cache import get_cache
from aikido_zen.helpers.protection_forced_off import protection_forced_off
from aikido_zen.context import Context


def is_protection_forced_off_cached(context: Context) -> bool:
"""
Check if protection is forced off using cached endpoints.
This function assumes that the thread cache has already been retrieved
and uses it to determine if protection is forced off for the given context.
"""
if not context:
return False

if context.protection_forced_off is not None:
# Retrieving from cache, we don't want to constantly go through
# all the endpoints for every single vulnerability check.
return context.protection_forced_off

thread_cache = get_cache()
if not thread_cache:
return False

is_forced_off = protection_forced_off(
context.get_route_metadata(), thread_cache.get_endpoints()
)
context.set_force_protection_off(is_forced_off)
context.set_as_current_context()

return is_forced_off
1 change: 1 addition & 0 deletions aikido_zen/sinks/tests/clickhouse_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, body):
self.source = "express"
self.route = "/"
self.parsed_userinput = {}
self.protection_forced_off = False


@pytest.fixture(autouse=True)
Expand Down
13 changes: 7 additions & 6 deletions aikido_zen/vulnerabilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from aikido_zen.helpers.logging import logger
from aikido_zen.helpers.get_clean_stacktrace import get_clean_stacktrace
from aikido_zen.helpers.blocking_enabled import is_blocking_enabled
from aikido_zen.helpers.protection_forced_off import protection_forced_off
from aikido_zen.helpers.is_protection_forced_off_cached import (
is_protection_forced_off_cached,
)
from aikido_zen.thread.thread_cache import get_cache
from .sql_injection.context_contains_sql_injection import context_contains_sql_injection
from .nosql_injection.check_context import check_context_for_nosql_injection
Expand All @@ -35,6 +37,10 @@ def run_vulnerability_scan(kind, op, args):
raises error if blocking is enabled, communicates it with connection_manager
"""
context = get_current_context()

if is_protection_forced_off_cached(context):
return

comms = comm.get_comms()
thread_cache = get_cache()
if not context and kind != "ssrf":
Expand All @@ -47,11 +53,6 @@ def run_vulnerability_scan(kind, op, args):
# This is because some scans/tests for SSRF do not require a thread cache to be set.
return
if thread_cache and context:
if protection_forced_off(
context.get_route_metadata(), thread_cache.get_endpoints()
):
# The client turned protection off for this route, not scanning
return
if thread_cache.is_bypassed_ip(context.remote_address):
# This IP is on the bypass list, not scanning
return
Expand Down
33 changes: 33 additions & 0 deletions aikido_zen/vulnerabilities/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,36 @@ def test_ssrf_vulnerability_scan_bypassed_ip(get_context):

# Verify that hostnames.add was not called due to bypassed IP
assert get_cache().hostnames.as_array() == []


def test_ssrf_vulnerability_scan_protection_gets_forced_off(get_context):
get_context.set_as_current_context()
get_cache().config.bypassed_ips = IPMatcher(["198.51.100.23"])

dns_results = MagicMock()
hostname = "example.com"
port = 80
assert get_context.protection_forced_off is None
run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port))
assert get_context.protection_forced_off is False


def test_sql_injection_with_protection_forced_off(caplog, get_context, monkeypatch):
get_context.set_as_current_context()
monkeypatch.setenv("AIKIDO_BLOCK", "1")
with patch("aikido_zen.background_process.comms.get_comms") as mock_get_comms:
# Create a mock comms object
mock_comms = MagicMock()
mock_get_comms.return_value = mock_comms # Set the return value of get_comms
with pytest.raises(AikidoSQLInjection):
run_vulnerability_scan(
kind="sql_injection",
op="test_op",
args=("INSERT * INTO VALUES ('doggoss2', TRUE);", "mysql"),
)
get_context.set_force_protection_off(True)
run_vulnerability_scan(
kind="sql_injection",
op="test_op",
args=("INSERT * INTO VALUES ('doggoss2', TRUE);", "mysql"),
)
8 changes: 8 additions & 0 deletions benchmarks/wrk_benchmark/flask_mysql_uwsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@
"a non empty route which makes a simulated request to a database",
percentage_limit=40
)


run_benchmark(
"http://localhost:8088/benchmark_io",
"http://localhost:8089/benchmark_io",
"a route that makes multiple I/O calls",
percentage_limit=35
)
1 change: 1 addition & 0 deletions sample-apps/flask-mysql-uwsgi/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
benchmark_temp.txt
4 changes: 2 additions & 2 deletions sample-apps/flask-mysql-uwsgi/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ runBenchmark: install
AIKIDO_DEBUG=false AIKIDO_BLOCK=true AIKIDO_TOKEN="AIK_secret_token" \
AIKIDO_REALTIME_ENDPOINT="http://localhost:5000/" \
AIKIDO_ENDPOINT="http://localhost:5000/" AIKIDO_DISABLE=0 \
poetry run uwsgi --ini uwsgi.ini
poetry run uwsgi --single-interpreter --ini uwsgi.ini

.PHONY: runZenDisabled
runZenDisabled: install
@echo "Running sample app flask-mysql-uwsgi without Zen on port 8089"
AIKIDO_DISABLE=1 \
poetry run uwsgi --ini uwsgi2.ini
poetry run uwsgi --single-interpreter --ini uwsgi2.ini
9 changes: 9 additions & 0 deletions sample-apps/flask-mysql-uwsgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,14 @@ def benchmark():
return "OK"


@app.route("/benchmark_io", methods=['GET'])
def benchmark_io():
for i in range(50):
with open("benchmark_temp.txt", "w") as f:
f.write("This is a benchmark file.")
with open("benchmark_temp.txt", "r") as f:
content = f.read()
return "OK"

if __name__ == '__main__':
app.run()
Loading