diff --git a/ansible_base/lib/logging/context.py b/ansible_base/lib/logging/context.py new file mode 100644 index 000000000..7c09c2761 --- /dev/null +++ b/ansible_base/lib/logging/context.py @@ -0,0 +1,50 @@ +import contextvars +import uuid + +# Define the context variables that will hold our trace information. +# Providing a default value is important so that they can be accessed +# even when the context has not been explicitly set. +trace_id_var = contextvars.ContextVar('trace_id', default=None) +origin_var = contextvars.ContextVar('origin', default=None) + + +class trace_context: + """ + A context manager and decorator to set the trace context for non-web operations. + """ + + def __init__(self, origin=None, trace_id=None): + self.origin = origin + self.tokens = [] + + if trace_id: + try: + # Validate that the provided header is a valid UUID + uuid.UUID(trace_id) + self.trace_id = trace_id + except (ValueError, TypeError): + # If it's not a valid UUID, discard it and we'll generate a new one + self.trace_id = str(uuid.uuid4()) + else: + self.trace_id = str(uuid.uuid4()) + + def __enter__(self): + # Set the trace ID for this context + self.tokens.append(trace_id_var.set(self.trace_id)) + + # Set the origin (e.g., 'dispatcher') + if self.origin: + self.tokens.append(origin_var.set(self.origin)) + + def __exit__(self, exc_type, exc_value, traceback): + # Reset the context variables to their previous state + for token in self.tokens: + var = token.var + var.reset(token) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper diff --git a/ansible_base/lib/middleware/observability.py b/ansible_base/lib/middleware/observability.py new file mode 100644 index 000000000..8c607db9c --- /dev/null +++ b/ansible_base/lib/middleware/observability.py @@ -0,0 +1,27 @@ +""" +A single middleware to provide a unified observability layer, ensuring that context, +profiling, and SQL metrics are captured in the correct order. +""" + +from .profiling.profile_request import _ProfileRequestMiddleware, _SQLProfilingMiddleware +from .request_context import _TraceContextMiddleware + + +class ObservabilityMiddleware: + """ + A single entry point for observability middleware. + + This middleware composes the trace context, request profiling, and SQL + profiling middleware in the correct order. Instead of listing all three + in your settings, you can now just add this one. + """ + + def __init__(self, get_response): + # Chain the middleware in the desired order. The request will flow + # from _TraceContextMiddleware -> _ProfileRequestMiddleware -> _SQLProfilingMiddleware. + handler = _SQLProfilingMiddleware(get_response) + handler = _ProfileRequestMiddleware(handler) + self.handler = _TraceContextMiddleware(handler) + + def __call__(self, request): + return self.handler(request) diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md new file mode 100644 index 000000000..eda260750 --- /dev/null +++ b/ansible_base/lib/middleware/profiling/README.md @@ -0,0 +1,139 @@ +# Request Profiling and Observability + +The `ObservabilityMiddleware` provides a simple way to gain performance and debugging insights into your Django application. It acts as a single entry point for several underlying middleware components, ensuring they are always used in the correct order. + +## `ObservabilityMiddleware` + +This single middleware bundles tracing, request profiling, and SQL query analysis. To use it, add it to the top of your `MIDDLEWARE` list in your Django settings. + +```python +# settings.py +MIDDLEWARE = [ + 'ansible_base.lib.middleware.observability.ObservabilityMiddleware', + ... +] +``` + +The middleware always adds the following headers to the response: + +* `X-Request-ID`: A unique identifier for the request. If the incoming request includes an `X-Request-ID` header, that value will be used; otherwise, a new UUID will be generated. +* `X-API-Time`: The total time taken to process the request, in seconds. +* `X-API-Node`: The cluster host ID of the node that served the request. + +### cProfile Support + +When the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting is enabled, the middleware will also perform a cProfile analysis for each request. The resulting `.prof` file is saved to a temporary directory on the node that served the request, and its path is returned in the `X-API-CProfile-File` response header. The filename will include the request's `X-Request-ID`. + +To enable cProfile support, set the following in your Django settings: + +```python +# settings.py +ANSIBLE_BASE_CPROFILE_REQUESTS = True +``` + +> **Note:** Enabling cProfile has significant performance implications and is intended for temporary, live debugging sessions, not for permanent use in production environments. + +### SQL Profiling Support + +When the `ANSIBLE_BASE_SQL_PROFILING` setting is enabled, the middleware provides insights into the database queries executed during a request. It adds the following headers to the response: + +* `X-API-Query-Count`: The total number of database queries executed during the request. +* `X-API-Query-Time`: The total time spent on database queries, in seconds. + +It also injects contextual information as a comment into each SQL query, which is invaluable for debugging and tracing. For example: +`/* trace_id=b71696ed-c483-408d-9740-2e7935b4f2d9, route=api/v2/users/{pk}/, origin=request */ SELECT ...` + +To enable SQL profiling, set the following in your Django settings: + +```python +# settings.py +ANSIBLE_BASE_SQL_PROFILING = True +``` + +> **Note:** This feature is most effective when used in combination with your database's slow query logging capabilities. For high-traffic environments, consider configuring your database to log only a percentage of queries to manage logging overhead. + +## `DABProfiler` + +For profiling non-HTTP contexts, such as background tasks or gRPC services, the `DABProfiler` class can be used directly. + +The profiler's cProfile functionality is controlled by the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting. + +- When the setting is `True`, `profiler.stop()` returns a tuple of `(elapsed_time, cprofile_filename)`. +- When the setting is `False`, `profiler.stop()` returns `(elapsed_time, None)`. + +### Example Usage + +```python +from ansible_base.lib.middleware.profiling.profile_request import DABProfiler + +def my_background_task(): + profiler = DABProfiler() + profiler.start() + + # Your code here + + elapsed, cprofile_filename = profiler.stop() + + if cprofile_filename: + print(f"cProfile data saved to: {cprofile_filename}") + + print(f"Task took {elapsed:.3f}s to complete.") +``` + +## `trace_context` for Background Tasks + +For adding observability to non-HTTP contexts without the overhead of the `DABProfiler`, the `trace_context` context manager is the ideal tool. It ensures that background tasks can be traced with a unique request ID, just like the `ObservabilityMiddleware` does for web requests. + +This is particularly useful for background tasks, such as those initiated by the controller's dispatcher, where you want to correlate all log messages for a specific operation. + +### Example Usage + +Here's how you might use the `trace_context` manager in the controller's dispatcher to ensure that all work related to a specific job has a consistent trace ID. + +```python +# In a hypothetical controller dispatcher task +from ansible_base.lib.logging.context import trace_context + +def run_job(job_id, parent_trace_id=None): + """ + A background task that runs a job. + """ + # Use the parent_trace_id if it exists; otherwise, a new one will be generated. + # The origin is a string that identifies the source of the trace. + with trace_context(origin='controller_dispatcher', trace_id=parent_trace_id): + # All logging within this block will now have the same trace_id. + # logger.info(f"Starting job {job_id}") + # ... do work ... + # logger.info(f"Finished job {job_id}") + pass +``` + +## Visualizing Profile Data + +The `.prof` files generated by the cProfile support can be analyzed with a variety of tools. + +### SnakeViz + +[SnakeViz](https://jiffyclub.github.io/snakeviz/) is a browser-based graphical viewer for the output of Python profilers. + +You can install it with pip: +```bash +pip install snakeviz +``` + +To visualize a profile file, run: +```bash +snakeviz /path/to/your/profile.prof +``` + +### pstats + +The standard library `pstats` module can also be used to read and manipulate profile data. + +```python +import pstats + +p = pstats.Stats('/path/to/your/profile.prof') +p.sort_stats('cumulative').print_stats(10) +``` + diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py new file mode 100644 index 000000000..a59525c17 --- /dev/null +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -0,0 +1,163 @@ +import cProfile +import logging +import os +import tempfile +import threading +import time +import uuid +from typing import Optional, Union +from urllib.parse import quote + +from django.conf import settings +from django.db import connection +from django.utils.translation import gettext_lazy as _ + +from ansible_base.lib.logging.context import origin_var, trace_id_var +from ansible_base.lib.utils.settings import get_function_from_setting, get_setting + +logger = logging.getLogger(__name__) + + +class DABProfiler: + def __init__(self, *args, **kwargs): + self.cprofiling = bool(get_setting('ANSIBLE_BASE_CPROFILE_REQUESTS', False)) + self.prof = None + self.start_time = None + + def start(self): + self.start_time = time.time() + if self.cprofiling: + self.prof = cProfile.Profile() + self.prof.enable() + + def stop(self, profile_id: Optional[Union[str, uuid.UUID]] = None): + if self.start_time is None: + logger.debug("Attempting to stop profiling without having started...") + return None, None + + elapsed = time.time() - self.start_time + + if not profile_id: + profile_id = uuid.uuid4() + + cprofile_filename = None + + if self.cprofiling and self.prof: + self.prof.disable() + temp_dir = tempfile.gettempdir() + filename = f"cprofile-{profile_id}.prof" + cprofile_filename = os.path.join(temp_dir, filename) + self.prof.dump_stats(cprofile_filename) + + return elapsed, cprofile_filename + + +class _ProfileRequestMiddleware(threading.local): + def __init__(self, get_response=None): + self.get_response = get_response + self.profiler = DABProfiler() + + def __call__(self, request): + # Logic before the view (formerly process_request) + self.profiler.start() + request_id = trace_id_var.get() + + # Call the next middleware or the view + response = self.get_response(request) + + # Logic after the view (formerly process_response) + if getattr(self.profiler, 'start_time', None) is None: + return response + + elapsed, cprofile_filename = self.profiler.stop(profile_id=request_id) + + if elapsed is not None: + response['X-API-Time'] = f'{elapsed:.3f}s' + if 'X-API-Node' not in response: + response['X-API-Node'] = get_setting('CLUSTER_HOST_ID', _('Unknown')) + + if cprofile_filename: + response['X-API-CProfile-File'] = cprofile_filename + logger.debug( + f'request: {request}, cprofile_file: {response["X-API-CProfile-File"]}', + extra=dict(python_objects=dict(request=request, response=response, X_API_CPROFILE_FILE=response["X-API-CProfile-File"])), + ) + + return response + + +# Define the maximum length for a value in a SQL comment +SQL_COMMENT_MAX_LENGTH = 256 + + +def _sanitize_for_sql_comment(value: str) -> str: + """ + Sanitizes a string for safe inclusion in a SQL comment. + + - URL-encodes the value to handle special characters. + - Escapes the '%' character to prevent conflicts with database placeholders. + - Truncates the string to a maximum length. + """ + # URL-encode the value + quoted_value = quote(str(value)) + # Escape the '%' character for the database driver + sanitized_value = quoted_value.replace('%', '%%') + # Truncate to the maximum length + return sanitized_value[:SQL_COMMENT_MAX_LENGTH] + + +class SQLQueryMetrics: + def __init__(self, request=None): + self.request = request + self.query_count = 0 + self.query_time = 0.0 + + def __call__(self, execute, sql, params, many, context): + # Build the context comment + context_items = [] + # trace_id is already validated as a UUID, so it is safe + if trace_id := trace_id_var.get(): + context_items.append(f"trace_id='{trace_id}'") + + # The route is only available after the URL resolver has run + if self.request and getattr(self.request, 'resolver_match', None): + if route := self.request.resolver_match.route: + context_items.append(f"route='{_sanitize_for_sql_comment(route)}'") + + if origin := origin_var.get(): + context_items.append(f"origin='{_sanitize_for_sql_comment(origin)}'") + + if context_items: + comment = f"/* {', '.join(context_items)} */" + sql = f"{comment} {sql}" + + start_time = time.time() + try: + return execute(sql, params, many, context) + finally: + self.query_count += 1 + self.query_time += time.time() - start_time + + +class _SQLProfilingMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if not get_setting('ANSIBLE_BASE_SQL_PROFILING', False): + return self.get_response(request) + + # Check if the trace context is available. If not, log a warning. + if trace_id_var.get() is None: + logger.warning( + "ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. " + "Please use the ObservabilityMiddleware instead of including profiling middleware individually." + ) + + metrics = SQLQueryMetrics(request) + with connection.execute_wrapper(metrics): + response = self.get_response(request) + + response['X-API-Query-Count'] = metrics.query_count + response['X-API-Query-Time'] = f'{metrics.query_time:.3f}s' + return response diff --git a/ansible_base/lib/middleware/request_context.py b/ansible_base/lib/middleware/request_context.py new file mode 100644 index 000000000..a49142cb7 --- /dev/null +++ b/ansible_base/lib/middleware/request_context.py @@ -0,0 +1,41 @@ +import uuid + +from ansible_base.lib.logging.context import origin_var, trace_id_var + + +class _TraceContextMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + # Set the context for the request and store the tokens + origin_token = origin_var.set('request') + + # Get the request ID from the header + header_trace_id = request.headers.get('x-request-id') + trace_id = None + + if header_trace_id: + try: + # Validate that the provided header is a valid UUID + uuid.UUID(header_trace_id) + trace_id = header_trace_id + except ValueError: + # If it's not a valid UUID, discard it and we'll generate a new one + pass + + # If no valid trace_id was found, generate a new one + if not trace_id: + trace_id = str(uuid.uuid4()) + + trace_id_token = trace_id_var.set(trace_id) + + try: + response = self.get_response(request) + response['X-Request-ID'] = trace_id + finally: + # Reset the context variables to their previous state + origin_var.reset(origin_token) + trace_id_var.reset(trace_id_token) + + return response diff --git a/test_app/tests/lib/logging/test_context.py b/test_app/tests/lib/logging/test_context.py new file mode 100644 index 000000000..b66d03599 --- /dev/null +++ b/test_app/tests/lib/logging/test_context.py @@ -0,0 +1,143 @@ +import random +import threading +import time +import unittest +import uuid + +import pytest + +from ansible_base.lib.logging.context import origin_var, trace_context, trace_id_var + + +class TestTraceContextThreadSafety(unittest.TestCase): + """ + Tests the thread safety of context variables. + """ + + def test_trace_id_is_thread_safe(self): + """ + Verify that the trace_id context variable is thread-safe and does not leak between threads. + """ + results = [] + + def target_function(thread_id): + # Set a unique trace ID for this thread + trace_id_var.set(f"trace-id-{thread_id}") + # Sleep for a random, short duration to encourage thread interleaving + time.sleep(random.uniform(0.01, 0.05)) + # Get the trace ID and verify it has not been changed by another thread + retrieved_id = trace_id_var.get() + # Store the result of the check for the main thread to verify + results.append(retrieved_id == f"trace-id-{thread_id}") + + threads = [] + for i in range(10): + thread = threading.Thread(target=target_function, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Verify that all threads successfully retrieved their own context + self.assertEqual(len(results), 10, "Not all threads completed successfully.") + self.assertTrue(all(results), "Context leaked between threads.") + + +class TestTraceContext: + """ + Tests the functionality of the trace_context context manager and decorator. + """ + + def test_generates_trace_id(self): + """ + Test that the context manager generates a new trace_id when none is provided. + """ + assert trace_id_var.get() is None + with trace_context(origin='test_origin'): + generated_id = trace_id_var.get() + assert generated_id is not None + assert isinstance(uuid.UUID(generated_id), uuid.UUID) + assert trace_id_var.get() is None + + def test_uses_provided_trace_id(self): + """ + Test that the context manager uses the trace_id that is passed to it. + """ + provided_id = str(uuid.uuid4()) + assert trace_id_var.get() is None + with trace_context(origin='test_origin', trace_id=provided_id): + assert trace_id_var.get() == provided_id + assert trace_id_var.get() is None + + def test_handles_invalid_trace_id(self): + """ + Test that the context manager generates a new trace_id if the provided one is invalid. + """ + invalid_id = 'not-a-uuid' + assert trace_id_var.get() is None + with trace_context(origin='test_origin', trace_id=invalid_id): + generated_id = trace_id_var.get() + assert generated_id is not None + assert generated_id != invalid_id + assert isinstance(uuid.UUID(generated_id), uuid.UUID) + assert trace_id_var.get() is None + + def test_resets_context_on_exception(self): + """ + Test that context variables are reset even if an exception is raised. + """ + assert trace_id_var.get() is None + with pytest.raises(ValueError): + with trace_context(origin='test_exception'): + raise ValueError("Test exception") + assert trace_id_var.get() is None + + def test_as_decorator(self): + """ + Test that the trace_context decorator sets and clears context correctly. + """ + + @trace_context(origin='test_decorator') + def my_function(): + assert trace_id_var.get() is not None + assert origin_var.get() == 'test_decorator' + + assert trace_id_var.get() is None + my_function() + assert trace_id_var.get() is None + + def test_decorator_with_provided_id(self): + """ + Test that the trace_context decorator uses a provided trace_id. + """ + provided_id = str(uuid.uuid4()) + + @trace_context(origin='test_decorator_id', trace_id=provided_id) + def my_function(): + assert trace_id_var.get() == provided_id + assert origin_var.get() == 'test_decorator_id' + + assert trace_id_var.get() is None + my_function() + assert trace_id_var.get() is None + + def test_nested_trace_context(self): + """ + Test that nested trace_context managers work correctly, restoring the previous context. + """ + outer_id = str(uuid.uuid4()) + with trace_context(origin='outer', trace_id=outer_id): + assert trace_id_var.get() == outer_id + assert origin_var.get() == 'outer' + + with trace_context(origin='inner'): + inner_id = trace_id_var.get() + assert inner_id is not None + assert inner_id != outer_id + assert origin_var.get() == 'inner' + + assert trace_id_var.get() == outer_id + assert origin_var.get() == 'outer' + + assert trace_id_var.get() is None diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py new file mode 100644 index 000000000..86e59e62d --- /dev/null +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -0,0 +1,242 @@ +import os +import tempfile +import uuid +from unittest.mock import patch + +from django.http import HttpResponse +from django.test import TestCase, override_settings +from django.urls import path + +from ansible_base.lib.middleware.observability import ObservabilityMiddleware +from ansible_base.lib.middleware.profiling.profile_request import ( + SQLQueryMetrics, + _ProfileRequestMiddleware, + _SQLProfilingMiddleware, +) +from test_app.models import Organization, User + + +# A simple view for testing middleware +def simple_view(request): + return HttpResponse("OK") + + +# A view that performs a database query +def db_view(request): + # Get or create an organization to guarantee at least one query is executed. + Organization.objects.get_or_create(name=f"test-org-{uuid.uuid4()}") + return HttpResponse("OK") + + +# Define URL patterns for the test +urlpatterns = [ + path('test/', simple_view), + path('test-db/', db_view), +] + + +@override_settings(ROOT_URLCONF=__name__) +class _ProfileRequestMiddlewareTest(TestCase): + @override_settings(CLUSTER_HOST_ID='test-node') + def test_profile_request_middleware_headers(self): + """ + Test that the _ProfileRequestMiddleware adds sensible headers. + """ + middleware = _ProfileRequestMiddleware(simple_view) + response = middleware(self.client.get('/test/').wsgi_request) + + # Test X-API-Time + self.assertIn('X-API-Time', response) + self.assertTrue(response['X-API-Time'].endswith('s')) + try: + float(response['X-API-Time'][:-1]) + except ValueError: + self.fail("X-API-Time value is not a valid float") + + # Test X-API-Node + self.assertIn('X-API-Node', response) + self.assertEqual(response['X-API-Node'], 'test-node') + + @override_settings(ANSIBLE_BASE_CPROFILE_REQUESTS=True) + def test_profile_request_middleware_cprofile_enabled(self): + """ + Test that the _ProfileRequestMiddleware adds the X-API-CProfile-File + header and creates a profile file when enabled. + """ + with tempfile.TemporaryDirectory() as tmpdir: + with patch('tempfile.gettempdir', return_value=tmpdir): + middleware = _ProfileRequestMiddleware(simple_view) + response = middleware(self.client.get('/test/').wsgi_request) + self.assertIn('X-API-CProfile-File', response) + profile_file = response['X-API-CProfile-File'] + self.assertTrue(profile_file.endswith('.prof')) + self.assertTrue(os.path.exists(profile_file)) + + @override_settings(ANSIBLE_BASE_CPROFILE_REQUESTS=False) + def test_profile_request_middleware_cprofile_disabled(self): + """ + Test that the _ProfileRequestMiddleware does not add the + X-API-CProfile-File header when disabled. + """ + middleware = _ProfileRequestMiddleware(simple_view) + response = middleware(self.client.get('/test/').wsgi_request) + self.assertNotIn('X-API-CProfile-File', response) + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'ansible_base.lib.middleware.request_context._TraceContextMiddleware', + 'ansible_base.lib.middleware.profiling.profile_request._SQLProfilingMiddleware', + ], +) +class _SQLProfilingMiddlewareTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username='testuser', password='password') + self.client.force_login(self.user) + + @override_settings(ANSIBLE_BASE_SQL_PROFILING=False) + def test_sql_profiling_disabled_by_default(self): + response = self.client.get('/test-db/') + self.assertNotIn('X-API-Query-Count', response) + self.assertNotIn('X-API-Query-Time', response) + + @override_settings(ANSIBLE_BASE_SQL_PROFILING=True) + def test_sql_profiling_enabled_with_new_setting(self): + response = self.client.get('/test-db/') + self.assertIn('X-API-Query-Count', response) + self.assertGreaterEqual(int(response['X-API-Query-Count']), 1) + self.assertIn('X-API-Query-Time', response) + self.assertTrue(response['X-API-Query-Time'].endswith('s')) + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'ansible_base.lib.middleware.profiling.profile_request._SQLProfilingMiddleware', + ], + ANSIBLE_BASE_SQL_PROFILING=True, +) +class _SQLProfilingMiddlewareMissingContextTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username='testuser', password='password') + self.client.force_login(self.user) + + @patch('ansible_base.lib.middleware.profiling.profile_request.logger') + def test_logs_warning_if_context_middleware_is_missing(self, mock_logger): + self.client.get('/test-db/') + mock_logger.warning.assert_called_with( + "ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. " + "Please use the ObservabilityMiddleware instead of including profiling middleware individually." + ) + + +class SQLQueryMetricsTest(TestCase): + def test_sql_comment_injection(self): + from django.test.client import RequestFactory + + from ansible_base.lib.logging.context import origin_var, trace_id_var + + # 1. Manually set the context, saving the tokens to reset it later. + trace_id_token = trace_id_var.set("test-trace-id") + origin_token = origin_var.set("test-origin") + + # 2. Create a mock request and manually set the resolver_match + factory = RequestFactory() + request = factory.get('/test-db/') + request.resolver_match = type('ResolverMatch', (), {'route': 'test/route'}) + + try: + # 3. Instantiate our metrics class and call it directly. + metrics = SQLQueryMetrics(request) + original_sql = "SELECT 1" + modified_sql = "" + + def mock_execute(sql, params, many, context): + nonlocal modified_sql + modified_sql = sql + return None + + metrics(mock_execute, original_sql, [], False, {}) + + # 4. Assert that the SQL passed to our mock was correctly modified. + self.assertIn("/*", modified_sql) + self.assertIn("trace_id='test-trace-id'", modified_sql) + self.assertIn("route='test/route'", modified_sql) + self.assertIn("origin='test-origin'", modified_sql) + self.assertIn("*/", modified_sql) + self.assertIn(original_sql, modified_sql) + finally: + # 5. Reset the context variables to their previous state. + trace_id_var.reset(trace_id_token) + origin_var.reset(origin_token) + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'ansible_base.lib.middleware.observability.ObservabilityMiddleware', + ], + ANSIBLE_BASE_SQL_PROFILING=True, + ANSIBLE_BASE_CPROFILE_REQUESTS=True, + CLUSTER_HOST_ID='test-node-obs', +) +class ObservabilityMiddlewareTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username='testuser', password='password') + self.client.force_login(self.user) + + def test_observability_middleware_all_headers(self): + """ + An integration test to ensure the facade middleware adds all expected + headers and uses the request ID consistently. + """ + request_id = str(uuid.uuid4()) + with tempfile.TemporaryDirectory() as tmpdir: + with patch('tempfile.gettempdir', return_value=tmpdir): + response = self.client.get('/test-db/', HTTP_X_REQUEST_ID=request_id) + + # 1. From _TraceContextMiddleware: Check response header + self.assertIn('X-Request-ID', response) + self.assertEqual(response['X-Request-ID'], request_id) + + # 2. From _ProfileRequestMiddleware: Check profiling headers and filename + self.assertIn('X-API-Time', response) + self.assertIn('X-API-Node', response) + self.assertEqual(response['X-API-Node'], 'test-node-obs') + self.assertIn('X-API-CProfile-File', response) + self.assertIn(request_id, response['X-API-CProfile-File']) + self.assertTrue(os.path.exists(response['X-API-CProfile-File'])) + + # 3. From _SQLProfilingMiddleware: Check SQL headers + self.assertIn('X-API-Query-Count', response) + self.assertIn('X-API-Query-Time', response) + + +class SQLCommentSanitizationTest(TestCase): + def test_sanitization_escapes_disallowed_chars(self): + from ansible_base.lib.middleware.profiling.profile_request import _sanitize_for_sql_comment + + malicious_string = "*/; DROP TABLE users; --" + sanitized = _sanitize_for_sql_comment(malicious_string) + self.assertEqual(sanitized, "%%2A/%%3B%%20DROP%%20TABLE%%20users%%3B%%20--") + + def test_sanitization_allows_safe_chars(self): + from ansible_base.lib.middleware.profiling.profile_request import _sanitize_for_sql_comment + + safe_string = "a-b_c.d/e123" + sanitized = _sanitize_for_sql_comment(safe_string) + self.assertEqual(sanitized, "a-b_c.d/e123") + + def test_sanitization_truncates_long_strings(self): + from ansible_base.lib.middleware.profiling.profile_request import SQL_COMMENT_MAX_LENGTH, _sanitize_for_sql_comment + + long_string = "a" * (SQL_COMMENT_MAX_LENGTH + 100) + sanitized = _sanitize_for_sql_comment(long_string) + self.assertEqual(len(sanitized), SQL_COMMENT_MAX_LENGTH) diff --git a/test_app/tests/lib/middleware/test_request_context.py b/test_app/tests/lib/middleware/test_request_context.py new file mode 100644 index 000000000..cb75f8038 --- /dev/null +++ b/test_app/tests/lib/middleware/test_request_context.py @@ -0,0 +1,106 @@ +import json +from uuid import UUID, uuid4 + +from django.http import JsonResponse +from django.test import TestCase, override_settings +from django.urls import path + +from ansible_base.lib.logging.context import trace_id_var + + +# A simple view for testing middleware +def context_view(request): + return JsonResponse({"trace_id": trace_id_var.get()}) + + +# Define URL patterns for the test +urlpatterns = [ + path('context/', context_view), +] + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'ansible_base.lib.middleware.request_context._TraceContextMiddleware', + ], +) +class TraceContextMiddlewareTest(TestCase): + def test_uses_x_request_id_header(self): + """ + Test that the middleware uses the X-Request-ID from the request header + if it is provided. + """ + self.assertIsNone(trace_id_var.get()) + request_id = str(uuid4()) + + response = self.client.get('/context/', HTTP_X_REQUEST_ID=request_id) + self.assertEqual(response.status_code, 200) + + # The trace_id in the view should match the header + trace_id = json.loads(response.content).get("trace_id") + self.assertEqual(trace_id, request_id) + + # After the request, the context should be reset + self.assertIsNone(trace_id_var.get()) + + def test_generates_id_if_header_is_missing(self): + """ + Test that the middleware generates a new UUID if the X-Request-ID + header is not provided. + """ + self.assertIsNone(trace_id_var.get()) + + response = self.client.get('/context/') + self.assertEqual(response.status_code, 200) + + # The trace_id should be a valid UUID + trace_id = json.loads(response.content).get("trace_id") + self.assertIsNotNone(trace_id) + try: + UUID(trace_id, version=4) + except ValueError: + self.fail("trace_id is not a valid UUID4") + + # After the request, the context should be reset + self.assertIsNone(trace_id_var.get()) + + def test_context_does_not_bleed_between_requests(self): + """ + Test that the trace_id from one request does not bleed into the next. + """ + # First request has an X-Request-ID + request_id = str(uuid4()) + self.client.get('/context/', HTTP_X_REQUEST_ID=request_id) + + # Second request does not have the header + response = self.client.get('/context/') + self.assertEqual(response.status_code, 200) + + # The trace_id of the second request should be a new, generated UUID, + # not the one from the first request's header. + trace_id_2 = json.loads(response.content).get("trace_id") + self.assertIsNotNone(trace_id_2) + self.assertNotEqual(trace_id_2, request_id) + try: + UUID(trace_id_2, version=4) + except ValueError: + self.fail("trace_id_2 is not a valid UUID4") + + def test_discards_invalid_uuid_in_header(self): + """ + Test that the middleware discards an invalid UUID in the X-Request-ID + header and generates a new, valid one. + """ + malicious_id = "not-a-uuid' --" + response = self.client.get('/context/', HTTP_X_REQUEST_ID=malicious_id) + self.assertEqual(response.status_code, 200) + + # The trace_id in the response should be a new, valid UUID, not the malicious one. + new_trace_id = response.headers.get("X-Request-ID") + self.assertIsNotNone(new_trace_id) + self.assertNotEqual(new_trace_id, malicious_id) + try: + UUID(new_trace_id, version=4) + except ValueError: + self.fail("The new trace_id is not a valid UUID4")