diff --git a/.github/workflows/qa-tests.yml b/.github/workflows/qa-tests.yml index 277c7de03..21318a4f2 100644 --- a/.github/workflows/qa-tests.yml +++ b/.github/workflows/qa-tests.yml @@ -50,4 +50,4 @@ jobs: dockerfile_path: ./zen-demo-python/Dockerfile app_port: 8080 sleep_before_test: 10 - skip_tests: test_bypassed_ip_for_geo_blocking,test_demo_apps_generic_tests,test_path_traversal,test_wave_attack + skip_tests: test_bypassed_ip_for_geo_blocking,test_demo_apps_generic_tests,test_path_traversal diff --git a/aikido_zen/background_process/commands/sync_data_test.py b/aikido_zen/background_process/commands/sync_data_test.py index 32dad72dd..e5e787640 100644 --- a/aikido_zen/background_process/commands/sync_data_test.py +++ b/aikido_zen/background_process/commands/sync_data_test.py @@ -93,6 +93,7 @@ def test_process_sync_data_initialization(setup_connection_manager): assert connection_manager.statistics.get_record()["requests"] == { "aborted": 0, "attacksDetected": {"blocked": 0, "total": 5}, + "attackWaves": {"total": 0, "blocked": 0}, "total": 10, "rateLimited": 0, } @@ -168,6 +169,7 @@ def test_process_sync_data_with_last_updated_at_below_zero(setup_connection_mana assert connection_manager.statistics.get_record()["requests"] == { "aborted": 0, "attacksDetected": {"blocked": 0, "total": 5}, + "attackWaves": {"total": 0, "blocked": 0}, "total": 10, "rateLimited": 0, } @@ -255,6 +257,7 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager assert connection_manager.statistics.get_record()["requests"] == { "aborted": 0, "attacksDetected": {"blocked": 0, "total": 10}, + "attackWaves": {"total": 0, "blocked": 0}, "total": 20, "rateLimited": 0, } diff --git a/aikido_zen/helpers/create_attack_wave_event.py b/aikido_zen/helpers/create_attack_wave_event.py new file mode 100644 index 000000000..ad506588b --- /dev/null +++ b/aikido_zen/helpers/create_attack_wave_event.py @@ -0,0 +1,27 @@ +from aikido_zen.helpers.limit_length_metadata import limit_length_metadata +from aikido_zen.helpers.logging import logger + + +def create_attack_wave_event(context, metadata): + try: + return { + "type": "detected_attack_wave", + "attack": { + "user": getattr(context, "user", None), + "metadata": limit_length_metadata(metadata, 4096), + }, + "request": extract_request_if_possible(context), + } + except Exception as e: + logger.error("Failed to create detected_attack_wave API event: %s", str(e)) + return None + + +def extract_request_if_possible(context): + if not context: + return None + return { + "ipAddress": getattr(context, "remote_address", None), + "source": getattr(context, "source", None), + "userAgent": context.get_user_agent(), + } diff --git a/aikido_zen/helpers/create_attack_wave_event_test.py b/aikido_zen/helpers/create_attack_wave_event_test.py new file mode 100644 index 000000000..64e9a02cd --- /dev/null +++ b/aikido_zen/helpers/create_attack_wave_event_test.py @@ -0,0 +1,173 @@ +import pytest +from unittest.mock import MagicMock +from .create_attack_wave_event import ( + create_attack_wave_event, + extract_request_if_possible, +) +import aikido_zen.test_utils as test_utils + + +def test_create_attack_wave_event_success(): + """Test successful creation of attack wave event with basic data""" + metadata = {"test": "value"} + context = test_utils.generate_context() + + event = create_attack_wave_event(context, metadata) + + assert event is not None + assert event["type"] == "detected_attack_wave" + assert event["attack"]["user"] is None + assert event["attack"]["metadata"] == metadata + assert event["request"] is not None + + +def test_create_attack_wave_event_with_user(): + """Test attack wave event creation with user information""" + metadata = {"test": "value"} + context = test_utils.generate_context(user="test_user") + + event = create_attack_wave_event(context, metadata) + + assert event["attack"]["user"] == "test_user" + assert event["attack"]["metadata"] == metadata + + +def test_create_attack_wave_event_with_long_metadata(): + """Test that metadata longer than 4096 characters is truncated""" + long_metadata = "x" * 5000 # Create metadata longer than 4096 characters + metadata = {"test": long_metadata} + context = test_utils.generate_context() + + event = create_attack_wave_event(context, metadata) + + assert len(event["attack"]["metadata"]["test"]) == 4096 + assert event["attack"]["metadata"]["test"] == long_metadata[:4096] + + +def test_create_attack_wave_event_with_multiple_long_metadata_fields(): + """Test that multiple metadata fields longer than 4096 characters are truncated""" + long_value1 = "a" * 5000 + long_value2 = "b" * 6000 + metadata = { + "field1": long_value1, + "field2": long_value2, + } + context = test_utils.generate_context() + + event = create_attack_wave_event(context, metadata) + + assert len(event["attack"]["metadata"]["field1"]) == 4096 + assert len(event["attack"]["metadata"]["field2"]) == 4096 + assert event["attack"]["metadata"]["field1"] == long_value1[:4096] + assert event["attack"]["metadata"]["field2"] == long_value2[:4096] + + +def test_create_attack_wave_event_request_data(): + """Test that request data is correctly extracted from context""" + metadata = {"test": "value"} + context = test_utils.generate_context( + ip="198.51.100.23", + route="/test-route", + headers={"user-agent": "Mozilla/5.0"}, + ) + + event = create_attack_wave_event(context, metadata) + + request_data = event["request"] + assert request_data["ipAddress"] == "198.51.100.23" + assert request_data["source"] == "flask" + assert request_data["userAgent"] == "Mozilla/5.0" + + +def test_create_attack_wave_event_no_context(): + """Test attack wave event creation with None context""" + metadata = {"test": "value"} + + event = create_attack_wave_event(None, metadata) + + assert event["attack"]["user"] is None + assert event["attack"]["metadata"] == metadata + assert event["request"] is None + + +def test_create_attack_wave_event_exception_handling(): + """Test that exceptions during event creation are handled gracefully""" + # Create a context that will raise an exception when accessed + context = MagicMock() + context.user = "test_user" + context.remote_address = "1.1.1.1" + context.source = "test_source" + # Make get_user_agent raise an exception + context.get_user_agent.side_effect = Exception("Test exception") + + metadata = {"test": "value"} + + # This should not raise an exception, but return None + event = create_attack_wave_event(context, metadata) + + # Since we're mocking and causing an exception, the function should handle it + # and return None based on the exception handling in the function + assert event is None + + +def test_extract_request_if_possible_with_valid_context(): + """Test request extraction with valid context""" + context = test_utils.generate_context( + ip="198.51.100.23", + route="/test-route", + headers={"user-agent": "Mozilla/5.0"}, + ) + + request = extract_request_if_possible(context) + + assert request is not None + assert request["ipAddress"] == "198.51.100.23" + assert request["source"] == "flask" + assert request["userAgent"] == "Mozilla/5.0" + + +def test_extract_request_if_possible_with_none_context(): + """Test request extraction with None context""" + request = extract_request_if_possible(None) + assert request is None + + +def test_extract_request_if_possible_with_minimal_context(): + """Test request extraction with minimal context data""" + context = test_utils.generate_context() + + request = extract_request_if_possible(context) + + assert request is not None + assert request["ipAddress"] == "1.1.1.1" + assert request["source"] == "flask" + assert request["userAgent"] is None + + +def test_create_attack_wave_event_empty_metadata(): + """Test attack wave event creation with empty metadata""" + metadata = {} + context = test_utils.generate_context() + + event = create_attack_wave_event(context, metadata) + + assert event is not None + assert event["attack"]["metadata"] == {} + assert event["request"] is not None + + +def test_create_attack_wave_event_complex_metadata(): + """Test attack wave event creation with complex nested metadata""" + metadata = { + "nested": {"key1": "value1", "key2": "value2"}, + "simple": "simple_value", + "json_string": "[1, 2, 3]", + "number_string": "42", + } + context = test_utils.generate_context() + + event = create_attack_wave_event(context, metadata) + + assert event["attack"]["metadata"] == metadata + assert event["attack"]["metadata"]["nested"]["key1"] == "value1" + assert event["attack"]["metadata"]["json_string"] == "[1, 2, 3]" diff --git a/aikido_zen/ratelimiting/lru_cache.py b/aikido_zen/ratelimiting/lru_cache.py index a4c5ec651..289aaa119 100644 --- a/aikido_zen/ratelimiting/lru_cache.py +++ b/aikido_zen/ratelimiting/lru_cache.py @@ -3,7 +3,7 @@ """ from collections import OrderedDict -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms +import aikido_zen.helpers.get_current_unixtime_ms as internal_time class LRUCache: @@ -24,7 +24,8 @@ def get(self, key): if key in self.cache: # Check if the item is still valid based on TTL if ( - get_unixtime_ms(monotonic=True) - self.cache[key]["startTime"] + internal_time.get_unixtime_ms(monotonic=True) + - self.cache[key]["startTime"] < self.time_to_live_in_ms ): return self.cache[key]["value"] # Return the actual value @@ -39,7 +40,7 @@ def set(self, key, value): self.cache.popitem(last=False) # Remove the oldest item self.cache[key] = { "value": value, - "startTime": get_unixtime_ms(monotonic=True), + "startTime": internal_time.get_unixtime_ms(monotonic=True), } # Store value and timestamp def clear(self): diff --git a/aikido_zen/sources/functions/request_handler.py b/aikido_zen/sources/functions/request_handler.py index 31ecc0101..201e0eada 100644 --- a/aikido_zen/sources/functions/request_handler.py +++ b/aikido_zen/sources/functions/request_handler.py @@ -4,9 +4,14 @@ from aikido_zen.api_discovery.update_route_info import update_route_info_from_context from aikido_zen.helpers.is_useful_route import is_useful_route from aikido_zen.helpers.logging import logger +from aikido_zen.helpers.create_attack_wave_event import create_attack_wave_event from aikido_zen.thread.thread_cache import get_cache from .ip_allowed_to_access_route import ip_allowed_to_access_route import aikido_zen.background_process.comms as c +from ...background_process.commands import PutEventCommand +from ...helpers.ipc.send_payload import send_payload +from ...helpers.serialize_to_json import serialize_to_json +from ...storage.attack_wave_detector_store import attack_wave_detector_store def request_handler(stage, status_code=0): @@ -79,25 +84,38 @@ def pre_response(): if block_type == "bot-blocking": msg = "You are not allowed to access this resource because you have been identified as a bot." return msg, 403 + return None def post_response(status_code): - """Checks if the current route is useful""" + """Checks if the current route is useful and performs attack wave detection""" context = ctx.get_current_context() if not context: return route_metadata = context.get_route_metadata() + cache = get_cache() + if not cache: + return + + attack_wave = attack_wave_detector_store.is_attack_wave(context.remote_address) + if attack_wave: + cache.stats.on_detected_attack_wave(blocked=False) + + event = create_attack_wave_event(context, metadata={}) + logger.debug("Attack wave: %s", serialize_to_json(event)[:5000]) + + # Report in background to core (send event over IPC) + if c.get_comms() and event: + send_payload(c.get_comms(), PutEventCommand.generate(event)) + + # Check if the current route is useful for API discovery is_curr_route_useful = is_useful_route( status_code, context.route, context.method, ) - if not is_curr_route_useful: - return - - cache = get_cache() - if cache: + if is_curr_route_useful: cache.routes.increment_route(route_metadata) # api spec generation diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index 9af37f31f..3dad511f4 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -1,3 +1,5 @@ +import inspect + import pytest from unittest.mock import patch, MagicMock from aikido_zen.thread.thread_cache import get_cache, ThreadCache @@ -7,6 +9,9 @@ from ...context import Context, current_context from ...helpers.headers import Headers from ...storage.firewall_lists import FirewallLists +from ...vulnerabilities.attack_wave_detection.attack_wave_detector import ( + AttackWaveDetector, +) @pytest.fixture @@ -38,6 +43,8 @@ def __init__(self): self.firewall_lists = FirewallLists() self.conn_manager = MagicMock() self.conn_manager.firewall_lists = self.firewall_lists + self.conn_manager.attack_wave_detector = AttackWaveDetector() + self.attacks = [] def send_data_to_bg_process(self, action, obj, receive=False, timeout_in_sec=0.1): if action != "CHECK_FIREWALL_LISTS": @@ -125,7 +132,7 @@ def test_post_response_no_context(mock_get_comms): # Test firewall lists -def set_context(remote_address, user_agent=""): +def set_context(remote_address, user_agent="", route="/posts/:number"): headers = Headers() headers.store_header("USER_AGENT", user_agent) Context( @@ -138,9 +145,10 @@ def set_context(remote_address, user_agent=""): "body": None, "cookies": {}, "source": "flask", - "route": "/posts/:number", + "route": route, "user": None, "executed_middleware": False, + "parsed_userinput": {}, } ).set_as_current_context() @@ -170,7 +178,13 @@ def wrapper(*args, **kwargs): comms = MyMockComms() mock_comms.return_value = comms - return func(*args, firewall_lists=comms.firewall_lists, **kwargs) + sig = inspect.signature(func) + if "attacks" in sig.parameters: + kwargs["attacks"] = comms.attacks + if "firewall_lists" in sig.parameters: + kwargs["firewall_lists"] = comms.firewall_lists + + return func(*args, **kwargs) return wrapper @@ -579,3 +593,255 @@ def test_multiple_blocked(firewall_lists): assert request_handler("pre_response") is None set_context("fd00:1234:5678:9abc::2") assert request_handler("pre_response") is None + + +@patch_firewall_lists +def test_attack_wave_detection_in_post_response(firewall_lists): + """Test attack wave detection happens in post_response stage""" + set_context("1.1.1.1", route="/.env") + create_service_config() + + # Reset attack wave detector store for clean test + from aikido_zen.storage.attack_wave_detector_store import attack_wave_detector_store + + detector = attack_wave_detector_store._get_detector() + detector.suspicious_requests_map.clear() + detector.sent_events_map.clear() + + # Reset stats + get_cache().stats.clear() + + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 0, + "blocked": 0, + } + + # Call request_handler 15 times in post_response to trigger attack wave detection + for i in range(15): + request_handler("post_response", status_code=200) + + # The attack wave should be detected + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 1, + "blocked": 0, + } + + # now try again (should not be possible due to cooldown window) + for i in range(15): + request_handler("post_response", status_code=200) + + # Should still be 1 because of cooldown + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 1, + "blocked": 0, + } + + # now try with another IP + set_context("4.4.4.4", route="/.htaccess") + for i in range(15): + request_handler("post_response", status_code=200) + + # Should now be 2 (one for each IP) + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 2, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_threshold_not_reached(firewall_lists): + """Test attack wave detection when threshold is not reached""" + set_context("2.2.2.2", route="/test") + create_service_config() + + # Reset attack wave detector store for clean test + from aikido_zen.storage.attack_wave_detector_store import attack_wave_detector_store + + detector = attack_wave_detector_store._get_detector() + detector.suspicious_requests_map.clear() + detector.sent_events_map.clear() + + # Reset stats + get_cache().stats.clear() + + # Call request_handler 14 times (below threshold of 15) + for i in range(14): + request_handler("post_response", status_code=200) + + # No attack wave should be detected + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 0, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_with_cooldown(firewall_lists): + """Test attack wave detection respects cooldown period""" + set_context("3.3.3.3", route="/admin") + create_service_config() + + # Reset attack wave detector store for clean test + from aikido_zen.storage.attack_wave_detector_store import attack_wave_detector_store + + detector = attack_wave_detector_store._get_detector() + detector.suspicious_requests_map.clear() + detector.sent_events_map.clear() + + # Reset stats + get_cache().stats.clear() + + # Trigger first attack wave + for i in range(15): + request_handler("post_response", status_code=200) + + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 1, + "blocked": 0, + } + + # Try to trigger another attack wave immediately (should be blocked by cooldown) + for i in range(15): + request_handler("post_response", status_code=200) + + # Should still be 1 due to cooldown + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 1, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_no_context(firewall_lists): + """Test attack wave detection when there is no context""" + create_service_config() + + # Reset stats + get_cache().stats.clear() + + # Call request_handler without context + with patch("aikido_zen.context.get_current_context", return_value=None): + request_handler("post_response", status_code=200) + + # No attack wave should be detected + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 0, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_with_different_routes(firewall_lists): + """Test attack wave detection with different routes""" + create_service_config() + + # Reset attack wave detector store for clean test + from aikido_zen.storage.attack_wave_detector_store import attack_wave_detector_store + + detector = attack_wave_detector_store._get_detector() + detector.suspicious_requests_map.clear() + detector.sent_events_map.clear() + + # Reset stats + get_cache().stats.clear() + + # Test with different suspicious routes + suspicious_routes = ["/.env", "/.git/config", "/wp-admin/", "/admin/", "/.htaccess"] + + for route in suspicious_routes: + set_context("5.5.5.5", route=route) + for i in range(3): # 3 requests per route + request_handler("post_response", status_code=200) + + # Should have 15 total requests (3 * 5 routes) which should trigger attack wave + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 1, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_with_null_ip(firewall_lists): + """Test attack wave detection with null/empty IP""" + set_context("", route="/test") # Empty IP + create_service_config() + + # Reset stats + get_cache().stats.clear() + + # Call request_handler multiple times with empty IP + for i in range(20): + request_handler("post_response", status_code=200) + + # No attack wave should be detected for null IP + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 0, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_post_response_only(firewall_lists): + """Test that attack wave detection only happens in post_response stage""" + set_context("6.6.6.6", route="/.env") + create_service_config() + + # Reset attack wave detector store for clean test + from aikido_zen.storage.attack_wave_detector_store import attack_wave_detector_store + + detector = attack_wave_detector_store._get_detector() + detector.suspicious_requests_map.clear() + detector.sent_events_map.clear() + + # Reset stats + get_cache().stats.clear() + + # Call pre_response multiple times - should not trigger attack wave detection + for i in range(20): + request_handler("pre_response") + + # No attack wave should be detected in pre_response + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 0, + "blocked": 0, + } + + # Now call post_response to actually trigger detection + for i in range(15): + request_handler("post_response", status_code=200) + + # Attack wave should now be detected + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 1, + "blocked": 0, + } + + +@patch_firewall_lists +def test_attack_wave_detection_multiple_ips(firewall_lists): + """Test attack wave detection with multiple different IPs""" + create_service_config() + + # Reset attack wave detector store for clean test + from aikido_zen.storage.attack_wave_detector_store import attack_wave_detector_store + + detector = attack_wave_detector_store._get_detector() + detector.suspicious_requests_map.clear() + detector.sent_events_map.clear() + + # Reset stats + get_cache().stats.clear() + + # Test with multiple IPs, each making some requests + ips = ["8.8.8.8", "9.9.9.9", "10.10.10.10"] + + for ip in ips: + set_context(ip, route="/test") + for i in range(15): # 15 requests per IP (threshold) + request_handler("post_response", status_code=200) + + # Should have 45 total requests (15 * 3 IPs) which should trigger attack wave for each IP + assert get_cache().stats.get_record()["requests"]["attackWaves"] == { + "total": 3, # One attack wave per IP + "blocked": 0, + } diff --git a/aikido_zen/storage/attack_wave_detector_store.py b/aikido_zen/storage/attack_wave_detector_store.py new file mode 100644 index 000000000..47defeb11 --- /dev/null +++ b/aikido_zen/storage/attack_wave_detector_store.py @@ -0,0 +1,21 @@ +import threading +from aikido_zen.vulnerabilities.attack_wave_detection.attack_wave_detector import ( + AttackWaveDetector, +) + + +class AttackWaveDetectorStore: + def __init__(self): + self._detector = AttackWaveDetector() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def is_attack_wave(self, ip: str) -> bool: + with self._lock: + return self._detector.is_attack_wave(ip) + + def _get_detector(self): + """Used in testing (internal)""" + return self._detector + + +attack_wave_detector_store = AttackWaveDetectorStore() diff --git a/aikido_zen/storage/attack_wave_detector_store_test.py b/aikido_zen/storage/attack_wave_detector_store_test.py new file mode 100644 index 000000000..dccf3effa --- /dev/null +++ b/aikido_zen/storage/attack_wave_detector_store_test.py @@ -0,0 +1,242 @@ +""" +Test cases for AttackWaveDetectorStore +""" + +import pytest +import threading +import time +from unittest.mock import patch +from .attack_wave_detector_store import ( + AttackWaveDetectorStore, + attack_wave_detector_store, +) + + +def test_attack_wave_detector_store_initialization(): + """Test that the store initializes correctly""" + store = AttackWaveDetectorStore() + assert store is not None + assert store._get_detector() is not None + + +def test_attack_wave_detector_store_singleton(): + """Test that the global singleton instance works""" + assert attack_wave_detector_store is not None + assert attack_wave_detector_store._get_detector() is not None + + +def test_is_attack_wave_basic_functionality(): + """Test basic attack wave detection functionality""" + store = AttackWaveDetectorStore() + + # Should return False for first few calls + assert not store.is_attack_wave("1.1.1.1") + assert not store.is_attack_wave("1.1.1.1") + + # Call 12 more times to get to 14 total (still below threshold) + for _ in range(12): + result = store.is_attack_wave("1.1.1.1") + assert not result + + # The 15th call should trigger attack wave detection and return True + assert store.is_attack_wave("1.1.1.1") + + +def test_is_attack_wave_different_ips(): + """Test that different IPs are tracked separately""" + store = AttackWaveDetectorStore() + + # Call multiple times for different IPs + for _ in range(10): + store.is_attack_wave("1.1.1.1") + store.is_attack_wave("2.2.2.2") + + # Neither should trigger attack wave yet + assert not store.is_attack_wave("1.1.1.1") + assert not store.is_attack_wave("2.2.2.2") + + +def test_is_attack_wave_none_ip(): + """Test handling of None IP address""" + store = AttackWaveDetectorStore() + assert not store.is_attack_wave(None) + + +def test_is_attack_wave_empty_ip(): + """Test handling of empty IP address""" + store = AttackWaveDetectorStore() + assert not store.is_attack_wave("") + + +def test_thread_safety_multiple_threads(): + """Test thread safety with multiple threads accessing the store""" + store = AttackWaveDetectorStore() + + results = [] + threads = [] + + def worker(ip_suffix, result_list): + """Worker function that calls is_attack_wave multiple times""" + ip = f"192.168.1.{ip_suffix}" + for _ in range(5): + result = store.is_attack_wave(ip) + result_list.append((ip, result)) + time.sleep(0.001) # Small delay to simulate real usage + + # Create and start multiple threads + for i in range(5): + thread = threading.Thread(target=worker, args=(i, results)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify that we got results from all threads + assert len(results) == 25 # 5 threads * 5 calls each + + # Verify that no exceptions were raised (thread safety) + assert all(isinstance(result, tuple) for result in results) + + +def test_thread_safety_same_ip(): + """Test thread safety when multiple threads access the same IP""" + store = AttackWaveDetectorStore() + + results = [] + threads = [] + lock = threading.Lock() + + def worker(result_list): + """Worker function that calls is_attack_wave for the same IP""" + for _ in range(10): + result = store.is_attack_wave("10.0.0.1") + with lock: + result_list.append(result) + time.sleep(0.001) + + # Create and start multiple threads + for _ in range(3): + thread = threading.Thread(target=worker, args=(results,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify that we got results from all threads + assert len(results) == 30 # 3 threads * 10 calls each + + # Verify that no exceptions were raised + assert all(isinstance(result, bool) for result in results) + + +def test_attack_wave_cooldown(): + """Test that attack wave detection respects the cooldown period""" + store = AttackWaveDetectorStore() + + # Call 14 times to get close to threshold + for _ in range(14): + store.is_attack_wave("1.1.1.1") + + # The 15th call should trigger attack wave detection and return True + assert store.is_attack_wave("1.1.1.1") + + # Subsequent calls should return False due to cooldown + assert not store.is_attack_wave("1.1.1.1") + + +def test_attack_wave_time_frame(): + """Test that attack wave detection respects the time frame""" + store = AttackWaveDetectorStore() + + # Make some calls + for _ in range(5): + store.is_attack_wave("1.1.1.1") + + # Should not trigger attack wave yet + assert not store.is_attack_wave("1.1.1.1") + + # Wait for the time frame to expire (60 seconds) + # We can't actually wait 60 seconds in a test, but we can verify the behavior + # by checking that the detector is tracking the requests correctly + detector = store._get_detector() + assert detector.suspicious_requests_map.get("1.1.1.1") == 6 + + +def test__get_detector_returns_same_instance(): + """Test that _get_detector returns the same instance""" + store = AttackWaveDetectorStore() + detector1 = store._get_detector() + detector2 = store._get_detector() + assert detector1 is detector2 + + +def test_global_singleton_consistency(): + """Test that the global singleton is consistent""" + detector1 = attack_wave_detector_store._get_detector() + detector2 = attack_wave_detector_store._get_detector() + assert detector1 is detector2 + + +def test_attack_wave_detector_store_with_custom_parameters(): + """Test that custom parameters can be set via the detector""" + store = AttackWaveDetectorStore() + detector = store._get_detector() + + # Verify default parameters + assert detector.attack_wave_threshold == 15 + assert detector.attack_wave_time_frame == 60 * 1000 + assert detector.min_time_between_events == 20 * 60 * 1000 + + +def test_stress_test_high_concurrency(): + """Stress test with high concurrency""" + store = AttackWaveDetectorStore() + + results = [] + threads = [] + + def worker(worker_id): + """Worker function for stress test""" + try: + for i in range(10): + ip = f"192.168.{worker_id}.{i}" + result = store.is_attack_wave(ip) + results.append((worker_id, ip, result)) + except Exception as e: + results.append((worker_id, "error", str(e))) + + # Create many threads + for i in range(10): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify that all threads completed without errors + error_results = [r for r in results if r[1] == "error"] + assert len(error_results) == 0 + + # Verify we got expected number of results + assert len(results) == 100 # 10 threads * 10 IPs each + + +@patch("aikido_zen.storage.attack_wave_detector_store.AttackWaveDetector") +def test_mock_detector_integration(mock_detector_class): + """Test integration with mocked AttackWaveDetector""" + # Setup mock + mock_detector = mock_detector_class.return_value + mock_detector.is_attack_wave.return_value = True + + store = AttackWaveDetectorStore() + + # Should use the mocked detector + result = store.is_attack_wave("1.1.1.1") + assert result is True + mock_detector.is_attack_wave.assert_called_once_with("1.1.1.1") diff --git a/aikido_zen/storage/statistics/__init__.py b/aikido_zen/storage/statistics/__init__.py index 47e7dec3e..4529f8681 100644 --- a/aikido_zen/storage/statistics/__init__.py +++ b/aikido_zen/storage/statistics/__init__.py @@ -10,16 +10,26 @@ class Statistics: def __init__(self): self.total_hits = 0 + self.attacks_detected = 0 self.attacks_blocked = 0 + + self.attack_waves_detected = 0 + self.attack_waves_blocked = 0 + self.rate_limited_hits = 0 self.started_at = t.get_unixtime_ms() self.operations = Operations() def clear(self): self.total_hits = 0 + self.attacks_detected = 0 self.attacks_blocked = 0 + + self.attack_waves_detected = 0 + self.attack_waves_blocked = 0 + self.rate_limited_hits = 0 self.started_at = t.get_unixtime_ms() self.operations.clear() @@ -36,6 +46,11 @@ def on_detected_attack(self, blocked, operation): def on_rate_limit(self): self.rate_limited_hits += 1 + def on_detected_attack_wave(self, blocked: bool): + self.attack_waves_detected += 1 + if blocked: + self.attack_waves_blocked += 1 + def get_record(self): current_time = t.get_unixtime_ms() return { @@ -49,6 +64,10 @@ def get_record(self): "total": self.attacks_detected, "blocked": self.attacks_blocked, }, + "attackWaves": { + "total": self.attack_waves_detected, + "blocked": self.attack_waves_blocked, + }, }, "operations": dict(self.operations), } @@ -66,6 +85,8 @@ def empty(self): return False if self.attacks_detected > 0: return False + if self.attack_waves_detected > 0: + return False if len(self.operations) > 0: return False return True diff --git a/aikido_zen/storage/statistics/init_test.py b/aikido_zen/storage/statistics/init_test.py index 6ed93b271..9e7ed0f1c 100644 --- a/aikido_zen/storage/statistics/init_test.py +++ b/aikido_zen/storage/statistics/init_test.py @@ -15,6 +15,8 @@ def test_initialization(monkeypatch): assert stats.attacks_detected == 0 assert stats.attacks_blocked == 0 assert stats.started_at == 1234567890000 + assert stats.attack_waves_detected == 0 + assert stats.attack_waves_blocked == 0 assert isinstance(stats.operations, Operations) @@ -23,6 +25,7 @@ def test_clear(monkeypatch): stats.total_hits = 10 stats.attacks_detected = 5 stats.attacks_blocked = 3 + stats.on_detected_attack_wave(blocked=True) stats.operations.register_call("test", "sql_op") with test_utils.patch_time(time_s=1234567890): stats.clear() @@ -31,6 +34,8 @@ def test_clear(monkeypatch): assert stats.attacks_detected == 0 assert stats.attacks_blocked == 0 assert stats.started_at == 1234567890000 + assert stats.attack_waves_blocked == 0 + assert stats.attack_waves_detected == 0 assert stats.operations == {} @@ -50,6 +55,16 @@ def test_on_detected_attack(stats): assert stats.attacks_blocked == 1 +def test_on_detected_attack_wave(stats): + stats.on_detected_attack_wave(blocked=True) + assert stats.get_record()["requests"]["attackWaves"]["total"] == 1 + assert stats.get_record()["requests"]["attackWaves"]["blocked"] == 1 + + stats.on_detected_attack_wave(blocked=False) + assert stats.get_record()["requests"]["attackWaves"]["total"] == 2 + assert stats.get_record()["requests"]["attackWaves"]["blocked"] == 1 + + def test_get_record(monkeypatch): with test_utils.patch_time(time_s=1234567890): stats = Statistics() @@ -60,6 +75,8 @@ def test_get_record(monkeypatch): stats.on_detected_attack(blocked=True, operation="test.test") stats.attacks_detected = 5 stats.attacks_blocked = 3 + stats.on_detected_attack_wave(False) + stats.on_detected_attack_wave(False) with test_utils.patch_time(time_s=9999999999): record = stats.get_record() @@ -70,6 +87,8 @@ def test_get_record(monkeypatch): assert record["requests"]["aborted"] == 0 assert record["requests"]["attacksDetected"]["total"] == 5 assert record["requests"]["attacksDetected"]["blocked"] == 3 + assert record["requests"]["attackWaves"]["total"] == 2 + assert record["requests"]["attackWaves"]["blocked"] == 0 assert record["operations"] == { "test.test": { "attacksDetected": {"blocked": 1, "total": 1}, @@ -140,6 +159,12 @@ def test_empty(stats): assert stats.empty() == False +def test_empty_with_attack_waves(stats): + assert stats.empty() + stats.on_detected_attack_wave(blocked=False) + assert not stats.empty() + + def test_multiple_imports(stats): record1 = { "requests": { diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index 83b484643..78e69cecb 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -39,6 +39,7 @@ def test_initialization(thread_cache: ThreadCache): "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, + "attackWaves": {"total": 0, "blocked": 0}, } @@ -75,6 +76,7 @@ def test_reset(thread_cache: ThreadCache): "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, + "attackWaves": {"total": 0, "blocked": 0}, } @@ -99,6 +101,7 @@ def test_renew_with_no_comms(thread_cache: ThreadCache): "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, + "attackWaves": {"total": 0, "blocked": 0}, } @@ -280,6 +283,7 @@ def test_renew_called_with_correct_args(mock_get_comms, thread_cache: ThreadCach "rateLimited": 0, "aborted": 0, "attacksDetected": {"blocked": 1, "total": 3}, + "attackWaves": {"total": 0, "blocked": 0}, }, "operations": { "op1": { @@ -360,6 +364,7 @@ def test_sync_data_for_users(mock_get_comms, thread_cache: ThreadCache): "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, + "attackWaves": {"total": 0, "blocked": 0}, }, "operations": {}, }, @@ -410,6 +415,7 @@ def test_renew_called_with_empty_routes(mock_get_comms, thread_cache: ThreadCach "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, + "attackWaves": {"total": 0, "blocked": 0}, }, "operations": {}, }, @@ -448,6 +454,7 @@ def test_renew_called_with_no_requests(mock_get_comms, thread_cache: ThreadCache "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, + "attackWaves": {"total": 0, "blocked": 0}, }, "operations": {}, }, diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/__init__.py b/aikido_zen/vulnerabilities/attack_wave_detection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector.py b/aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector.py new file mode 100644 index 000000000..ece4b373e --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector.py @@ -0,0 +1,47 @@ +import aikido_zen.helpers.get_current_unixtime_ms as internal_time +from aikido_zen.ratelimiting.lru_cache import LRUCache + + +class AttackWaveDetector: + def __init__( + self, + attack_wave_threshold: int = 15, + attack_wave_time_frame: int = 60 * 1000, # 1 minute in ms + min_time_between_events: int = 20 * 60 * 1000, # 20 minutes in ms + max_lru_entries: int = 10_000, + ): + self.attack_wave_threshold = attack_wave_threshold + self.attack_wave_time_frame = attack_wave_time_frame + self.min_time_between_events = min_time_between_events + self.max_lru_entries = max_lru_entries + + self.suspicious_requests_map = LRUCache( + max_items=self.max_lru_entries, + time_to_live_in_ms=self.attack_wave_time_frame, + ) + self.sent_events_map = LRUCache( + max_items=self.max_lru_entries, + time_to_live_in_ms=self.min_time_between_events, + ) + + def is_attack_wave(self, ip: str) -> bool: + """ + Function gets called with IP if there is an attack wave request. + """ + if not ip: + return False + + # Check if an event was sent recently + if self.sent_events_map.get(ip) is not None: + return False + + # Increment suspicious requests count -> there is a new or first suspicious request + suspicious_requests = (self.suspicious_requests_map.get(ip) or 0) + 1 + self.suspicious_requests_map.set(ip, suspicious_requests) + + if suspicious_requests < self.attack_wave_threshold: + return False + + # Mark event as sent + self.sent_events_map.set(ip, internal_time.get_unixtime_ms(monotonic=True)) + return True diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector_test.py b/aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector_test.py new file mode 100644 index 000000000..a4edf0cbf --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/attack_wave_detector_test.py @@ -0,0 +1,138 @@ +import pytest +from unittest.mock import patch +from .attack_wave_detector import AttackWaveDetector + + +def new_attack_wave_detector(): + return AttackWaveDetector( + attack_wave_threshold=6, + attack_wave_time_frame=60 * 1000, + min_time_between_events=60 * 60 * 1000, + max_lru_entries=10_000, + ) + + +# Mock for get_unixtime_ms +def mock_get_unixtime_ms(monotonic=True, mock_time=0): + return mock_time + + +def test_no_ip_address(): + detector = new_attack_wave_detector() + assert not detector.is_attack_wave(None) + + +def test_a_web_scanner(): + detector = new_attack_wave_detector() + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + # Is true because the threshold is 6 + assert detector.is_attack_wave("::1") + # False again because event should have been sent last time + assert not detector.is_attack_wave("::1") + + +def test_a_web_scanner_with_delays(): + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=0), + ): + detector = new_attack_wave_detector() + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=30 * 1000), + ): + assert not detector.is_attack_wave("::1") + # Is true because the threshold is 6 + assert detector.is_attack_wave("::1") + # False again because event should have been sent last time + assert not detector.is_attack_wave("::1") + + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=60 * 60 * 1000), + ): + # Still false because minimum time between events is 1 hour + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=92 * 60 * 1000), + ): + # Should resend event after 1 hour + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert detector.is_attack_wave("::1") + + +def test_a_slow_web_scanner_that_triggers_in_the_second_interval(): + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=0), + ): + detector = new_attack_wave_detector() + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=62 * 1000), + ): + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert detector.is_attack_wave("::1") + + +def test_a_slow_web_scanner_that_triggers_in_the_third_interval(): + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=0), + ): + detector = new_attack_wave_detector() + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=62 * 1000), + ): + # Still false because minimum time between events is 1 hour + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", + side_effect=lambda **kw: mock_get_unixtime_ms(**kw, mock_time=124 * 1000), + ): + # Should resend event after 1 hour + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert not detector.is_attack_wave("::1") + assert detector.is_attack_wave("::1") diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_method.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_method.py new file mode 100644 index 000000000..9ffec2b97 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_method.py @@ -0,0 +1,11 @@ +web_scan_methods = { + "BADMETHOD", + "BADHTTPMETHOD", + "BADDATA", + "BADMTHD", + "BDMTHD", +} + + +def is_web_scan_method(method: str) -> bool: + return method.upper() in web_scan_methods diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_method_test.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_method_test.py new file mode 100644 index 000000000..91479b6a8 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_method_test.py @@ -0,0 +1,24 @@ +from aikido_zen.vulnerabilities.attack_wave_detection.is_web_scan_method import ( + is_web_scan_method, +) + + +def test_is_web_scan_method(): + assert is_web_scan_method("BADMETHOD") + assert is_web_scan_method("BADHTTPMETHOD") + assert is_web_scan_method("BADDATA") + assert is_web_scan_method("BADMTHD") + assert is_web_scan_method("BDMTHD") + + +def test_is_not_web_scan_method(): + assert not is_web_scan_method("GET") + assert not is_web_scan_method("POST") + assert not is_web_scan_method("PUT") + assert not is_web_scan_method("DELETE") + assert not is_web_scan_method("PATCH") + assert not is_web_scan_method("OPTIONS") + assert not is_web_scan_method("HEAD") + assert not is_web_scan_method("TRACE") + assert not is_web_scan_method("CONNECT") + assert not is_web_scan_method("PURGE") diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_path.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_path.py new file mode 100644 index 000000000..eacf23454 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_path.py @@ -0,0 +1,47 @@ +from aikido_zen.vulnerabilities.attack_wave_detection.paths import ( + file_names, + directory_names, +) + +file_extensions = { + "env", + "bak", + "sql", + "sqlite", + "sqlite3", + "db", + "old", + "save", + "orig", + "sqlitedb", + "sqlite3db", +} +filenames = {name.lower() for name in file_names} +directories = {name.lower() for name in directory_names} + + +def is_web_scan_path(path: str) -> bool: + """ + is_web_scan_path gets the current route and wants to determine whether it's a test by some web scanner. + Checks filename if it exists (list of suspicious filenames & list of supsicious extensions) + Checks all other segments for suspicious directories + """ + normalized = path.lower() + segments = normalized.split("/") + if not segments: + return False + + filename = segments[-1] + if filename: + if filename in filenames: + return True + + if "." in filename: + ext = filename.split(".")[-1] + if ext in file_extensions: + return True + + for directory in segments: + if directory in directories: + return True + return False diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_path_test.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_path_test.py new file mode 100644 index 000000000..ee9d13fa1 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scan_path_test.py @@ -0,0 +1,46 @@ +from aikido_zen.vulnerabilities.attack_wave_detection.is_web_scan_path import ( + is_web_scan_path, +) +from aikido_zen.vulnerabilities.attack_wave_detection.paths import ( + file_names, + directory_names, +) + + +def test_is_web_scan_path(): + assert is_web_scan_path("/.env") + assert is_web_scan_path("/test/.env") + assert is_web_scan_path("/test/.env.bak") + assert is_web_scan_path("/.git/config") + assert is_web_scan_path("/.aws/config") + assert is_web_scan_path("/some/path/.git/test") + assert is_web_scan_path("/some/path/.gitlab-ci.yml") + assert is_web_scan_path("/some/path/.github/workflows/test.yml") + assert is_web_scan_path("/.travis.yml") + assert is_web_scan_path("/../example/") + assert is_web_scan_path("/./test") + assert is_web_scan_path("/Cargo.lock") + assert is_web_scan_path("/System32/test") + + +def test_is_not_web_scan_path(): + assert not is_web_scan_path("/test/file.txt") + assert not is_web_scan_path("/some/route/to/file.txt") + assert not is_web_scan_path("/some/route/to/file.json") + assert not is_web_scan_path("/en") + assert not is_web_scan_path("/") + assert not is_web_scan_path("/test/route") + assert not is_web_scan_path("/static/file.css") + assert not is_web_scan_path("/static/file.a461f56e.js") + + +def test_no_duplicates_in_file_names(): + unique_file_names = set(file_names) + assert len(unique_file_names) == len(file_names), "File names should be unique" + + +def test_no_duplicates_in_directory_names(): + unique_directory_names = set(directory_names) + assert len(unique_directory_names) == len( + directory_names + ), "Directory names should be unique" diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner.py new file mode 100644 index 000000000..e1638f89e --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner.py @@ -0,0 +1,20 @@ +from aikido_zen.context import Context +from aikido_zen.vulnerabilities.attack_wave_detection.is_web_scan_method import ( + is_web_scan_method, +) +from aikido_zen.vulnerabilities.attack_wave_detection.is_web_scan_path import ( + is_web_scan_path, +) +from aikido_zen.vulnerabilities.attack_wave_detection.query_params_contain_dangerous_strings import ( + query_params_contain_dangerous_strings, +) + + +def is_web_scanner(context: Context) -> bool: + if context.method and is_web_scan_method(context.method): + return True + if context.route and is_web_scan_path(context.route): + return True + if query_params_contain_dangerous_strings(context): + return True + return False diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner_benchmark_test.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner_benchmark_test.py new file mode 100644 index 000000000..9da02a694 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner_benchmark_test.py @@ -0,0 +1,46 @@ +import time + +import pytest + +from aikido_zen.vulnerabilities.attack_wave_detection.is_web_scanner import ( + is_web_scanner, +) + + +class Context: + def __init__(self, route, method, query): + self.remote_address = "::1" + self.method = method + self.url = "http://example.com" + self.query = query + self.headers = {} + self.body = {} + self.cookies = {} + self.route_params = {} + self.source = "flask" + self.route = route + self.parsed_userinput = {} + + +def get_test_context(path="/", method="GET", query=None): + return Context(path, method, query) + + +# the CI/CD results here are very unreliable, locally this test passes consistently. +@pytest.mark.skip(reason="Skipping this test in CI/CD") +def test_performance(): + iterations = 25_000 + start = time.perf_counter_ns() + for _ in range(iterations): + is_web_scanner(get_test_context("/wp-config.php", "GET", {"test": "1"})) + is_web_scanner( + get_test_context("/vulnerable", "GET", {"test": "1'; DROP TABLE users; --"}) + ) + is_web_scanner(get_test_context("/", "GET", {"test": "1"})) + end = time.perf_counter_ns() + + total_time_ms = (end - start) / 1_000_000 + time_per_check_ms = total_time_ms / iterations / 3 + assert ( + time_per_check_ms < 0.006 + ), f"Took {time_per_check_ms:.6f}ms per check (max allowed: 0.006ms)" diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner_test.py b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner_test.py new file mode 100644 index 000000000..eb03a44f4 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/is_web_scanner_test.py @@ -0,0 +1,43 @@ +import pytest +from .is_web_scanner import is_web_scanner + + +class Context: + def __init__(self, route, method, query): + self.remote_address = "::1" + self.method = method + self.url = "http://example.com" + self.query = query + self.headers = {} + self.body = {} + self.cookies = {} + self.route_params = {} + self.source = "flask" + self.route = route + self.parsed_userinput = {} + + +def get_test_context(path="/", method="GET", query=None): + return Context(path, method, query) + + +def test_is_web_scanner(): + assert is_web_scanner(get_test_context("/wp-config.php", "GET")) + assert is_web_scanner(get_test_context("/.env", "GET")) + assert is_web_scanner(get_test_context("/test/.env.bak", "GET")) + assert is_web_scanner(get_test_context("/.git/config", "GET")) + assert is_web_scanner(get_test_context("/.aws/config", "GET")) + assert is_web_scanner(get_test_context("/../secret", "GET")) + assert is_web_scanner(get_test_context("/", "BADMETHOD")) + assert is_web_scanner(get_test_context("/", "GET", {"test": "SELECT * FROM admin"})) + assert is_web_scanner(get_test_context("/", "GET", {"test": "../etc/passwd"})) + + +def test_is_not_web_scanner(): + assert not is_web_scanner(get_test_context("graphql", "POST")) + assert not is_web_scanner(get_test_context("/api/v1/users", "GET")) + assert not is_web_scanner(get_test_context("/public/index.html", "GET")) + assert not is_web_scanner(get_test_context("/static/js/app.js", "GET")) + assert not is_web_scanner(get_test_context("/uploads/image.png", "GET")) + assert not is_web_scanner(get_test_context("/", "GET", {"test": "1'"})) + assert not is_web_scanner(get_test_context("/", "GET", {"test": "abcd"})) diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/paths.py b/aikido_zen/vulnerabilities/attack_wave_detection/paths.py new file mode 100644 index 000000000..9bbd1089a --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/paths.py @@ -0,0 +1,354 @@ +# Sourced from AikidoSec/firewall-node : library/vulnerabilities/attack-wave-detection/paths/fileNames.ts +file_names = { + ".addressbook", + ".atom", + ".bashrc", + ".boto", + ".config", + ".config.json", + ".config.xml", + ".config.yaml", + ".config.yml", + ".envrc", + ".eslintignore", + ".fbcindex", + ".forward", + ".gitattributes", + ".gitconfig", + ".gitignore", + ".gitkeep", + ".gitlab-ci.yaml", + ".gitlab-ci.yml", + ".gitmodules", + ".google_authenticator", + ".hgignore", + ".htaccess", + ".htpasswd", + ".htdigest", + ".ksh_history", + ".lesshst", + ".lhistory", + ".lighttpdpassword", + ".lldb-history", + ".lynx_cookies", + ".my.cnf", + ".mysql_history", + ".nano_history", + ".netrc", + ".node_repl_history", + ".npmrc", + ".nsconfig", + ".nsr", + ".password-store", + ".pearrc", + ".pgpass", + ".php_history", + ".pinerc", + ".proclog", + ".procmailrc", + ".profile", + ".psql_history", + ".python_history", + ".rediscli_history", + ".rhosts", + ".selected_editor", + ".sh_history", + ".sqlite_history", + ".svnignore", + ".tcshrc", + ".tmux.conf", + ".travis.yaml", + ".travis.yml", + ".viminfo", + ".vimrc", + ".www_acl", + ".wwwacl", + ".xauthority", + ".yarnrc", + ".zhistory", + ".zsh_history", + ".zshenv", + ".zshrc", + "Dockerfile", + "aws-key.yaml", + "aws-key.yml", + "aws.yaml", + "aws.yml", + "docker-compose.yaml", + "docker-compose.yml", + "npm-shrinkwrap.json", + "package-lock.json", + "package.json", + "phpinfo.php", + "wp-config.php", + "wp-config.php3", + "wp-config.php4", + "wp-config.php5", + "wp-config.phtml", + "composer.json", + "composer.lock", + "composer.phar", + "yarn.lock", + ".env.local", + ".env.development", + ".env.test", + ".env.production", + ".env.prod", + ".env.dev", + ".env.example", + "php.ini", + "wp-settings.php", + "config.asp", + "config_dev.asp", + "config-dev.asp", + "config.dev.asp", + "config_prod.asp", + "config-prod.asp", + "config.prod.asp", + "config.sample.asp", + "config-sample.asp", + "config_sample.asp", + "config_test.asp", + "config-test.asp", + "config.test.asp", + "config.ini", + "config_dev.ini", + "config-dev.ini", + "config.dev.ini", + "config_prod.ini", + "config-prod.ini", + "config.prod.ini", + "config.sample.ini", + "config-sample.ini", + "config_sample.ini", + "config_test.ini", + "config-test.ini", + "config.test.ini", + "config.json", + "config_dev.json", + "config-dev.json", + "config.dev.json", + "config_prod.json", + "config-prod.json", + "config.prod.json", + "config.sample.json", + "config-sample.json", + "config_sample.json", + "config_test.json", + "config-test.json", + "config.test.json", + "config.php", + "config_dev.php", + "config-dev.php", + "config.dev.php", + "config_prod.php", + "config-prod.php", + "config.prod.php", + "config.sample.php", + "config-sample.php", + "config_sample.php", + "config_test.php", + "config-test.php", + "config.test.php", + "config.pl", + "config_dev.pl", + "config-dev.pl", + "config.dev.pl", + "config_prod.pl", + "config-prod.pl", + "config.prod.pl", + "config.sample.pl", + "config-sample.pl", + "config_sample.pl", + "config_test.pl", + "config-test.pl", + "config.test.pl", + "config.py", + "config_dev.py", + "config-dev.py", + "config.dev.py", + "config_prod.py", + "config-prod.py", + "config.prod.py", + "config.sample.py", + "config-sample.py", + "config_sample.py", + "config_test.py", + "config-test.py", + "config.test.py", + "config.rb", + "config_dev.rb", + "config-dev.rb", + "config.dev.rb", + "config_prod.rb", + "config-prod.rb", + "config.prod.rb", + "config.sample.rb", + "config-sample.rb", + "config_sample.rb", + "config_test.rb", + "config-test.rb", + "config.test.rb", + "config.toml", + "config_dev.toml", + "config-dev.toml", + "config.dev.toml", + "config_prod.toml", + "config-prod.toml", + "config.prod.toml", + "config.sample.toml", + "config-sample.toml", + "config_sample.toml", + "config_test.toml", + "config-test.toml", + "config.test.toml", + "config.txt", + "config_dev.txt", + "config-dev.txt", + "config.dev.txt", + "config_prod.txt", + "config-prod.txt", + "config.prod.txt", + "config.sample.txt", + "config-sample.txt", + "config_sample.txt", + "config_test.txt", + "config-test.txt", + "config.test.txt", + "config.xml", + "config_dev.xml", + "config-dev.xml", + "config.dev.xml", + "config_prod.xml", + "config-prod.xml", + "config.prod.xml", + "config.sample.xml", + "config-sample.xml", + "config_sample.xml", + "config_test.xml", + "config-test.xml", + "config.test.xml", + "config.yaml", + "config_dev.yaml", + "config-dev.yaml", + "config.dev.yaml", + "config_prod.yaml", + "config-prod.yaml", + "config.prod.yaml", + "config.sample.yaml", + "config-sample.yaml", + "config_sample.yaml", + "config_test.yaml", + "config-test.yaml", + "config.test.yaml", + "config.yml", + "config_dev.yml", + "config-dev.yml", + "config.dev.yml", + "config_prod.yml", + "config-prod.yml", + "config.prod.yml", + "config.sample.yml", + "config-sample.yml", + "config_sample.yml", + "config_test.yml", + "config-test.yml", + "config.test.yml", + "boot.ini", + "gruntfile.js", + "localsettings.php", + "my.ini", + "npm-debug.log", + "parameters.yml", + "parameters.yaml", + "services.yml", + "services.yaml", + "web.config", + "webpack.config.js", + "config.old", + "config.inc.php", + "error.log", + "access.log", + ".DS_Store", + "passwd", + "win.ini", + "cmd.exe", + "my.cnf", + ".bash_history", + "docker-compose-dev.yml", + "docker-compose.override.yml", + "docker-compose.dev.yml", + "Cargo.lock", + "secrets.yml", + "secrets.yaml", + "docker-compose.staging.yml", + "docker-compose.production.yml", + "yaws-key.pem", + "mysql_config.ini", + "firewall.log", + "log4j.properties", + "serviceAccountCredentials.json", + "haproxy.cfg", + "service-account-credentials.json", + "vpn.log", + "system.log", + "webuser-auth.xml", + "fastcgi.conf", + "smb.conf", + "iis.log", + "pom.xml", + "openapi.json", + "vim_settings.xml", + "winscp.ini", + "ws_ftp.ini", +} + +# Sourced from AikidoSec/firewall-node : library/vulnerabilities/attack-wave-detection/paths/directoryNames.ts +directory_names = { + ".", + "..", + ".anydesk", + ".aptitude", + ".aws", + ".azure", + ".cache", + ".circleci", + ".config", + ".dbus", + ".docker", + ".drush", + ".gem", + ".git", + ".github", + ".gnupg", + ".gsutil", + ".hg", + ".idea", + ".java", + ".kube", + ".lftp", + ".minikube", + ".npm", + ".nvm", + ".pki", + ".snap", + ".ssh", + ".subversion", + ".svn", + ".tconn", + ".thunderbird", + ".tor", + ".vagrant.d", + ".vidalia", + ".vim", + ".vmware", + ".vscode", + "apache", + "apache2", + "grub", + "System32", + "tmp", + "xampp", + "cgi-bin", + "%systemroot%", +} diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/query_params_contain_dangerous_strings.py b/aikido_zen/vulnerabilities/attack_wave_detection/query_params_contain_dangerous_strings.py new file mode 100644 index 000000000..172c11a59 --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/query_params_contain_dangerous_strings.py @@ -0,0 +1,43 @@ +from aikido_zen.context import Context +from aikido_zen.helpers.extract_strings_from_user_input import ( + extract_strings_from_user_input_cached, +) + +keywords = { + "SELECT (CASE WHEN", + "SELECT COUNT(", + "SLEEP(", + "WAITFOR DELAY", + "SELECT LIKE(CHAR(", + "INFORMATION_SCHEMA.COLUMNS", + "INFORMATION_SCHEMA.TABLES", + "MD5(", + "DBMS_PIPE.RECEIVE_MESSAGE", + "SYSIBM.SYSTABLES", + "RANDOMBLOB(", + "SELECT * FROM", + "1'='1", + "PG_SLEEP(", + "UNION ALL SELECT", + "../", +} + + +def query_params_contain_dangerous_strings(context: Context) -> bool: + """ + Check the query for some common SQL or path traversal patterns. + """ + if not context.query: + return False + + for s in extract_strings_from_user_input_cached(context.query, "query"): + # skipping strings that don't match the length, we chose to start with 5 since the + # smaller inputs like `../` and `MD5(` are usually followed with more data. + if len(s) < 5 or len(s) > 1000: + continue + + s_upper = s.upper() + for keyword in keywords: + if keyword.upper() in s_upper: + return True + return False diff --git a/aikido_zen/vulnerabilities/attack_wave_detection/query_params_contain_dangerous_strings_test.py b/aikido_zen/vulnerabilities/attack_wave_detection/query_params_contain_dangerous_strings_test.py new file mode 100644 index 000000000..621df85fa --- /dev/null +++ b/aikido_zen/vulnerabilities/attack_wave_detection/query_params_contain_dangerous_strings_test.py @@ -0,0 +1,70 @@ +import pytest +from .query_params_contain_dangerous_strings import ( + query_params_contain_dangerous_strings, +) + + +class Context: + def __init__(self, query=None, body=None): + self.remote_address = "::1" + self.method = "GET" + self.url = "http://localhost:4000/test" + self.query = query or { + "test": "", + "utmSource": "newsletter", + "utmMedium": "electronicmail", + "utmCampaign": "test", + "utmTerm": "sql_injection", + } + self.headers = {"content-type": "application/json"} + self.body = body or {} + self.cookies = {} + self.route_params = {} + self.source = "express" + self.route = "/test" + + +def get_test_context(query): + return Context( + query={ + "test": query, + **{ + "utmSource": "newsletter", + "utmMedium": "electronicmail", + "utmCampaign": "test", + "utmTerm": "sql_injection", + }, + } + ) + + +def test_detects_injection_patterns(): + test_strings = [ + "' or '1'='1", + "1: SELECT * FROM users WHERE '1'='1'", + "', information_schema.tables", + "1' sleep(5)", + "WAITFOR DELAY 1", + "../etc/passwd", + ] + for s in test_strings: + ctx = get_test_context(s) + assert query_params_contain_dangerous_strings( + ctx + ), f"Expected '{s}' to match patterns" + + +def test_does_not_detect(): + non_matching = ["google.de", "some-string", "1", ""] + for s in non_matching: + ctx = get_test_context(s) + assert not query_params_contain_dangerous_strings( + ctx + ), f"Expected '{s}' to NOT match patterns" + + +def test_handles_empty_query_object(): + ctx = Context(query={}) + assert not query_params_contain_dangerous_strings( + ctx + ), "Expected empty query to NOT match injection patterns" diff --git a/end2end/django_mysql_test.py b/end2end/django_mysql_test.py index ed791dddd..06161b875 100644 --- a/end2end/django_mysql_test.py +++ b/end2end/django_mysql_test.py @@ -110,3 +110,4 @@ def test_initial_heartbeat(): assert req_stats["aborted"] == 0 assert req_stats["rateLimited"] == 0 assert req_stats["attacksDetected"] == {"blocked": 2, "total": 2} + assert req_stats["attackWaves"] == {"total": 0, "blocked": 0}