Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ansible_base/lib/logging/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import contextvars
import uuid

# Define the context variables that will hold our trace information.
# Providing a default value is important so that they can be accessed
# even when the context has not been explicitly set.
trace_id_var = contextvars.ContextVar('trace_id', default=None)
origin_var = contextvars.ContextVar('origin', default=None)


class trace_context:
"""
A context manager and decorator to set the trace context for non-web operations.
"""

def __init__(self, origin=None, trace_id=None):
self.origin = origin
self.tokens = []

if trace_id:
try:
# Validate that the provided header is a valid UUID
uuid.UUID(trace_id)
self.trace_id = trace_id
except (ValueError, TypeError):
# If it's not a valid UUID, discard it and we'll generate a new one
self.trace_id = str(uuid.uuid4())
else:
self.trace_id = str(uuid.uuid4())

def __enter__(self):
# Set the trace ID for this context
self.tokens.append(trace_id_var.set(self.trace_id))

# Set the origin (e.g., 'dispatcher')
if self.origin:
self.tokens.append(origin_var.set(self.origin))

def __exit__(self, exc_type, exc_value, traceback):
# Reset the context variables to their previous state
for token in self.tokens:
var = token.var
var.reset(token)

def __call__(self, func):
def wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)

return wrapper
27 changes: 27 additions & 0 deletions ansible_base/lib/middleware/observability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
A single middleware to provide a unified observability layer, ensuring that context,
profiling, and SQL metrics are captured in the correct order.
"""

from .profiling.profile_request import _ProfileRequestMiddleware, _SQLProfilingMiddleware
from .request_context import _TraceContextMiddleware


class ObservabilityMiddleware:
"""
A single entry point for observability middleware.

This middleware composes the trace context, request profiling, and SQL
profiling middleware in the correct order. Instead of listing all three
in your settings, you can now just add this one.
"""

def __init__(self, get_response):
# Chain the middleware in the desired order. The request will flow
# from _TraceContextMiddleware -> _ProfileRequestMiddleware -> _SQLProfilingMiddleware.
handler = _SQLProfilingMiddleware(get_response)
handler = _ProfileRequestMiddleware(handler)
self.handler = _TraceContextMiddleware(handler)

def __call__(self, request):
return self.handler(request)
139 changes: 139 additions & 0 deletions ansible_base/lib/middleware/profiling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Request Profiling and Observability

The `ObservabilityMiddleware` provides a simple way to gain performance and debugging insights into your Django application. It acts as a single entry point for several underlying middleware components, ensuring they are always used in the correct order.

## `ObservabilityMiddleware`

This single middleware bundles tracing, request profiling, and SQL query analysis. To use it, add it to the top of your `MIDDLEWARE` list in your Django settings.

```python
# settings.py
MIDDLEWARE = [
'ansible_base.lib.middleware.observability.ObservabilityMiddleware',
...
]
```

The middleware always adds the following headers to the response:

* `X-Request-ID`: A unique identifier for the request. If the incoming request includes an `X-Request-ID` header, that value will be used; otherwise, a new UUID will be generated.
* `X-API-Time`: The total time taken to process the request, in seconds.
* `X-API-Node`: The cluster host ID of the node that served the request.

### cProfile Support

When the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting is enabled, the middleware will also perform a cProfile analysis for each request. The resulting `.prof` file is saved to a temporary directory on the node that served the request, and its path is returned in the `X-API-CProfile-File` response header. The filename will include the request's `X-Request-ID`.

To enable cProfile support, set the following in your Django settings:

```python
# settings.py
ANSIBLE_BASE_CPROFILE_REQUESTS = True
```

> **Note:** Enabling cProfile has significant performance implications and is intended for temporary, live debugging sessions, not for permanent use in production environments.

### SQL Profiling Support

When the `ANSIBLE_BASE_SQL_PROFILING` setting is enabled, the middleware provides insights into the database queries executed during a request. It adds the following headers to the response:

* `X-API-Query-Count`: The total number of database queries executed during the request.
* `X-API-Query-Time`: The total time spent on database queries, in seconds.

It also injects contextual information as a comment into each SQL query, which is invaluable for debugging and tracing. For example:
`/* trace_id=b71696ed-c483-408d-9740-2e7935b4f2d9, route=api/v2/users/{pk}/, origin=request */ SELECT ...`

To enable SQL profiling, set the following in your Django settings:

```python
# settings.py
ANSIBLE_BASE_SQL_PROFILING = True
```

> **Note:** This feature is most effective when used in combination with your database's slow query logging capabilities. For high-traffic environments, consider configuring your database to log only a percentage of queries to manage logging overhead.

## `DABProfiler`

For profiling non-HTTP contexts, such as background tasks or gRPC services, the `DABProfiler` class can be used directly.

The profiler's cProfile functionality is controlled by the `ANSIBLE_BASE_CPROFILE_REQUESTS` setting.

- When the setting is `True`, `profiler.stop()` returns a tuple of `(elapsed_time, cprofile_filename)`.
- When the setting is `False`, `profiler.stop()` returns `(elapsed_time, None)`.

### Example Usage

```python
from ansible_base.lib.middleware.profiling.profile_request import DABProfiler

def my_background_task():
profiler = DABProfiler()
profiler.start()

# Your code here

elapsed, cprofile_filename = profiler.stop()

if cprofile_filename:
print(f"cProfile data saved to: {cprofile_filename}")

print(f"Task took {elapsed:.3f}s to complete.")
```

## `trace_context` for Background Tasks

For adding observability to non-HTTP contexts without the overhead of the `DABProfiler`, the `trace_context` context manager is the ideal tool. It ensures that background tasks can be traced with a unique request ID, just like the `ObservabilityMiddleware` does for web requests.

This is particularly useful for background tasks, such as those initiated by the controller's dispatcher, where you want to correlate all log messages for a specific operation.

### Example Usage

Here's how you might use the `trace_context` manager in the controller's dispatcher to ensure that all work related to a specific job has a consistent trace ID.

```python
# In a hypothetical controller dispatcher task
from ansible_base.lib.logging.context import trace_context

def run_job(job_id, parent_trace_id=None):
"""
A background task that runs a job.
"""
# Use the parent_trace_id if it exists; otherwise, a new one will be generated.
# The origin is a string that identifies the source of the trace.
with trace_context(origin='controller_dispatcher', trace_id=parent_trace_id):
# All logging within this block will now have the same trace_id.
# logger.info(f"Starting job {job_id}")
# ... do work ...
# logger.info(f"Finished job {job_id}")
pass
```

## Visualizing Profile Data

The `.prof` files generated by the cProfile support can be analyzed with a variety of tools.

### SnakeViz

[SnakeViz](https://jiffyclub.github.io/snakeviz/) is a browser-based graphical viewer for the output of Python profilers.

You can install it with pip:
```bash
pip install snakeviz
```

To visualize a profile file, run:
```bash
snakeviz /path/to/your/profile.prof
```

### pstats

The standard library `pstats` module can also be used to read and manipulate profile data.

```python
import pstats

p = pstats.Stats('/path/to/your/profile.prof')
p.sort_stats('cumulative').print_stats(10)
```

163 changes: 163 additions & 0 deletions ansible_base/lib/middleware/profiling/profile_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import cProfile
import logging
import os
import tempfile
import threading
import time
import uuid
from typing import Optional, Union
from urllib.parse import quote

from django.conf import settings
from django.db import connection
from django.utils.translation import gettext_lazy as _

from ansible_base.lib.logging.context import origin_var, trace_id_var
from ansible_base.lib.utils.settings import get_function_from_setting, get_setting

logger = logging.getLogger(__name__)


class DABProfiler:
def __init__(self, *args, **kwargs):
self.cprofiling = bool(get_setting('ANSIBLE_BASE_CPROFILE_REQUESTS', False))
self.prof = None
self.start_time = None

def start(self):
self.start_time = time.time()
if self.cprofiling:
self.prof = cProfile.Profile()
self.prof.enable()

def stop(self, profile_id: Optional[Union[str, uuid.UUID]] = None):
if self.start_time is None:
logger.debug("Attempting to stop profiling without having started...")
return None, None

elapsed = time.time() - self.start_time

if not profile_id:
profile_id = uuid.uuid4()

cprofile_filename = None

if self.cprofiling and self.prof:
self.prof.disable()
temp_dir = tempfile.gettempdir()
filename = f"cprofile-{profile_id}.prof"
cprofile_filename = os.path.join(temp_dir, filename)
self.prof.dump_stats(cprofile_filename)

return elapsed, cprofile_filename


class _ProfileRequestMiddleware(threading.local):
def __init__(self, get_response=None):
self.get_response = get_response
self.profiler = DABProfiler()

def __call__(self, request):
# Logic before the view (formerly process_request)
self.profiler.start()
request_id = trace_id_var.get()

# Call the next middleware or the view
response = self.get_response(request)

# Logic after the view (formerly process_response)
if getattr(self.profiler, 'start_time', None) is None:
return response

elapsed, cprofile_filename = self.profiler.stop(profile_id=request_id)

if elapsed is not None:
response['X-API-Time'] = f'{elapsed:.3f}s'
if 'X-API-Node' not in response:
response['X-API-Node'] = get_setting('CLUSTER_HOST_ID', _('Unknown'))

if cprofile_filename:
response['X-API-CProfile-File'] = cprofile_filename
logger.debug(
f'request: {request}, cprofile_file: {response["X-API-CProfile-File"]}',
extra=dict(python_objects=dict(request=request, response=response, X_API_CPROFILE_FILE=response["X-API-CProfile-File"])),
)

return response


# Define the maximum length for a value in a SQL comment
SQL_COMMENT_MAX_LENGTH = 256


def _sanitize_for_sql_comment(value: str) -> str:
"""
Sanitizes a string for safe inclusion in a SQL comment.

- URL-encodes the value to handle special characters.
- Escapes the '%' character to prevent conflicts with database placeholders.
- Truncates the string to a maximum length.
"""
# URL-encode the value
quoted_value = quote(str(value))
# Escape the '%' character for the database driver
sanitized_value = quoted_value.replace('%', '%%')
# Truncate to the maximum length
return sanitized_value[:SQL_COMMENT_MAX_LENGTH]


class SQLQueryMetrics:
def __init__(self, request=None):
self.request = request
self.query_count = 0
self.query_time = 0.0

def __call__(self, execute, sql, params, many, context):
# Build the context comment
context_items = []
# trace_id is already validated as a UUID, so it is safe
if trace_id := trace_id_var.get():
context_items.append(f"trace_id='{trace_id}'")

# The route is only available after the URL resolver has run
if self.request and getattr(self.request, 'resolver_match', None):
if route := self.request.resolver_match.route:
context_items.append(f"route='{_sanitize_for_sql_comment(route)}'")

if origin := origin_var.get():
context_items.append(f"origin='{_sanitize_for_sql_comment(origin)}'")

if context_items:
comment = f"/* {', '.join(context_items)} */"
sql = f"{comment} {sql}"

start_time = time.time()
try:
return execute(sql, params, many, context)
finally:
self.query_count += 1
self.query_time += time.time() - start_time


class _SQLProfilingMiddleware:
def __init__(self, get_response):
self.get_response = get_response

def __call__(self, request):
if not get_setting('ANSIBLE_BASE_SQL_PROFILING', False):
return self.get_response(request)

# Check if the trace context is available. If not, log a warning.
if trace_id_var.get() is None:
logger.warning(
"ANSIBLE_BASE_SQL_PROFILING is enabled, but the trace context is not set. "
"Please use the ObservabilityMiddleware instead of including profiling middleware individually."
)

metrics = SQLQueryMetrics(request)
with connection.execute_wrapper(metrics):
response = self.get_response(request)

response['X-API-Query-Count'] = metrics.query_count
response['X-API-Query-Time'] = f'{metrics.query_time:.3f}s'
return response
Loading
Loading