Skip to content

Commit 848235a

Browse files
committed
Merge branch 'main' into qa-tests
2 parents 8f04c1c + acf1166 commit 848235a

17 files changed

+463
-586
lines changed

aikido_zen/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Re-export functions :
1010
from aikido_zen.context.users import set_user
1111
from aikido_zen.helpers.check_gevent import check_gevent
12+
from aikido_zen.helpers.python_version_not_supported import python_version_not_supported
1213
from aikido_zen.middleware import should_block_request
1314
from aikido_zen.middleware.set_rate_limit_group import set_rate_limit_group
1415

@@ -34,6 +35,8 @@ def protect(mode="daemon", token=""):
3435
if aikido_disabled_flag_active():
3536
# Do not run any aikido code when the disabled flag is on
3637
return
38+
if python_version_not_supported():
39+
return
3740
if not test_uds_file_access():
3841
return # Unable to start background process
3942
if check_gevent():
@@ -71,6 +74,7 @@ def protect(mode="daemon", token=""):
7174

7275
import aikido_zen.sinks.builtins
7376
import aikido_zen.sinks.os
77+
import aikido_zen.sinks.pathlib
7478
import aikido_zen.sinks.shutil
7579
import aikido_zen.sinks.io
7680
import aikido_zen.sinks.http_client
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import sys
2+
from aikido_zen.helpers.logging import logger
3+
4+
5+
def python_version_not_supported() -> bool:
6+
major = sys.version_info.major
7+
minor = sys.version_info.minor
8+
if major != 3:
9+
logger.error("This version of Zen only supports Python 3")
10+
return True
11+
if minor > 13:
12+
logger.error("This version of Zen doesn't support versions above Python 3.13")
13+
return True
14+
return False

aikido_zen/middleware/init_test.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from unittest.mock import patch, MagicMock
22

33
import pytest
4-
from aikido_zen.context import current_context, Context, get_current_context
5-
from aikido_zen.thread.thread_cache import ThreadCache, get_cache
4+
import aikido_zen.test_utils as test_utils
5+
from aikido_zen.context import current_context, get_current_context
6+
from aikido_zen.thread.thread_cache import get_cache
67
from . import should_block_request
78
from .. import set_rate_limit_group
89

@@ -22,35 +23,14 @@ def test_without_context():
2223
assert should_block_request() == {"block": False}
2324

2425

25-
def set_context(user=None, executed_middleware=False):
26-
Context(
27-
context_obj={
28-
"remote_address": "::1",
29-
"method": "POST",
30-
"url": "http://localhost:4000",
31-
"query": {
32-
"abc": "def",
33-
},
34-
"headers": {},
35-
"body": None,
36-
"cookies": {},
37-
"source": "flask",
38-
"route": "/posts/:id",
39-
"user": user,
40-
"rate_limit_group": None,
41-
"executed_middleware": executed_middleware,
42-
}
43-
).set_as_current_context()
44-
45-
4626
def test_with_context_without_cache():
47-
set_context()
27+
test_utils.generate_and_set_context()
4828
get_cache().cache = None
4929
assert should_block_request() == {"block": False}
5030

5131

5232
def test_with_context_with_cache():
53-
set_context(user={"id": "123"})
33+
test_utils.generate_and_set_context(user={"id": "123"})
5434
thread_cache = get_cache()
5535

5636
thread_cache.config.blocked_uids = ["123"]
@@ -76,7 +56,7 @@ def test_with_context_with_cache():
7656

7757

7858
def test_cache_comms_with_endpoints():
79-
set_context(user={"id": "456"})
59+
test_utils.generate_and_set_context(user={"id": "456"}, route="/posts/:id")
8060
set_rate_limit_group("my_group")
8161
thread_cache = get_cache()
8262
thread_cache.config.blocked_uids = ["123"]
@@ -145,11 +125,11 @@ def test_cache_comms_with_endpoints():
145125
"route_metadata": {
146126
"method": "POST",
147127
"route": "/posts/:id",
148-
"url": "http://localhost:4000",
128+
"url": "http://localhost:8080/",
149129
},
150130
"user": {"id": "456"},
151131
"group": "my_group",
152-
"remote_address": "::1",
132+
"remote_address": "1.1.1.1",
153133
},
154134
receive=True,
155135
timeout_in_sec=0.01,
@@ -168,7 +148,7 @@ def test_cache_comms_with_endpoints():
168148
assert thread_cache.stats.rate_limited_hits == 0
169149
assert should_block_request() == {
170150
"block": True,
171-
"ip": "::1",
151+
"ip": "1.1.1.1",
172152
"type": "ratelimited",
173153
"trigger": "my_trigger",
174154
}

aikido_zen/middleware/set_rate_limit_group_test.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
2-
from aikido_zen.context import get_current_context, Context
32
from aikido_zen.thread.thread_cache import get_cache
3+
import aikido_zen.test_utils as test_utils
44
from .set_rate_limit_group import set_rate_limit_group
55

66

@@ -15,31 +15,8 @@ def run_around_tests():
1515
get_cache().reset()
1616

1717

18-
def set_context_and_lifecycle():
19-
wsgi_request = {
20-
"REQUEST_METHOD": "GET",
21-
"HTTP_HEADER_1": "header 1 value",
22-
"HTTP_HEADER_2": "Header 2 value",
23-
"RANDOM_VALUE": "Random value",
24-
"HTTP_COOKIE": "sessionId=abc123xyz456;",
25-
"wsgi.url_scheme": "http",
26-
"HTTP_HOST": "localhost:8080",
27-
"PATH_INFO": "/hello",
28-
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
29-
"CONTENT_TYPE": "application/json",
30-
"REMOTE_ADDR": "198.51.100.23",
31-
}
32-
context = Context(
33-
req=wsgi_request,
34-
body=None,
35-
source="flask",
36-
)
37-
context.set_as_current_context()
38-
return context
39-
40-
4118
def test_set_rate_limit_group_valid_group_id(caplog):
42-
context1 = set_context_and_lifecycle()
19+
context1 = test_utils.generate_and_set_context()
4320
assert context1.rate_limit_group is None
4421
set_rate_limit_group("group1")
4522
assert context1.rate_limit_group == "group1"
@@ -49,15 +26,15 @@ def test_set_rate_limit_group_valid_group_id(caplog):
4926

5027

5128
def test_set_rate_limit_group_empty_group_id(caplog):
52-
context1 = set_context_and_lifecycle()
29+
context1 = test_utils.generate_and_set_context()
5330
assert context1.rate_limit_group is None
5431
set_rate_limit_group("")
5532
assert context1.rate_limit_group is None
5633
assert "Group ID cannot be empty." in caplog.text
5734

5835

5936
def test_set_rate_limit_group_none_group_id(caplog):
60-
context1 = set_context_and_lifecycle()
37+
context1 = test_utils.generate_and_set_context()
6138
assert context1.rate_limit_group is None
6239
set_rate_limit_group(None)
6340
assert context1.rate_limit_group is None
@@ -73,30 +50,30 @@ def test_set_rate_limit_group_no_context(caplog):
7350

7451

7552
def test_set_rate_limit_group_middleware_already_executed(caplog):
76-
context1 = set_context_and_lifecycle()
53+
context1 = test_utils.generate_and_set_context()
7754
context1.executed_middleware = True
7855
set_rate_limit_group("group1")
7956
assert "must be called before the Zen middleware is executed" in caplog.text
8057
assert context1.rate_limit_group is "group1"
8158

8259

8360
def test_set_rate_limit_group_non_string_group_id(caplog):
84-
context1 = set_context_and_lifecycle()
61+
context1 = test_utils.generate_and_set_context()
8562
assert context1.rate_limit_group is None
8663
set_rate_limit_group(123)
8764
assert context1.rate_limit_group == "123"
8865

8966

9067
def test_set_rate_limit_group_non_string_group_id_non_number(caplog):
91-
context1 = set_context_and_lifecycle()
68+
context1 = test_utils.generate_and_set_context()
9269
assert context1.rate_limit_group is None
9370
set_rate_limit_group({"a": "b"})
9471
assert context1.rate_limit_group is None
9572
assert "Group ID must be a string or a number" in caplog.text
9673

9774

9875
def test_set_rate_limit_group_overwrite_existing_group():
99-
context1 = set_context_and_lifecycle()
76+
context1 = test_utils.generate_and_set_context()
10077
assert context1.rate_limit_group is None
10178
set_rate_limit_group("group1")
10279
assert context1.rate_limit_group == "group1"

aikido_zen/sinks/builtins_import.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from aikido_zen.sinks import on_import, patch_function, after
2+
import contextvars
13
import importlib.metadata
24
from importlib.metadata import PackageNotFoundError
3-
45
from aikido_zen.background_process.packages import PackagesStore
5-
from aikido_zen.sinks import on_import, patch_function, after
6+
7+
running_import_scan = contextvars.ContextVar("running_import_scan", default=False)
68

79

810
@after
@@ -12,26 +14,34 @@ def _import(func, instance, args, kwargs, return_value):
1214

1315
if not hasattr(return_value, "__package__"):
1416
return
15-
name = getattr(return_value, "__package__")
16-
17-
if not name:
18-
# Make sure the name exists
19-
return
20-
name = name.split(".")[0] # Remove submodules
21-
if name == "importlib" or name == "importlib_metadata":
22-
# Avoid circular dependencies
23-
return
24-
25-
if PackagesStore.get_package(name):
26-
return
2717

28-
version = None
2918
try:
30-
version = importlib.metadata.version(name)
31-
except PackageNotFoundError:
32-
pass
33-
if version:
34-
PackagesStore.add_package(name, version)
19+
if running_import_scan.get():
20+
return
21+
running_import_scan.set(True)
22+
23+
name = getattr(return_value, "__package__")
24+
25+
if not name:
26+
# Make sure the name exists
27+
return
28+
name = name.split(".")[0] # Remove submodules
29+
if name == "importlib" or name == "importlib_metadata":
30+
# Avoid circular dependencies, this is a double safety-check for if contextvar check fails.
31+
return
32+
33+
if PackagesStore.get_package(name):
34+
return
35+
36+
version = None
37+
try:
38+
version = importlib.metadata.version(name)
39+
except PackageNotFoundError:
40+
pass
41+
if version:
42+
PackagesStore.add_package(name, version)
43+
finally:
44+
running_import_scan.set(False)
3545

3646

3747
@on_import("builtins")

aikido_zen/sinks/pathlib.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Sink module for python's `pathlib`
3+
"""
4+
5+
import aikido_zen.vulnerabilities as vulns
6+
from aikido_zen.helpers.get_argument import get_argument
7+
from aikido_zen.helpers.register_call import register_call
8+
from aikido_zen.sinks import before, patch_function, on_import
9+
10+
11+
@before
12+
def _pathlib_truediv_patch(func, instance, args, kwargs):
13+
path = get_argument(args, kwargs, 0, "key")
14+
op = "pathlib.PurePath.__truediv__"
15+
register_call(op, "fs_op")
16+
17+
vulns.run_vulnerability_scan(kind="path_traversal", op=op, args=(path,))
18+
19+
20+
@on_import("pathlib")
21+
def patch(m):
22+
"""
23+
patching module pathlib
24+
- patches PurePath.__truediv__ : Path() / Path() -> join operation
25+
"""
26+
27+
# PurePath() / "my/path/test.txt"
28+
# This is accomplished by overloading the __truediv__ function on the Path class
29+
patch_function(m, "PurePath.__truediv__", _pathlib_truediv_patch)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
3+
import aikido_zen
4+
5+
aikido_zen.protect()
6+
7+
from aikido_zen.background_process.packages import PackagesStore
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def run_around_tests():
12+
PackagesStore.clear()
13+
14+
15+
def test_flask_import():
16+
import flask
17+
18+
assert PackagesStore.get_package("flask")["version"] == "3.0.3"
19+
20+
21+
def test_django_import():
22+
import django
23+
24+
assert PackagesStore.get_package("django")["version"] == "4.0"
25+
26+
27+
def test_recursive_package_store(monkeypatch):
28+
"""Test that recursive imports during package scanning don't cause max recursion depth errors."""
29+
30+
def recursive_get_package(name):
31+
"""Recursively add package and its dependencies to PackagesStore."""
32+
import flask
33+
34+
PackagesStore.clear()
35+
monkeypatch.setattr(PackagesStore, "get_package", recursive_get_package)
36+
37+
import flask
38+
39+
# Restore the original method after the test
40+
monkeypatch.undo()
41+
42+
43+
def test_recursive_package_store_2(monkeypatch):
44+
"""Test that recursive imports during package scanning don't cause max recursion depth errors."""
45+
46+
def recursive_add_package(name, version):
47+
"""Recursively add package and its dependencies to PackagesStore."""
48+
if name == "django":
49+
import django
50+
51+
PackagesStore.clear()
52+
monkeypatch.setattr(PackagesStore, "add_package", recursive_add_package)
53+
import django
54+
55+
# Restore the original method after the test
56+
monkeypatch.undo()

0 commit comments

Comments
 (0)