From 0184335dddb873b51d60a59c6f50c9212a029ca8 Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Fri, 8 Aug 2025 14:33:26 -0400 Subject: [PATCH 1/7] feat(profiling): Add request profiling middleware This commit introduces `ProfileRequestMiddleware` to provide performance insights for API requests, migrating and generalizing functionality from AWX's `TimingMiddleware` and `AWXProfiler`. This makes the functionality available to any consumer of `django-ansible-base`. The middleware always adds an `X-API-Time` header to the response, indicating the total time taken to process the request. To help locate profiling data, it also adds an `X-API-Node` header with the cluster host ID if one is not already present in the response. When the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting is enabled, the middleware will also perform a cProfile analysis for each request. The resulting `.prof` file is saved to a temporary directory on the node that served the request, and its path is returned in the `X-API-CProfile-File` response header. For the most accurate and reliable timing, it is recommended to add this middleware to the top of the `MIDDLEWARE` list in your settings. The core profiling logic is encapsulated in the `DABProfiler` class, which can be imported and used directly for profiling non-HTTP contexts, such as background tasks. This commit also includes a `README.md` file with documentation for the new middleware and profiler, including usage examples and tests in the test_app. --- .../lib/middleware/profiling/README.md | 88 +++++++++++++++++++ .../middleware/profiling/profile_request.py | 82 +++++++++++++++++ .../middleware/test_profiling_middleware.py | 67 ++++++++++++++ 3 files changed, 237 insertions(+) create mode 100644 ansible_base/lib/middleware/profiling/README.md create mode 100644 ansible_base/lib/middleware/profiling/profile_request.py create mode 100644 test_app/tests/lib/middleware/test_profiling_middleware.py diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md new file mode 100644 index 000000000..3d01e29d8 --- /dev/null +++ b/ansible_base/lib/middleware/profiling/README.md @@ -0,0 +1,88 @@ +# Request Profiling + +The `ProfileRequestMiddleware` and `DABProfiler` class provide a way to profile requests and other code in your Django application. This functionality is a generalization of the profiling tools found in AWX and can be used by any `django-ansible-base` consumer. + +## `ProfileRequestMiddleware` + +This middleware provides performance insights for API requests. To use it, add it to your `MIDDLEWARE` list in your Django settings. For the most accurate and reliable timing, it is recommended to add this middleware to the top of the `MIDDLEWARE` list. + +```python +# settings.py +MIDDLEWARE = [ + 'ansible_base.lib.middleware.profiling.profile_request.ProfileRequestMiddleware', + ... +] +``` + +The middleware always adds the following headers to the response: + +* `X-API-Time`: The total time taken to process the request, in seconds. +* `X-API-Node`: The cluster host ID of the node that served the request. This header is only added if it is not already present in the response. + +### cProfile Support + +When the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting is enabled, the middleware will also perform a cProfile analysis for each request. The resulting `.prof` file is saved to a temporary directory on the node that served the request, and its path is returned in the `X-API-CProfile-File` response header. + +To enable cProfile support, set the following in your Django settings: + +```python +# settings.py +ANSIBLE_BASE_CPROFILE_REQUESTS = True +``` + +## `DABProfiler` + +The core profiling logic is encapsulated in the `DABProfiler` class. This class can be imported and used directly for profiling non-HTTP contexts, such as background tasks or gRPC services. + +The profiler's cProfile functionality is controlled by the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting. + +- When the setting is `True`, `profiler.stop()` returns a tuple of `(elapsed_time, cprofile_filename)`. +- When the setting is `False`, `profiler.stop()` returns `(elapsed_time, None)`. + +### Example Usage + +```python +from ansible_base.lib.middleware.profiling.profile_request import DABProfiler + +def my_background_task(): + profiler = DABProfiler() + profiler.start() + + # Your code here + + elapsed, cprofile_filename = profiler.stop() + + if cprofile_filename: + print(f"cProfile data saved to: {cprofile_filename}") + + print(f"Task took {elapsed:.3f}s to complete.") +``` + +## Visualizing Profile Data + +The `.prof` files generated by the cProfile support can be analyzed with a variety of tools. + +### SnakeViz + +[SnakeViz](https://jiffyclub.github.io/snakeviz/) is a browser-based graphical viewer for the output of Python profilers. + +You can install it with pip: +```bash +pip install snakeviz +``` + +To visualize a profile file, run: +```bash +snakeviz /path/to/your/profile.prof +``` + +### pstats + +The standard library `pstats` module can also be used to read and manipulate profile data. + +```python +import pstats + +p = pstats.Stats('/path/to/your/profile.prof') +p.sort_stats('cumulative').print_stats(10) +``` diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py new file mode 100644 index 000000000..c20a8c658 --- /dev/null +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -0,0 +1,82 @@ +import cProfile +import logging +import os +import tempfile +import threading +import time +import uuid +from typing import Optional, Union + +from django.utils.translation import gettext_lazy as _ + +from ansible_base.lib.utils.settings import get_function_from_setting, get_setting + +logger = logging.getLogger(__name__) + + +class DABProfiler: + def __init__(self, *args, **kwargs): + self.cprofiling = bool(get_setting('ANSIBLE_BASE_CPROFILE_REQUESTS', False)) + self.prof = None + self.start_time = None + + def start(self): + self.start_time = time.time() + if self.cprofiling: + self.prof = cProfile.Profile() + self.prof.enable() + + def stop(self, profile_id: Optional[Union[str, uuid.UUID]] = None): + if self.start_time is None: + logger.debug("Attempting to stop profiling without having started...") + return None, None + + elapsed = time.time() - self.start_time + + if not profile_id: + profile_id = uuid.uuid4() + + cprofile_filename = None + + if self.cprofiling and self.prof: + self.prof.disable() + temp_dir = tempfile.gettempdir() + filename = f"cprofile-{profile_id}.prof" + cprofile_filename = os.path.join(temp_dir, filename) + self.prof.dump_stats(cprofile_filename) + + return elapsed, cprofile_filename + + +class ProfileRequestMiddleware(threading.local): + def __init__(self, get_response=None): + self.get_response = get_response + self.profiler = DABProfiler() + + def __call__(self, request): + # Logic before the view (formerly process_request) + self.profiler.start() + request_id = request.headers.get('X-Request-ID') + + # Call the next middleware or the view + response = self.get_response(request) + + # Logic after the view (formerly process_response) + if getattr(self.profiler, 'start_time', None) is None: + return response + + elapsed, cprofile_filename = self.profiler.stop(profile_id=request_id) + + if elapsed is not None: + response['X-API-Time'] = f'{elapsed:.3f}s' + if 'X-API-Node' not in response: + response['X-API-Node'] = get_setting('CLUSTER_HOST_ID', _('Unknown')) + + if cprofile_filename: + response['X-API-CProfile-File'] = cprofile_filename + logger.debug( + f'request: {request}, cprofile_file: {response["X-API-CProfile-File"]}', + extra=dict(python_objects=dict(request=request, response=response, X_API_CPROFILE_FILE=response["X-API-CProfile-File"])), + ) + + return response \ No newline at end of file diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py new file mode 100644 index 000000000..11f53c743 --- /dev/null +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -0,0 +1,67 @@ + +import os +import tempfile +from unittest.mock import patch + +from django.http import HttpResponse +from django.test import TestCase, override_settings +from django.urls import path + +from ansible_base.lib.middleware.profiling.profile_request import ProfileRequestMiddleware + +# A simple view for testing middleware +def simple_view(request): + return HttpResponse("OK") + +# Define URL patterns for the test +urlpatterns = [ + path('test/', simple_view), +] + +@override_settings(ROOT_URLCONF=__name__) +class ProfileRequestMiddlewareTest(TestCase): + @override_settings(CLUSTER_HOST_ID='test-node') + def test_profile_request_middleware_headers(self): + """ + Test that the ProfileRequestMiddleware adds sensible headers. + """ + middleware = ProfileRequestMiddleware(simple_view) + response = middleware(self.client.get('/test/').wsgi_request) + + # Test X-API-Time + self.assertIn('X-API-Time', response) + self.assertTrue(response['X-API-Time'].endswith('s')) + try: + float(response['X-API-Time'][:-1]) + except ValueError: + self.fail("X-API-Time value is not a valid float") + + # Test X-API-Node + self.assertIn('X-API-Node', response) + self.assertEqual(response['X-API-Node'], 'test-node') + + @override_settings(ANSIBLE_BASE_CPROFILE_REQUESTS=True) + def test_profile_request_middleware_cprofile_enabled(self): + """ + Test that the ProfileRequestMiddleware adds the X-API-CProfile-File + header and creates a profile file when enabled. + """ + with tempfile.TemporaryDirectory() as tmpdir: + with patch('tempfile.gettempdir', return_value=tmpdir): + middleware = ProfileRequestMiddleware(simple_view) + response = middleware(self.client.get('/test/').wsgi_request) + self.assertIn('X-API-CProfile-File', response) + profile_file = response['X-API-CProfile-File'] + self.assertTrue(profile_file.endswith('.prof')) + self.assertTrue(os.path.exists(profile_file)) + + def test_profile_request_middleware_cprofile_disabled(self): + """ + Test that the ProfileRequestMiddleware does not add the + X-API-CProfile-File header when disabled. + """ + middleware = ProfileRequestMiddleware(simple_view) + response = middleware(self.client.get('/test/').wsgi_request) + self.assertNotIn('X-API-CProfile-File', response) + + From 2632ad97a5c7820c4f1092d62d03c03de9d83d3c Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Mon, 11 Aug 2025 18:15:18 -0400 Subject: [PATCH 2/7] feat(profiling): Add reusable SQL profiling middleware This commit refactors the SQL query profiling logic out of the AWX controller's API view and into a reusable middleware in the django-ansible-base library. This makes the feature available to all AAP components, such as the gateway and EDA. SQLProfilingMiddleware adds `X-API-Query-Count` and `X-API-Query-Time` headers to API responses. Changes include: - A new `SQLProfilingMiddleware` class in the profiling module. - Configuration is controlled by a new, namespaced setting, `ANSIBLE_BASE_SQL_PROFILING`, with a fallback to the legacy `SQL_DEBUG` setting for backward compatibility. - The middleware logs a warning if it is enabled while Django's `DEBUG` setting is `False`, as query logging is disabled in that state. - Unit tests have been added to ensure correct functionality and prevent regressions. - The profiling documentation has been updated to include usage instructions for the new middleware. --- .../lib/middleware/profiling/README.md | 32 +++++++++ .../middleware/profiling/profile_request.py | 27 +++++++- .../middleware/test_profiling_middleware.py | 69 ++++++++++++++++++- 3 files changed, 125 insertions(+), 3 deletions(-) diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md index 3d01e29d8..fd373a34d 100644 --- a/ansible_base/lib/middleware/profiling/README.md +++ b/ansible_base/lib/middleware/profiling/README.md @@ -30,6 +30,38 @@ To enable cProfile support, set the following in your Django settings: ANSIBLE_BASE_CPROFILE_REQUESTS = True ``` +## `SQLProfilingMiddleware` + +This middleware provides insights into the database queries executed during a request. When enabled, it adds the following headers to the response: + +* `X-API-Query-Count`: The total number of database queries executed during the request. +* `X-API-Query-Time`: The total time spent on database queries, in seconds. + +To use it, add it to your `MIDDLEWARE` list in your Django settings: + +```python +# settings.py +MIDDLEWARE = [ + ... + 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', + ... +] +``` + +### Enabling SQL Profiling + +**Important:** This middleware relies on Django's `connection.queries` list, which is only populated when `settings.DEBUG` is set to `True`. Therefore, you must have `DEBUG = True` in your Django settings for this middleware to have any effect. + +The middleware is controlled by the `ANSIBLE_BASE_SQL_PROFILING` setting. For backwards compatibility, it will also be enabled if the standard Django `SQL_DEBUG` setting is `True`. + +To enable SQL profiling, set the following in your Django settings: + +```python +# settings.py +ANSIBLE_BASE_SQL_PROFILING = True +DEBUG = True +``` + ## `DABProfiler` The core profiling logic is encapsulated in the `DABProfiler` class. This class can be imported and used directly for profiling non-HTTP contexts, such as background tasks or gRPC services. diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py index c20a8c658..585ca4396 100644 --- a/ansible_base/lib/middleware/profiling/profile_request.py +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -7,10 +7,13 @@ import uuid from typing import Optional, Union +from django.db import connection +from django.conf import settings from django.utils.translation import gettext_lazy as _ from ansible_base.lib.utils.settings import get_function_from_setting, get_setting + logger = logging.getLogger(__name__) @@ -79,4 +82,26 @@ def __call__(self, request): extra=dict(python_objects=dict(request=request, response=response, X_API_CPROFILE_FILE=response["X-API-CProfile-File"])), ) - return response \ No newline at end of file + return response + + +class SQLProfilingMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + sql_profiling_enabled = get_setting('ANSIBLE_BASE_SQL_PROFILING', get_setting('SQL_DEBUG', False)) + if sql_profiling_enabled: + if not settings.DEBUG: + logger.warning("ANSIBLE_BASE_SQL_PROFILING is enabled, but DEBUG is False. No SQL queries will be logged or counted.") + return self.get_response(request) + + queries_before = len(connection.queries) + response = self.get_response(request) + q_times = [float(q['time']) for q in connection.queries[queries_before:]] + response['X-API-Query-Count'] = len(q_times) + response['X-API-Query-Time'] = '%0.3fs' % sum(q_times) + else: + response = self.get_response(request) + + return response diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py index 11f53c743..5e07a04cd 100644 --- a/test_app/tests/lib/middleware/test_profiling_middleware.py +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -1,4 +1,4 @@ - +import uuid import os import tempfile from unittest.mock import patch @@ -7,15 +7,24 @@ from django.test import TestCase, override_settings from django.urls import path -from ansible_base.lib.middleware.profiling.profile_request import ProfileRequestMiddleware +from ansible_base.lib.middleware.profiling.profile_request import ProfileRequestMiddleware, SQLProfilingMiddleware +from ansible_base.lib.utils.settings import get_setting +from test_app.models import User # A simple view for testing middleware def simple_view(request): return HttpResponse("OK") +# A view that performs a database query +def db_view(request): + # Create a user with a unique username to guarantee a query. + User.objects.create(username=f"test-{uuid.uuid4()}") + return HttpResponse("OK") + # Define URL patterns for the test urlpatterns = [ path('test/', simple_view), + path('test-db/', db_view), ] @override_settings(ROOT_URLCONF=__name__) @@ -65,3 +74,59 @@ def test_profile_request_middleware_cprofile_disabled(self): self.assertNotIn('X-API-CProfile-File', response) +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=['ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware'] +) +class SQLProfilingMiddlewareTest(TestCase): + def test_sql_profiling_disabled_by_default(self): + """ + Test that the SQLProfilingMiddleware does not add headers when disabled. + """ + response = self.client.get('/test-db/') + self.assertNotIn('X-API-Query-Count', response) + self.assertNotIn('X-API-Query-Time', response) + + @override_settings(ANSIBLE_BASE_SQL_PROFILING=True, DEBUG=True) + def test_sql_profiling_enabled_with_new_setting(self): + """ + Test that the SQLProfilingMiddleware adds headers when ANSIBLE_BASE_SQL_PROFILING is True. + """ + response = self.client.get('/test-db/') + self.assertIn('X-API-Query-Count', response) + self.assertGreaterEqual(int(response['X-API-Query-Count']), 1) + self.assertIn('X-API-Query-Time', response) + self.assertTrue(response['X-API-Query-Time'].endswith('s')) + try: + float(response['X-API-Query-Time'][:-1]) + except ValueError: + self.fail("X-API-Query-Time value is not a valid float") + + @override_settings(SQL_DEBUG=True, DEBUG=True) + def test_sql_profiling_enabled_with_fallback_setting(self): + """ + Test that the SQLProfilingMiddleware adds headers when SQL_DEBUG is True as a fallback. + """ + response = self.client.get('/test-db/') + self.assertIn('X-API-Query-Count', response) + self.assertGreaterEqual(int(response['X-API-Query-Count']), 1) + self.assertIn('X-API-Query-Time', response) + self.assertTrue(response['X-API-Query-Time'].endswith('s')) + try: + float(response['X-API-Query-Time'][:-1]) + except ValueError: + self.fail("X-API-Query-Time value is not a valid float") + + @override_settings(ANSIBLE_BASE_SQL_PROFILING=True, DEBUG=False) + def test_sql_profiling_logs_warning_if_debug_is_false(self): + """ + Test that the SQLProfilingMiddleware logs a warning and does not add headers + if profiling is enabled but DEBUG is False. + """ + with self.assertLogs('ansible_base.lib.middleware.profiling.profile_request', level='WARNING') as cm: + response = self.client.get('/test-db/') + self.assertIn("ANSIBLE_BASE_SQL_PROFILING is enabled, but DEBUG is False", cm.output[0]) + self.assertNotIn('X-API-Query-Count', response) + self.assertNotIn('X-API-Query-Time', response) + + From 1e1157288d0ff8d2157ec7642052ffb78df4b92c Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Tue, 12 Aug 2025 09:08:00 -0400 Subject: [PATCH 3/7] implement db cursor wrapper so don't need DEBUG=True --- .../lib/middleware/profiling/README.md | 5 +- .../middleware/profiling/profile_request.py | 34 ++++++++------ .../middleware/test_profiling_middleware.py | 46 ++++--------------- 3 files changed, 31 insertions(+), 54 deletions(-) diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md index fd373a34d..6bcf09ac4 100644 --- a/ansible_base/lib/middleware/profiling/README.md +++ b/ansible_base/lib/middleware/profiling/README.md @@ -50,16 +50,13 @@ MIDDLEWARE = [ ### Enabling SQL Profiling -**Important:** This middleware relies on Django's `connection.queries` list, which is only populated when `settings.DEBUG` is set to `True`. Therefore, you must have `DEBUG = True` in your Django settings for this middleware to have any effect. - -The middleware is controlled by the `ANSIBLE_BASE_SQL_PROFILING` setting. For backwards compatibility, it will also be enabled if the standard Django `SQL_DEBUG` setting is `True`. +The middleware is controlled by the `ANSIBLE_BASE_SQL_PROFILING` setting. To enable SQL profiling, set the following in your Django settings: ```python # settings.py ANSIBLE_BASE_SQL_PROFILING = True -DEBUG = True ``` ## `DABProfiler` diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py index 585ca4396..c6a4e7984 100644 --- a/ansible_base/lib/middleware/profiling/profile_request.py +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -7,13 +7,12 @@ import uuid from typing import Optional, Union -from django.db import connection from django.conf import settings +from django.db import connection from django.utils.translation import gettext_lazy as _ from ansible_base.lib.utils.settings import get_function_from_setting, get_setting - logger = logging.getLogger(__name__) @@ -85,23 +84,32 @@ def __call__(self, request): return response +class SQLQueryMetrics: + def __init__(self): + self.query_count = 0 + self.query_time = 0.0 + + def __call__(self, execute, sql, params, many, context): + start_time = time.time() + try: + return execute(sql, params, many, context) + finally: + self.query_count += 1 + self.query_time += time.time() - start_time + + class SQLProfilingMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - sql_profiling_enabled = get_setting('ANSIBLE_BASE_SQL_PROFILING', get_setting('SQL_DEBUG', False)) - if sql_profiling_enabled: - if not settings.DEBUG: - logger.warning("ANSIBLE_BASE_SQL_PROFILING is enabled, but DEBUG is False. No SQL queries will be logged or counted.") - return self.get_response(request) + if not get_setting('ANSIBLE_BASE_SQL_PROFILING', False): + return self.get_response(request) - queries_before = len(connection.queries) - response = self.get_response(request) - q_times = [float(q['time']) for q in connection.queries[queries_before:]] - response['X-API-Query-Count'] = len(q_times) - response['X-API-Query-Time'] = '%0.3fs' % sum(q_times) - else: + metrics = SQLQueryMetrics() + with connection.execute_wrapper(metrics): response = self.get_response(request) + response['X-API-Query-Count'] = metrics.query_count + response['X-API-Query-Time'] = f'{metrics.query_time:.3f}s' return response diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py index 5e07a04cd..a28d8e64d 100644 --- a/test_app/tests/lib/middleware/test_profiling_middleware.py +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -1,6 +1,6 @@ -import uuid import os import tempfile +import uuid from unittest.mock import patch from django.http import HttpResponse @@ -11,22 +11,26 @@ from ansible_base.lib.utils.settings import get_setting from test_app.models import User + # A simple view for testing middleware def simple_view(request): return HttpResponse("OK") + # A view that performs a database query def db_view(request): # Create a user with a unique username to guarantee a query. User.objects.create(username=f"test-{uuid.uuid4()}") return HttpResponse("OK") + # Define URL patterns for the test urlpatterns = [ path('test/', simple_view), path('test-db/', db_view), ] + @override_settings(ROOT_URLCONF=__name__) class ProfileRequestMiddlewareTest(TestCase): @override_settings(CLUSTER_HOST_ID='test-node') @@ -36,7 +40,7 @@ def test_profile_request_middleware_headers(self): """ middleware = ProfileRequestMiddleware(simple_view) response = middleware(self.client.get('/test/').wsgi_request) - + # Test X-API-Time self.assertIn('X-API-Time', response) self.assertTrue(response['X-API-Time'].endswith('s')) @@ -74,10 +78,7 @@ def test_profile_request_middleware_cprofile_disabled(self): self.assertNotIn('X-API-CProfile-File', response) -@override_settings( - ROOT_URLCONF=__name__, - MIDDLEWARE=['ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware'] -) +@override_settings(ROOT_URLCONF=__name__, MIDDLEWARE=['ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware']) class SQLProfilingMiddlewareTest(TestCase): def test_sql_profiling_disabled_by_default(self): """ @@ -87,25 +88,10 @@ def test_sql_profiling_disabled_by_default(self): self.assertNotIn('X-API-Query-Count', response) self.assertNotIn('X-API-Query-Time', response) - @override_settings(ANSIBLE_BASE_SQL_PROFILING=True, DEBUG=True) + @override_settings(ANSIBLE_BASE_SQL_PROFILING=True) def test_sql_profiling_enabled_with_new_setting(self): """ - Test that the SQLProfilingMiddleware adds headers when ANSIBLE_BASE_SQL_PROFILING is True. - """ - response = self.client.get('/test-db/') - self.assertIn('X-API-Query-Count', response) - self.assertGreaterEqual(int(response['X-API-Query-Count']), 1) - self.assertIn('X-API-Query-Time', response) - self.assertTrue(response['X-API-Query-Time'].endswith('s')) - try: - float(response['X-API-Query-Time'][:-1]) - except ValueError: - self.fail("X-API-Query-Time value is not a valid float") - - @override_settings(SQL_DEBUG=True, DEBUG=True) - def test_sql_profiling_enabled_with_fallback_setting(self): - """ - Test that the SQLProfilingMiddleware adds headers when SQL_DEBUG is True as a fallback. + Test that the SQLProfilingMiddleware adds headers when ANSIBLE_BASE_SQL_PROFILING is True """ response = self.client.get('/test-db/') self.assertIn('X-API-Query-Count', response) @@ -116,17 +102,3 @@ def test_sql_profiling_enabled_with_fallback_setting(self): float(response['X-API-Query-Time'][:-1]) except ValueError: self.fail("X-API-Query-Time value is not a valid float") - - @override_settings(ANSIBLE_BASE_SQL_PROFILING=True, DEBUG=False) - def test_sql_profiling_logs_warning_if_debug_is_false(self): - """ - Test that the SQLProfilingMiddleware logs a warning and does not add headers - if profiling is enabled but DEBUG is False. - """ - with self.assertLogs('ansible_base.lib.middleware.profiling.profile_request', level='WARNING') as cm: - response = self.client.get('/test-db/') - self.assertIn("ANSIBLE_BASE_SQL_PROFILING is enabled, but DEBUG is False", cm.output[0]) - self.assertNotIn('X-API-Query-Count', response) - self.assertNotIn('X-API-Query-Time', response) - - From 44e6171a0ab93c71191d93f75b3dc7432cfc3b7a Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Tue, 12 Aug 2025 13:48:54 -0400 Subject: [PATCH 4/7] feat(middleware): Enhance SQL profiling with request tracing This enhances the `SQLProfilingMiddleware` to provide deeper observability by integrating it with a new request tracing context system. It introduces `TraceContextMiddleware`, which establishes a unique `trace_id` for each request using `contextvars`. This trace ID is then leveraged by the `SQLProfilingMiddleware` to inject a detailed comment, including the trace ID, route, and origin, into every SQL query. This allows for direct correlation between a web request and its database activity. Key changes include: - A new `TraceContextMiddleware` that sets and safely resets the request context. - The `SQLProfilingMiddleware` is updated to inject contextual SQL comments. - A new, isolated test suite for the `TraceContextMiddleware` to ensure it correctly handles the `X-Request-ID` header and prevents context bleeding between requests. - Existing tests were made more robust, and a state pollution issue was fixed by ensuring proper context cleanup in the test suite. --- ansible_base/lib/logging/context.py | 45 ++++++++ .../lib/middleware/profiling/README.md | 6 +- .../middleware/profiling/profile_request.py | 20 ++++ .../lib/middleware/request_context.py | 29 +++++ test_app/tests/lib/logging/test_context.py | 40 +++++++ .../middleware/test_profiling_middleware.py | 102 ++++++++++++++++-- .../lib/middleware/test_request_context.py | 88 +++++++++++++++ 7 files changed, 323 insertions(+), 7 deletions(-) create mode 100644 ansible_base/lib/logging/context.py create mode 100644 ansible_base/lib/middleware/request_context.py create mode 100644 test_app/tests/lib/logging/test_context.py create mode 100644 test_app/tests/lib/middleware/test_request_context.py diff --git a/ansible_base/lib/logging/context.py b/ansible_base/lib/logging/context.py new file mode 100644 index 000000000..89a93fcb2 --- /dev/null +++ b/ansible_base/lib/logging/context.py @@ -0,0 +1,45 @@ +import contextvars +import uuid + +# Define the context variables that will hold our trace information. +# Providing a default value is important so that they can be accessed +# even when the context has not been explicitly set. +trace_id_var = contextvars.ContextVar('trace_id', default=None) +route_var = contextvars.ContextVar('route', default=None) +origin_var = contextvars.ContextVar('origin', default=None) + + +class trace_context: + """ + A context manager and decorator to set the trace context for non-web operations. + """ + + def __init__(self, origin=None, **kwargs): + self.origin = origin + self.kwargs = kwargs + self.tokens = [] + + def __enter__(self): + # Set a new trace ID for this context + self.tokens.append(trace_id_var.set(str(uuid.uuid4()))) + + # Set the origin (e.g., 'dispatcher') + if self.origin: + self.tokens.append(origin_var.set(self.origin)) + + for key, value in self.kwargs.items(): + var = contextvars.ContextVar(key) + self.tokens.append(var.set(value)) + + def __exit__(self, exc_type, exc_value, traceback): + # Reset the context variables to their previous state + for token in self.tokens: + var = token.var + var.reset(token) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md index 6bcf09ac4..68d568f4f 100644 --- a/ansible_base/lib/middleware/profiling/README.md +++ b/ansible_base/lib/middleware/profiling/README.md @@ -37,12 +37,16 @@ This middleware provides insights into the database queries executed during a re * `X-API-Query-Count`: The total number of database queries executed during the request. * `X-API-Query-Time`: The total time spent on database queries, in seconds. -To use it, add it to your `MIDDLEWARE` list in your Django settings: +It also injects contextual information as a comment into each SQL query, which is invaluable for debugging and tracing. For example: +`/* trace_id=b71696ed-c483-408d-9740-2e7935b4f2d9, route=api/v2/users/{pk}/, origin=request */ SELECT ...` + +To use it, add both the `TraceContextMiddleware` and the `SQLProfilingMiddleware` to your `MIDDLEWARE` list in your Django settings. The `TraceContextMiddleware` should come before the `SQLProfilingMiddleware`. ```python # settings.py MIDDLEWARE = [ ... + 'ansible_base.lib.middleware.request_context.TraceContextMiddleware', 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', ... ] diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py index c6a4e7984..5ea946fad 100644 --- a/ansible_base/lib/middleware/profiling/profile_request.py +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -11,6 +11,7 @@ from django.db import connection from django.utils.translation import gettext_lazy as _ +from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var from ansible_base.lib.utils.settings import get_function_from_setting, get_setting logger = logging.getLogger(__name__) @@ -90,6 +91,18 @@ def __init__(self): self.query_time = 0.0 def __call__(self, execute, sql, params, many, context): + # Build the context comment + context_items = [] + if trace_id := trace_id_var.get(): + context_items.append(f"trace_id={trace_id}") + if route := route_var.get(): + context_items.append(f"route={route}") + if origin := origin_var.get(): + context_items.append(f"origin={origin}") + + if context_items: + sql = f"/* {', '.join(context_items)} */ {sql}" + start_time = time.time() try: return execute(sql, params, many, context) @@ -106,6 +119,13 @@ def __call__(self, request): if not get_setting('ANSIBLE_BASE_SQL_PROFILING', False): return self.get_response(request) + # Check if the trace context is available. If not, log a warning. + if trace_id_var.get() is None: + logger.warning( + "ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. " + "Please ensure that TraceContextMiddleware is included in your MIDDLEWARE settings before this middleware." + ) + metrics = SQLQueryMetrics() with connection.execute_wrapper(metrics): response = self.get_response(request) diff --git a/ansible_base/lib/middleware/request_context.py b/ansible_base/lib/middleware/request_context.py new file mode 100644 index 000000000..f2105b97c --- /dev/null +++ b/ansible_base/lib/middleware/request_context.py @@ -0,0 +1,29 @@ +import uuid + +from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var + + +class TraceContextMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + # Set the context for the request and store the tokens + origin_token = origin_var.set('request') + trace_id = request.headers.get('X-Request-ID', str(uuid.uuid4())) + trace_id_token = trace_id_var.set(trace_id) + + route_token = None + if request.resolver_match: + route_token = route_var.set(request.resolver_match.route) + + try: + response = self.get_response(request) + finally: + # Reset the context variables to their previous state + origin_var.reset(origin_token) + trace_id_var.reset(trace_id_token) + if route_token: + route_var.reset(route_token) + + return response diff --git a/test_app/tests/lib/logging/test_context.py b/test_app/tests/lib/logging/test_context.py new file mode 100644 index 000000000..4fa68f015 --- /dev/null +++ b/test_app/tests/lib/logging/test_context.py @@ -0,0 +1,40 @@ +import random +import threading +import time +import unittest + +from ansible_base.lib.logging.context import trace_id_var + + +class TestContextSafety(unittest.TestCase): + def test_trace_id_is_thread_safe(self): + """ + Verify that the trace_id context variable is thread-safe. + """ + results = [] + + def target_function(thread_id): + # Set a unique trace ID for this thread + trace_id_var.set(f"trace-id-{thread_id}") + + # Sleep for a random, short duration to encourage thread interleaving + time.sleep(random.uniform(0.01, 0.05)) + + # Get the trace ID and verify it has not been changed by another thread + retrieved_id = trace_id_var.get() + + # Store the result of the check for the main thread to verify + results.append(retrieved_id == f"trace-id-{thread_id}") + + threads = [] + for i in range(10): + thread = threading.Thread(target=target_function, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Verify that all threads successfully retrieved their own context + self.assertEqual(len(results), 10, "Not all threads completed successfully.") + self.assertTrue(all(results), "Context leaked between threads.") diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py index a28d8e64d..a85fc1ed2 100644 --- a/test_app/tests/lib/middleware/test_profiling_middleware.py +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -1,15 +1,16 @@ import os import tempfile import uuid -from unittest.mock import patch +from unittest.mock import MagicMock, patch +from django.db import connection from django.http import HttpResponse from django.test import TestCase, override_settings from django.urls import path from ansible_base.lib.middleware.profiling.profile_request import ProfileRequestMiddleware, SQLProfilingMiddleware from ansible_base.lib.utils.settings import get_setting -from test_app.models import User +from test_app.models import Organization, User # A simple view for testing middleware @@ -19,8 +20,8 @@ def simple_view(request): # A view that performs a database query def db_view(request): - # Create a user with a unique username to guarantee a query. - User.objects.create(username=f"test-{uuid.uuid4()}") + # Get or create an organization to guarantee at least one query is executed. + Organization.objects.get_or_create(name=f"test-org-{uuid.uuid4()}") return HttpResponse("OK") @@ -78,8 +79,23 @@ def test_profile_request_middleware_cprofile_disabled(self): self.assertNotIn('X-API-CProfile-File', response) -@override_settings(ROOT_URLCONF=__name__, MIDDLEWARE=['ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware']) +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'ansible_base.lib.middleware.request_context.TraceContextMiddleware', + 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', + ], +) class SQLProfilingMiddlewareTest(TestCase): + def setUp(self): + # Create a user and log them in. This is necessary to avoid the bug in the + # test_app models that causes a TypeError when get_system_user is called. + self.user = User.objects.create_user(username='testuser', password='password') + self.client.force_login(self.user) + + @override_settings(ANSIBLE_BASE_SQL_PROFILING=False) def test_sql_profiling_disabled_by_default(self): """ Test that the SQLProfilingMiddleware does not add headers when disabled. @@ -91,7 +107,7 @@ def test_sql_profiling_disabled_by_default(self): @override_settings(ANSIBLE_BASE_SQL_PROFILING=True) def test_sql_profiling_enabled_with_new_setting(self): """ - Test that the SQLProfilingMiddleware adds headers when ANSIBLE_BASE_SQL_PROFILING is True + Test that the SQLProfilingMiddleware adds headers when ANSIBLE_BASE_SQL_PROFILING is True. """ response = self.client.get('/test-db/') self.assertIn('X-API-Query-Count', response) @@ -102,3 +118,77 @@ def test_sql_profiling_enabled_with_new_setting(self): float(response['X-API-Query-Time'][:-1]) except ValueError: self.fail("X-API-Query-Time value is not a valid float") + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', + ], + ANSIBLE_BASE_SQL_PROFILING=True, +) +class SQLProfilingMiddlewareMissingContextTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username='testuser', password='password') + self.client.force_login(self.user) + + @patch('ansible_base.lib.middleware.profiling.profile_request.logger') + def test_logs_warning_if_context_middleware_is_missing(self, mock_logger): + """ + Test that the SQLProfilingMiddleware logs a warning if the TraceContextMiddleware + is not present and the context is missing, even when a query is made. + """ + # We need to use a real view that makes a query + response = self.client.get('/test-db/') + self.assertEqual(response.status_code, 200) + + mock_logger.warning.assert_called_with( + "ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. " + "Please ensure that TraceContextMiddleware is included in your MIDDLEWARE settings before this middleware." + ) + + +class SQLQueryMetricsTest(TestCase): + def test_sql_comment_injection(self): + """ + Test that the SQLQueryMetrics wrapper correctly injects context + into the SQL query as a comment. + """ + from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var + from ansible_base.lib.middleware.profiling.profile_request import SQLQueryMetrics + + # 1. Manually set the context, saving the tokens to reset it later. + trace_id_token = trace_id_var.set("test-trace-id") + route_token = route_var.set("test/route") + origin_token = origin_var.set("test-origin") + + try: + # 2. Instantiate our metrics class and call it directly. + metrics = SQLQueryMetrics() + original_sql = "SELECT 1" + + # 3. We don't need a real execute function, so we'll just use a lambda. + # The key is that we can inspect the SQL that was passed to it. + modified_sql = "" + + def mock_execute(sql, params, many, context): + nonlocal modified_sql + modified_sql = sql + return None + + metrics(mock_execute, original_sql, [], False, {}) + + # 4. Assert that the SQL passed to our mock was correctly modified. + self.assertIn("/*", modified_sql) + self.assertIn("trace_id=test-trace-id", modified_sql) + self.assertIn("route=test/route", modified_sql) + self.assertIn("origin=test-origin", modified_sql) + self.assertIn("*/", modified_sql) + self.assertIn(original_sql, modified_sql) + finally: + # 5. Reset the context variables to their previous state. + trace_id_var.reset(trace_id_token) + route_var.reset(route_token) + origin_var.reset(origin_token) diff --git a/test_app/tests/lib/middleware/test_request_context.py b/test_app/tests/lib/middleware/test_request_context.py new file mode 100644 index 000000000..d33831030 --- /dev/null +++ b/test_app/tests/lib/middleware/test_request_context.py @@ -0,0 +1,88 @@ +import json +from uuid import UUID, uuid4 + +from django.http import JsonResponse +from django.test import TestCase, override_settings +from django.urls import path + +from ansible_base.lib.logging.context import trace_id_var + + +# A simple view for testing middleware +def context_view(request): + return JsonResponse({"trace_id": trace_id_var.get()}) + + +# Define URL patterns for the test +urlpatterns = [ + path('context/', context_view), +] + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'ansible_base.lib.middleware.request_context.TraceContextMiddleware', + ], +) +class TraceContextMiddlewareTest(TestCase): + def test_uses_x_request_id_header(self): + """ + Test that the middleware uses the X-Request-ID from the request header + if it is provided. + """ + self.assertIsNone(trace_id_var.get()) + request_id = str(uuid4()) + + response = self.client.get('/context/', HTTP_X_REQUEST_ID=request_id) + self.assertEqual(response.status_code, 200) + + # The trace_id in the view should match the header + trace_id = json.loads(response.content).get("trace_id") + self.assertEqual(trace_id, request_id) + + # After the request, the context should be reset + self.assertIsNone(trace_id_var.get()) + + def test_generates_id_if_header_is_missing(self): + """ + Test that the middleware generates a new UUID if the X-Request-ID + header is not provided. + """ + self.assertIsNone(trace_id_var.get()) + + response = self.client.get('/context/') + self.assertEqual(response.status_code, 200) + + # The trace_id should be a valid UUID + trace_id = json.loads(response.content).get("trace_id") + self.assertIsNotNone(trace_id) + try: + UUID(trace_id, version=4) + except ValueError: + self.fail("trace_id is not a valid UUID4") + + # After the request, the context should be reset + self.assertIsNone(trace_id_var.get()) + + def test_context_does_not_bleed_between_requests(self): + """ + Test that the trace_id from one request does not bleed into the next. + """ + # First request has an X-Request-ID + request_id = str(uuid4()) + self.client.get('/context/', HTTP_X_REQUEST_ID=request_id) + + # Second request does not have the header + response = self.client.get('/context/') + self.assertEqual(response.status_code, 200) + + # The trace_id of the second request should be a new, generated UUID, + # not the one from the first request's header. + trace_id_2 = json.loads(response.content).get("trace_id") + self.assertIsNotNone(trace_id_2) + self.assertNotEqual(trace_id_2, request_id) + try: + UUID(trace_id_2, version=4) + except ValueError: + self.fail("trace_id_2 is not a valid UUID4") From 18051d98c99c14529db98efe5c03a5a2b3f5c344 Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Tue, 12 Aug 2025 16:11:18 -0400 Subject: [PATCH 5/7] add a single middleware for apps to include that orchestrates things the right way --- ansible_base/lib/middleware/observability.py | 27 +++++ .../lib/middleware/profiling/README.md | 38 ++---- .../middleware/profiling/profile_request.py | 8 +- .../lib/middleware/request_context.py | 6 +- .../middleware/test_profiling_middleware.py | 114 ++++++++++-------- .../lib/middleware/test_request_context.py | 2 +- 6 files changed, 113 insertions(+), 82 deletions(-) create mode 100644 ansible_base/lib/middleware/observability.py diff --git a/ansible_base/lib/middleware/observability.py b/ansible_base/lib/middleware/observability.py new file mode 100644 index 000000000..8c607db9c --- /dev/null +++ b/ansible_base/lib/middleware/observability.py @@ -0,0 +1,27 @@ +""" +A single middleware to provide a unified observability layer, ensuring that context, +profiling, and SQL metrics are captured in the correct order. +""" + +from .profiling.profile_request import _ProfileRequestMiddleware, _SQLProfilingMiddleware +from .request_context import _TraceContextMiddleware + + +class ObservabilityMiddleware: + """ + A single entry point for observability middleware. + + This middleware composes the trace context, request profiling, and SQL + profiling middleware in the correct order. Instead of listing all three + in your settings, you can now just add this one. + """ + + def __init__(self, get_response): + # Chain the middleware in the desired order. The request will flow + # from _TraceContextMiddleware -> _ProfileRequestMiddleware -> _SQLProfilingMiddleware. + handler = _SQLProfilingMiddleware(get_response) + handler = _ProfileRequestMiddleware(handler) + self.handler = _TraceContextMiddleware(handler) + + def __call__(self, request): + return self.handler(request) diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md index 68d568f4f..9e250c12d 100644 --- a/ansible_base/lib/middleware/profiling/README.md +++ b/ansible_base/lib/middleware/profiling/README.md @@ -1,27 +1,28 @@ -# Request Profiling +# Request Profiling and Observability -The `ProfileRequestMiddleware` and `DABProfiler` class provide a way to profile requests and other code in your Django application. This functionality is a generalization of the profiling tools found in AWX and can be used by any `django-ansible-base` consumer. +The `ObservabilityMiddleware` provides a simple way to gain performance and debugging insights into your Django application. It acts as a single entry point for several underlying middleware components, ensuring they are always used in the correct order. -## `ProfileRequestMiddleware` +## `ObservabilityMiddleware` -This middleware provides performance insights for API requests. To use it, add it to your `MIDDLEWARE` list in your Django settings. For the most accurate and reliable timing, it is recommended to add this middleware to the top of the `MIDDLEWARE` list. +This single middleware bundles tracing, request profiling, and SQL query analysis. To use it, add it to the top of your `MIDDLEWARE` list in your Django settings. ```python # settings.py MIDDLEWARE = [ - 'ansible_base.lib.middleware.profiling.profile_request.ProfileRequestMiddleware', + 'ansible_base.lib.middleware.observability.ObservabilityMiddleware', ... ] ``` The middleware always adds the following headers to the response: +* `X-Request-ID`: A unique identifier for the request. If the incoming request includes an `X-Request-ID` header, that value will be used; otherwise, a new UUID will be generated. * `X-API-Time`: The total time taken to process the request, in seconds. -* `X-API-Node`: The cluster host ID of the node that served the request. This header is only added if it is not already present in the response. +* `X-API-Node`: The cluster host ID of the node that served the request. ### cProfile Support -When the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting is enabled, the middleware will also perform a cProfile analysis for each request. The resulting `.prof` file is saved to a temporary directory on the node that served the request, and its path is returned in the `X-API-CProfile-File` response header. +When the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting is enabled, the middleware will also perform a cProfile analysis for each request. The resulting `.prof` file is saved to a temporary directory on the node that served the request, and its path is returned in the `X-API-CProfile-File` response header. The filename will include the request's `X-Request-ID`. To enable cProfile support, set the following in your Django settings: @@ -30,9 +31,9 @@ To enable cProfile support, set the following in your Django settings: ANSIBLE_BASE_CPROFILE_REQUESTS = True ``` -## `SQLProfilingMiddleware` +### SQL Profiling Support -This middleware provides insights into the database queries executed during a request. When enabled, it adds the following headers to the response: +When the `ANSIBLE_BASE_SQL_PROFILING` setting is enabled, the middleware provides insights into the database queries executed during a request. It adds the following headers to the response: * `X-API-Query-Count`: The total number of database queries executed during the request. * `X-API-Query-Time`: The total time spent on database queries, in seconds. @@ -40,22 +41,6 @@ This middleware provides insights into the database queries executed during a re It also injects contextual information as a comment into each SQL query, which is invaluable for debugging and tracing. For example: `/* trace_id=b71696ed-c483-408d-9740-2e7935b4f2d9, route=api/v2/users/{pk}/, origin=request */ SELECT ...` -To use it, add both the `TraceContextMiddleware` and the `SQLProfilingMiddleware` to your `MIDDLEWARE` list in your Django settings. The `TraceContextMiddleware` should come before the `SQLProfilingMiddleware`. - -```python -# settings.py -MIDDLEWARE = [ - ... - 'ansible_base.lib.middleware.request_context.TraceContextMiddleware', - 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', - ... -] -``` - -### Enabling SQL Profiling - -The middleware is controlled by the `ANSIBLE_BASE_SQL_PROFILING` setting. - To enable SQL profiling, set the following in your Django settings: ```python @@ -65,7 +50,7 @@ ANSIBLE_BASE_SQL_PROFILING = True ## `DABProfiler` -The core profiling logic is encapsulated in the `DABProfiler` class. This class can be imported and used directly for profiling non-HTTP contexts, such as background tasks or gRPC services. +For profiling non-HTTP contexts, such as background tasks or gRPC services, the `DABProfiler` class can be used directly. The profiler's cProfile functionality is controlled by the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting. @@ -119,3 +104,4 @@ import pstats p = pstats.Stats('/path/to/your/profile.prof') p.sort_stats('cumulative').print_stats(10) ``` + diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py index 5ea946fad..1844c8ed4 100644 --- a/ansible_base/lib/middleware/profiling/profile_request.py +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -51,7 +51,7 @@ def stop(self, profile_id: Optional[Union[str, uuid.UUID]] = None): return elapsed, cprofile_filename -class ProfileRequestMiddleware(threading.local): +class _ProfileRequestMiddleware(threading.local): def __init__(self, get_response=None): self.get_response = get_response self.profiler = DABProfiler() @@ -59,7 +59,7 @@ def __init__(self, get_response=None): def __call__(self, request): # Logic before the view (formerly process_request) self.profiler.start() - request_id = request.headers.get('X-Request-ID') + request_id = trace_id_var.get() # Call the next middleware or the view response = self.get_response(request) @@ -111,7 +111,7 @@ def __call__(self, execute, sql, params, many, context): self.query_time += time.time() - start_time -class SQLProfilingMiddleware: +class _SQLProfilingMiddleware: def __init__(self, get_response): self.get_response = get_response @@ -123,7 +123,7 @@ def __call__(self, request): if trace_id_var.get() is None: logger.warning( "ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. " - "Please ensure that TraceContextMiddleware is included in your MIDDLEWARE settings before this middleware." + "Please use the ObservabilityMiddleware instead of including profiling middleware individually." ) metrics = SQLQueryMetrics() diff --git a/ansible_base/lib/middleware/request_context.py b/ansible_base/lib/middleware/request_context.py index f2105b97c..824859b6a 100644 --- a/ansible_base/lib/middleware/request_context.py +++ b/ansible_base/lib/middleware/request_context.py @@ -3,14 +3,15 @@ from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var -class TraceContextMiddleware: +class _TraceContextMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): # Set the context for the request and store the tokens origin_token = origin_var.set('request') - trace_id = request.headers.get('X-Request-ID', str(uuid.uuid4())) + # .get is case-insensitive, but we'll use lowercase for consistency + trace_id = request.headers.get('x-request-id', str(uuid.uuid4())) trace_id_token = trace_id_var.set(trace_id) route_token = None @@ -19,6 +20,7 @@ def __call__(self, request): try: response = self.get_response(request) + response['X-Request-ID'] = trace_id finally: # Reset the context variables to their previous state origin_var.reset(origin_token) diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py index a85fc1ed2..a05f552c3 100644 --- a/test_app/tests/lib/middleware/test_profiling_middleware.py +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -1,15 +1,18 @@ import os import tempfile import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import patch -from django.db import connection from django.http import HttpResponse from django.test import TestCase, override_settings from django.urls import path -from ansible_base.lib.middleware.profiling.profile_request import ProfileRequestMiddleware, SQLProfilingMiddleware -from ansible_base.lib.utils.settings import get_setting +from ansible_base.lib.middleware.observability import ObservabilityMiddleware +from ansible_base.lib.middleware.profiling.profile_request import ( + SQLQueryMetrics, + _ProfileRequestMiddleware, + _SQLProfilingMiddleware, +) from test_app.models import Organization, User @@ -33,13 +36,13 @@ def db_view(request): @override_settings(ROOT_URLCONF=__name__) -class ProfileRequestMiddlewareTest(TestCase): +class _ProfileRequestMiddlewareTest(TestCase): @override_settings(CLUSTER_HOST_ID='test-node') def test_profile_request_middleware_headers(self): """ - Test that the ProfileRequestMiddleware adds sensible headers. + Test that the _ProfileRequestMiddleware adds sensible headers. """ - middleware = ProfileRequestMiddleware(simple_view) + middleware = _ProfileRequestMiddleware(simple_view) response = middleware(self.client.get('/test/').wsgi_request) # Test X-API-Time @@ -57,24 +60,25 @@ def test_profile_request_middleware_headers(self): @override_settings(ANSIBLE_BASE_CPROFILE_REQUESTS=True) def test_profile_request_middleware_cprofile_enabled(self): """ - Test that the ProfileRequestMiddleware adds the X-API-CProfile-File + Test that the _ProfileRequestMiddleware adds the X-API-CProfile-File header and creates a profile file when enabled. """ with tempfile.TemporaryDirectory() as tmpdir: with patch('tempfile.gettempdir', return_value=tmpdir): - middleware = ProfileRequestMiddleware(simple_view) + middleware = _ProfileRequestMiddleware(simple_view) response = middleware(self.client.get('/test/').wsgi_request) self.assertIn('X-API-CProfile-File', response) profile_file = response['X-API-CProfile-File'] self.assertTrue(profile_file.endswith('.prof')) self.assertTrue(os.path.exists(profile_file)) + @override_settings(ANSIBLE_BASE_CPROFILE_REQUESTS=False) def test_profile_request_middleware_cprofile_disabled(self): """ - Test that the ProfileRequestMiddleware does not add the + Test that the _ProfileRequestMiddleware does not add the X-API-CProfile-File header when disabled. """ - middleware = ProfileRequestMiddleware(simple_view) + middleware = _ProfileRequestMiddleware(simple_view) response = middleware(self.client.get('/test/').wsgi_request) self.assertNotIn('X-API-CProfile-File', response) @@ -84,40 +88,28 @@ def test_profile_request_middleware_cprofile_disabled(self): MIDDLEWARE=[ 'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'ansible_base.lib.middleware.request_context.TraceContextMiddleware', - 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', + 'ansible_base.lib.middleware.request_context._TraceContextMiddleware', + 'ansible_base.lib.middleware.profiling.profile_request._SQLProfilingMiddleware', ], ) -class SQLProfilingMiddlewareTest(TestCase): +class _SQLProfilingMiddlewareTest(TestCase): def setUp(self): - # Create a user and log them in. This is necessary to avoid the bug in the - # test_app models that causes a TypeError when get_system_user is called. self.user = User.objects.create_user(username='testuser', password='password') self.client.force_login(self.user) @override_settings(ANSIBLE_BASE_SQL_PROFILING=False) def test_sql_profiling_disabled_by_default(self): - """ - Test that the SQLProfilingMiddleware does not add headers when disabled. - """ response = self.client.get('/test-db/') self.assertNotIn('X-API-Query-Count', response) self.assertNotIn('X-API-Query-Time', response) @override_settings(ANSIBLE_BASE_SQL_PROFILING=True) def test_sql_profiling_enabled_with_new_setting(self): - """ - Test that the SQLProfilingMiddleware adds headers when ANSIBLE_BASE_SQL_PROFILING is True. - """ response = self.client.get('/test-db/') self.assertIn('X-API-Query-Count', response) self.assertGreaterEqual(int(response['X-API-Query-Count']), 1) self.assertIn('X-API-Query-Time', response) self.assertTrue(response['X-API-Query-Time'].endswith('s')) - try: - float(response['X-API-Query-Time'][:-1]) - except ValueError: - self.fail("X-API-Query-Time value is not a valid float") @override_settings( @@ -125,52 +117,35 @@ def test_sql_profiling_enabled_with_new_setting(self): MIDDLEWARE=[ 'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'ansible_base.lib.middleware.profiling.profile_request.SQLProfilingMiddleware', + 'ansible_base.lib.middleware.profiling.profile_request._SQLProfilingMiddleware', ], ANSIBLE_BASE_SQL_PROFILING=True, ) -class SQLProfilingMiddlewareMissingContextTest(TestCase): +class _SQLProfilingMiddlewareMissingContextTest(TestCase): def setUp(self): self.user = User.objects.create_user(username='testuser', password='password') self.client.force_login(self.user) @patch('ansible_base.lib.middleware.profiling.profile_request.logger') def test_logs_warning_if_context_middleware_is_missing(self, mock_logger): - """ - Test that the SQLProfilingMiddleware logs a warning if the TraceContextMiddleware - is not present and the context is missing, even when a query is made. - """ - # We need to use a real view that makes a query - response = self.client.get('/test-db/') - self.assertEqual(response.status_code, 200) - + self.client.get('/test-db/') mock_logger.warning.assert_called_with( "ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. " - "Please ensure that TraceContextMiddleware is included in your MIDDLEWARE settings before this middleware." + "Please use the ObservabilityMiddleware instead of including profiling middleware individually." ) class SQLQueryMetricsTest(TestCase): def test_sql_comment_injection(self): - """ - Test that the SQLQueryMetrics wrapper correctly injects context - into the SQL query as a comment. - """ from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var - from ansible_base.lib.middleware.profiling.profile_request import SQLQueryMetrics - # 1. Manually set the context, saving the tokens to reset it later. trace_id_token = trace_id_var.set("test-trace-id") route_token = route_var.set("test/route") origin_token = origin_var.set("test-origin") try: - # 2. Instantiate our metrics class and call it directly. metrics = SQLQueryMetrics() original_sql = "SELECT 1" - - # 3. We don't need a real execute function, so we'll just use a lambda. - # The key is that we can inspect the SQL that was passed to it. modified_sql = "" def mock_execute(sql, params, many, context): @@ -180,7 +155,6 @@ def mock_execute(sql, params, many, context): metrics(mock_execute, original_sql, [], False, {}) - # 4. Assert that the SQL passed to our mock was correctly modified. self.assertIn("/*", modified_sql) self.assertIn("trace_id=test-trace-id", modified_sql) self.assertIn("route=test/route", modified_sql) @@ -188,7 +162,49 @@ def mock_execute(sql, params, many, context): self.assertIn("*/", modified_sql) self.assertIn(original_sql, modified_sql) finally: - # 5. Reset the context variables to their previous state. trace_id_var.reset(trace_id_token) route_var.reset(route_token) origin_var.reset(origin_token) + + +@override_settings( + ROOT_URLCONF=__name__, + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'ansible_base.lib.middleware.observability.ObservabilityMiddleware', + ], + ANSIBLE_BASE_SQL_PROFILING=True, + ANSIBLE_BASE_CPROFILE_REQUESTS=True, + CLUSTER_HOST_ID='test-node-obs', +) +class ObservabilityMiddlewareTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username='testuser', password='password') + self.client.force_login(self.user) + + def test_observability_middleware_all_headers(self): + """ + An integration test to ensure the facade middleware adds all expected + headers and uses the request ID consistently. + """ + request_id = str(uuid.uuid4()) + with tempfile.TemporaryDirectory() as tmpdir: + with patch('tempfile.gettempdir', return_value=tmpdir): + response = self.client.get('/test-db/', HTTP_X_REQUEST_ID=request_id) + + # 1. From _TraceContextMiddleware: Check response header + self.assertIn('X-Request-ID', response) + self.assertEqual(response['X-Request-ID'], request_id) + + # 2. From _ProfileRequestMiddleware: Check profiling headers and filename + self.assertIn('X-API-Time', response) + self.assertIn('X-API-Node', response) + self.assertEqual(response['X-API-Node'], 'test-node-obs') + self.assertIn('X-API-CProfile-File', response) + self.assertIn(request_id, response['X-API-CProfile-File']) + self.assertTrue(os.path.exists(response['X-API-CProfile-File'])) + + # 3. From _SQLProfilingMiddleware: Check SQL headers + self.assertIn('X-API-Query-Count', response) + self.assertIn('X-API-Query-Time', response) diff --git a/test_app/tests/lib/middleware/test_request_context.py b/test_app/tests/lib/middleware/test_request_context.py index d33831030..e67c199b7 100644 --- a/test_app/tests/lib/middleware/test_request_context.py +++ b/test_app/tests/lib/middleware/test_request_context.py @@ -22,7 +22,7 @@ def context_view(request): @override_settings( ROOT_URLCONF=__name__, MIDDLEWARE=[ - 'ansible_base.lib.middleware.request_context.TraceContextMiddleware', + 'ansible_base.lib.middleware.request_context._TraceContextMiddleware', ], ) class TraceContextMiddlewareTest(TestCase): From afd694679088ed7d909a97396d1f720c6f26fac8 Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Tue, 12 Aug 2025 17:30:09 -0400 Subject: [PATCH 6/7] fix route in sql comment and also guard against sql injection --- ansible_base/lib/logging/context.py | 1 - .../middleware/profiling/profile_request.py | 44 ++++++++++++++---- .../lib/middleware/request_context.py | 28 +++++++---- .../middleware/test_profiling_middleware.py | 46 ++++++++++++++++--- .../lib/middleware/test_request_context.py | 18 ++++++++ 5 files changed, 112 insertions(+), 25 deletions(-) diff --git a/ansible_base/lib/logging/context.py b/ansible_base/lib/logging/context.py index 89a93fcb2..98b5072e9 100644 --- a/ansible_base/lib/logging/context.py +++ b/ansible_base/lib/logging/context.py @@ -5,7 +5,6 @@ # Providing a default value is important so that they can be accessed # even when the context has not been explicitly set. trace_id_var = contextvars.ContextVar('trace_id', default=None) -route_var = contextvars.ContextVar('route', default=None) origin_var = contextvars.ContextVar('origin', default=None) diff --git a/ansible_base/lib/middleware/profiling/profile_request.py b/ansible_base/lib/middleware/profiling/profile_request.py index 1844c8ed4..a59525c17 100644 --- a/ansible_base/lib/middleware/profiling/profile_request.py +++ b/ansible_base/lib/middleware/profiling/profile_request.py @@ -6,12 +6,13 @@ import time import uuid from typing import Optional, Union +from urllib.parse import quote from django.conf import settings from django.db import connection from django.utils.translation import gettext_lazy as _ -from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var +from ansible_base.lib.logging.context import origin_var, trace_id_var from ansible_base.lib.utils.settings import get_function_from_setting, get_setting logger = logging.getLogger(__name__) @@ -85,23 +86,50 @@ def __call__(self, request): return response +# Define the maximum length for a value in a SQL comment +SQL_COMMENT_MAX_LENGTH = 256 + + +def _sanitize_for_sql_comment(value: str) -> str: + """ + Sanitizes a string for safe inclusion in a SQL comment. + + - URL-encodes the value to handle special characters. + - Escapes the '%' character to prevent conflicts with database placeholders. + - Truncates the string to a maximum length. + """ + # URL-encode the value + quoted_value = quote(str(value)) + # Escape the '%' character for the database driver + sanitized_value = quoted_value.replace('%', '%%') + # Truncate to the maximum length + return sanitized_value[:SQL_COMMENT_MAX_LENGTH] + + class SQLQueryMetrics: - def __init__(self): + def __init__(self, request=None): + self.request = request self.query_count = 0 self.query_time = 0.0 def __call__(self, execute, sql, params, many, context): # Build the context comment context_items = [] + # trace_id is already validated as a UUID, so it is safe if trace_id := trace_id_var.get(): - context_items.append(f"trace_id={trace_id}") - if route := route_var.get(): - context_items.append(f"route={route}") + context_items.append(f"trace_id='{trace_id}'") + + # The route is only available after the URL resolver has run + if self.request and getattr(self.request, 'resolver_match', None): + if route := self.request.resolver_match.route: + context_items.append(f"route='{_sanitize_for_sql_comment(route)}'") + if origin := origin_var.get(): - context_items.append(f"origin={origin}") + context_items.append(f"origin='{_sanitize_for_sql_comment(origin)}'") if context_items: - sql = f"/* {', '.join(context_items)} */ {sql}" + comment = f"/* {', '.join(context_items)} */" + sql = f"{comment} {sql}" start_time = time.time() try: @@ -126,7 +154,7 @@ def __call__(self, request): "Please use the ObservabilityMiddleware instead of including profiling middleware individually." ) - metrics = SQLQueryMetrics() + metrics = SQLQueryMetrics(request) with connection.execute_wrapper(metrics): response = self.get_response(request) diff --git a/ansible_base/lib/middleware/request_context.py b/ansible_base/lib/middleware/request_context.py index 824859b6a..a49142cb7 100644 --- a/ansible_base/lib/middleware/request_context.py +++ b/ansible_base/lib/middleware/request_context.py @@ -1,6 +1,6 @@ import uuid -from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var +from ansible_base.lib.logging.context import origin_var, trace_id_var class _TraceContextMiddleware: @@ -10,13 +10,25 @@ def __init__(self, get_response): def __call__(self, request): # Set the context for the request and store the tokens origin_token = origin_var.set('request') - # .get is case-insensitive, but we'll use lowercase for consistency - trace_id = request.headers.get('x-request-id', str(uuid.uuid4())) - trace_id_token = trace_id_var.set(trace_id) - route_token = None - if request.resolver_match: - route_token = route_var.set(request.resolver_match.route) + # Get the request ID from the header + header_trace_id = request.headers.get('x-request-id') + trace_id = None + + if header_trace_id: + try: + # Validate that the provided header is a valid UUID + uuid.UUID(header_trace_id) + trace_id = header_trace_id + except ValueError: + # If it's not a valid UUID, discard it and we'll generate a new one + pass + + # If no valid trace_id was found, generate a new one + if not trace_id: + trace_id = str(uuid.uuid4()) + + trace_id_token = trace_id_var.set(trace_id) try: response = self.get_response(request) @@ -25,7 +37,5 @@ def __call__(self, request): # Reset the context variables to their previous state origin_var.reset(origin_token) trace_id_var.reset(trace_id_token) - if route_token: - route_var.reset(route_token) return response diff --git a/test_app/tests/lib/middleware/test_profiling_middleware.py b/test_app/tests/lib/middleware/test_profiling_middleware.py index a05f552c3..86e59e62d 100644 --- a/test_app/tests/lib/middleware/test_profiling_middleware.py +++ b/test_app/tests/lib/middleware/test_profiling_middleware.py @@ -137,14 +137,22 @@ def test_logs_warning_if_context_middleware_is_missing(self, mock_logger): class SQLQueryMetricsTest(TestCase): def test_sql_comment_injection(self): - from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var + from django.test.client import RequestFactory + from ansible_base.lib.logging.context import origin_var, trace_id_var + + # 1. Manually set the context, saving the tokens to reset it later. trace_id_token = trace_id_var.set("test-trace-id") - route_token = route_var.set("test/route") origin_token = origin_var.set("test-origin") + # 2. Create a mock request and manually set the resolver_match + factory = RequestFactory() + request = factory.get('/test-db/') + request.resolver_match = type('ResolverMatch', (), {'route': 'test/route'}) + try: - metrics = SQLQueryMetrics() + # 3. Instantiate our metrics class and call it directly. + metrics = SQLQueryMetrics(request) original_sql = "SELECT 1" modified_sql = "" @@ -155,15 +163,16 @@ def mock_execute(sql, params, many, context): metrics(mock_execute, original_sql, [], False, {}) + # 4. Assert that the SQL passed to our mock was correctly modified. self.assertIn("/*", modified_sql) - self.assertIn("trace_id=test-trace-id", modified_sql) - self.assertIn("route=test/route", modified_sql) - self.assertIn("origin=test-origin", modified_sql) + self.assertIn("trace_id='test-trace-id'", modified_sql) + self.assertIn("route='test/route'", modified_sql) + self.assertIn("origin='test-origin'", modified_sql) self.assertIn("*/", modified_sql) self.assertIn(original_sql, modified_sql) finally: + # 5. Reset the context variables to their previous state. trace_id_var.reset(trace_id_token) - route_var.reset(route_token) origin_var.reset(origin_token) @@ -208,3 +217,26 @@ def test_observability_middleware_all_headers(self): # 3. From _SQLProfilingMiddleware: Check SQL headers self.assertIn('X-API-Query-Count', response) self.assertIn('X-API-Query-Time', response) + + +class SQLCommentSanitizationTest(TestCase): + def test_sanitization_escapes_disallowed_chars(self): + from ansible_base.lib.middleware.profiling.profile_request import _sanitize_for_sql_comment + + malicious_string = "*/; DROP TABLE users; --" + sanitized = _sanitize_for_sql_comment(malicious_string) + self.assertEqual(sanitized, "%%2A/%%3B%%20DROP%%20TABLE%%20users%%3B%%20--") + + def test_sanitization_allows_safe_chars(self): + from ansible_base.lib.middleware.profiling.profile_request import _sanitize_for_sql_comment + + safe_string = "a-b_c.d/e123" + sanitized = _sanitize_for_sql_comment(safe_string) + self.assertEqual(sanitized, "a-b_c.d/e123") + + def test_sanitization_truncates_long_strings(self): + from ansible_base.lib.middleware.profiling.profile_request import SQL_COMMENT_MAX_LENGTH, _sanitize_for_sql_comment + + long_string = "a" * (SQL_COMMENT_MAX_LENGTH + 100) + sanitized = _sanitize_for_sql_comment(long_string) + self.assertEqual(len(sanitized), SQL_COMMENT_MAX_LENGTH) diff --git a/test_app/tests/lib/middleware/test_request_context.py b/test_app/tests/lib/middleware/test_request_context.py index e67c199b7..cb75f8038 100644 --- a/test_app/tests/lib/middleware/test_request_context.py +++ b/test_app/tests/lib/middleware/test_request_context.py @@ -86,3 +86,21 @@ def test_context_does_not_bleed_between_requests(self): UUID(trace_id_2, version=4) except ValueError: self.fail("trace_id_2 is not a valid UUID4") + + def test_discards_invalid_uuid_in_header(self): + """ + Test that the middleware discards an invalid UUID in the X-Request-ID + header and generates a new, valid one. + """ + malicious_id = "not-a-uuid' --" + response = self.client.get('/context/', HTTP_X_REQUEST_ID=malicious_id) + self.assertEqual(response.status_code, 200) + + # The trace_id in the response should be a new, valid UUID, not the malicious one. + new_trace_id = response.headers.get("X-Request-ID") + self.assertIsNotNone(new_trace_id) + self.assertNotEqual(new_trace_id, malicious_id) + try: + UUID(new_trace_id, version=4) + except ValueError: + self.fail("The new trace_id is not a valid UUID4") From e7f5ad5a1f638ec6d5533839a1f7c31dd437ac03 Mon Sep 17 00:00:00 2001 From: Elijah DeLee Date: Wed, 13 Aug 2025 14:05:41 -0400 Subject: [PATCH 7/7] fix: Align trace_context with middleware and add tests Refactors the existing `trace_context` manager to align its behavior with the `TraceContextMiddleware`, ensuring consistent and safe observability for non-web tasks. This commit addresses these issues by: - Modifying `trace_context` to only accept `origin` and `trace_id`. - Adding validation to ensure any provided `trace_id` is a valid UUID, matching the middleware's logic. - Removing the ability to create arbitrary context variables via `**kwargs`. Additionally, this change introduces: - Tests for the `trace_context` manager, covering its functionality, validation, thread-safety, and use as a decorator. - Documentation in the `README.md` with usage instructions and an example for background tasks. --- ansible_base/lib/logging/context.py | 22 ++-- .../lib/middleware/profiling/README.md | 32 +++++ test_app/tests/lib/logging/test_context.py | 115 +++++++++++++++++- 3 files changed, 155 insertions(+), 14 deletions(-) diff --git a/ansible_base/lib/logging/context.py b/ansible_base/lib/logging/context.py index 98b5072e9..7c09c2761 100644 --- a/ansible_base/lib/logging/context.py +++ b/ansible_base/lib/logging/context.py @@ -13,23 +13,29 @@ class trace_context: A context manager and decorator to set the trace context for non-web operations. """ - def __init__(self, origin=None, **kwargs): + def __init__(self, origin=None, trace_id=None): self.origin = origin - self.kwargs = kwargs self.tokens = [] + if trace_id: + try: + # Validate that the provided header is a valid UUID + uuid.UUID(trace_id) + self.trace_id = trace_id + except (ValueError, TypeError): + # If it's not a valid UUID, discard it and we'll generate a new one + self.trace_id = str(uuid.uuid4()) + else: + self.trace_id = str(uuid.uuid4()) + def __enter__(self): - # Set a new trace ID for this context - self.tokens.append(trace_id_var.set(str(uuid.uuid4()))) + # Set the trace ID for this context + self.tokens.append(trace_id_var.set(self.trace_id)) # Set the origin (e.g., 'dispatcher') if self.origin: self.tokens.append(origin_var.set(self.origin)) - for key, value in self.kwargs.items(): - var = contextvars.ContextVar(key) - self.tokens.append(var.set(value)) - def __exit__(self, exc_type, exc_value, traceback): # Reset the context variables to their previous state for token in self.tokens: diff --git a/ansible_base/lib/middleware/profiling/README.md b/ansible_base/lib/middleware/profiling/README.md index 9e250c12d..eda260750 100644 --- a/ansible_base/lib/middleware/profiling/README.md +++ b/ansible_base/lib/middleware/profiling/README.md @@ -31,6 +31,8 @@ To enable cProfile support, set the following in your Django settings: ANSIBLE_BASE_CPROFILE_REQUESTS = True ``` +> **Note:** Enabling cProfile has significant performance implications and is intended for temporary, live debugging sessions, not for permanent use in production environments. + ### SQL Profiling Support When the `ANSIBLE_BASE_SQL_PROFILING` setting is enabled, the middleware provides insights into the database queries executed during a request. It adds the following headers to the response: @@ -48,6 +50,8 @@ To enable SQL profiling, set the following in your Django settings: ANSIBLE_BASE_SQL_PROFILING = True ``` +> **Note:** This feature is most effective when used in combination with your database's slow query logging capabilities. For high-traffic environments, consider configuring your database to log only a percentage of queries to manage logging overhead. + ## `DABProfiler` For profiling non-HTTP contexts, such as background tasks or gRPC services, the `DABProfiler` class can be used directly. @@ -76,6 +80,34 @@ def my_background_task(): print(f"Task took {elapsed:.3f}s to complete.") ``` +## `trace_context` for Background Tasks + +For adding observability to non-HTTP contexts without the overhead of the `DABProfiler`, the `trace_context` context manager is the ideal tool. It ensures that background tasks can be traced with a unique request ID, just like the `ObservabilityMiddleware` does for web requests. + +This is particularly useful for background tasks, such as those initiated by the controller's dispatcher, where you want to correlate all log messages for a specific operation. + +### Example Usage + +Here's how you might use the `trace_context` manager in the controller's dispatcher to ensure that all work related to a specific job has a consistent trace ID. + +```python +# In a hypothetical controller dispatcher task +from ansible_base.lib.logging.context import trace_context + +def run_job(job_id, parent_trace_id=None): + """ + A background task that runs a job. + """ + # Use the parent_trace_id if it exists; otherwise, a new one will be generated. + # The origin is a string that identifies the source of the trace. + with trace_context(origin='controller_dispatcher', trace_id=parent_trace_id): + # All logging within this block will now have the same trace_id. + # logger.info(f"Starting job {job_id}") + # ... do work ... + # logger.info(f"Finished job {job_id}") + pass +``` + ## Visualizing Profile Data The `.prof` files generated by the cProfile support can be analyzed with a variety of tools. diff --git a/test_app/tests/lib/logging/test_context.py b/test_app/tests/lib/logging/test_context.py index 4fa68f015..b66d03599 100644 --- a/test_app/tests/lib/logging/test_context.py +++ b/test_app/tests/lib/logging/test_context.py @@ -2,27 +2,31 @@ import threading import time import unittest +import uuid -from ansible_base.lib.logging.context import trace_id_var +import pytest +from ansible_base.lib.logging.context import origin_var, trace_context, trace_id_var + + +class TestTraceContextThreadSafety(unittest.TestCase): + """ + Tests the thread safety of context variables. + """ -class TestContextSafety(unittest.TestCase): def test_trace_id_is_thread_safe(self): """ - Verify that the trace_id context variable is thread-safe. + Verify that the trace_id context variable is thread-safe and does not leak between threads. """ results = [] def target_function(thread_id): # Set a unique trace ID for this thread trace_id_var.set(f"trace-id-{thread_id}") - # Sleep for a random, short duration to encourage thread interleaving time.sleep(random.uniform(0.01, 0.05)) - # Get the trace ID and verify it has not been changed by another thread retrieved_id = trace_id_var.get() - # Store the result of the check for the main thread to verify results.append(retrieved_id == f"trace-id-{thread_id}") @@ -38,3 +42,102 @@ def target_function(thread_id): # Verify that all threads successfully retrieved their own context self.assertEqual(len(results), 10, "Not all threads completed successfully.") self.assertTrue(all(results), "Context leaked between threads.") + + +class TestTraceContext: + """ + Tests the functionality of the trace_context context manager and decorator. + """ + + def test_generates_trace_id(self): + """ + Test that the context manager generates a new trace_id when none is provided. + """ + assert trace_id_var.get() is None + with trace_context(origin='test_origin'): + generated_id = trace_id_var.get() + assert generated_id is not None + assert isinstance(uuid.UUID(generated_id), uuid.UUID) + assert trace_id_var.get() is None + + def test_uses_provided_trace_id(self): + """ + Test that the context manager uses the trace_id that is passed to it. + """ + provided_id = str(uuid.uuid4()) + assert trace_id_var.get() is None + with trace_context(origin='test_origin', trace_id=provided_id): + assert trace_id_var.get() == provided_id + assert trace_id_var.get() is None + + def test_handles_invalid_trace_id(self): + """ + Test that the context manager generates a new trace_id if the provided one is invalid. + """ + invalid_id = 'not-a-uuid' + assert trace_id_var.get() is None + with trace_context(origin='test_origin', trace_id=invalid_id): + generated_id = trace_id_var.get() + assert generated_id is not None + assert generated_id != invalid_id + assert isinstance(uuid.UUID(generated_id), uuid.UUID) + assert trace_id_var.get() is None + + def test_resets_context_on_exception(self): + """ + Test that context variables are reset even if an exception is raised. + """ + assert trace_id_var.get() is None + with pytest.raises(ValueError): + with trace_context(origin='test_exception'): + raise ValueError("Test exception") + assert trace_id_var.get() is None + + def test_as_decorator(self): + """ + Test that the trace_context decorator sets and clears context correctly. + """ + + @trace_context(origin='test_decorator') + def my_function(): + assert trace_id_var.get() is not None + assert origin_var.get() == 'test_decorator' + + assert trace_id_var.get() is None + my_function() + assert trace_id_var.get() is None + + def test_decorator_with_provided_id(self): + """ + Test that the trace_context decorator uses a provided trace_id. + """ + provided_id = str(uuid.uuid4()) + + @trace_context(origin='test_decorator_id', trace_id=provided_id) + def my_function(): + assert trace_id_var.get() == provided_id + assert origin_var.get() == 'test_decorator_id' + + assert trace_id_var.get() is None + my_function() + assert trace_id_var.get() is None + + def test_nested_trace_context(self): + """ + Test that nested trace_context managers work correctly, restoring the previous context. + """ + outer_id = str(uuid.uuid4()) + with trace_context(origin='outer', trace_id=outer_id): + assert trace_id_var.get() == outer_id + assert origin_var.get() == 'outer' + + with trace_context(origin='inner'): + inner_id = trace_id_var.get() + assert inner_id is not None + assert inner_id != outer_id + assert origin_var.get() == 'inner' + + assert trace_id_var.get() == outer_id + assert origin_var.get() == 'outer' + + assert trace_id_var.get() is None