Skip to content

Commit e7f5ad5

Browse files
committed
fix: Align trace_context with middleware and add tests
Refactors the existing `trace_context` manager to align its behavior with the `TraceContextMiddleware`, ensuring consistent and safe observability for non-web tasks. This commit addresses these issues by: - Modifying `trace_context` to only accept `origin` and `trace_id`. - Adding validation to ensure any provided `trace_id` is a valid UUID, matching the middleware's logic. - Removing the ability to create arbitrary context variables via `**kwargs`. Additionally, this change introduces: - Tests for the `trace_context` manager, covering its functionality, validation, thread-safety, and use as a decorator. - Documentation in the `README.md` with usage instructions and an example for background tasks.
1 parent afd6946 commit e7f5ad5

File tree

3 files changed

+155
-14
lines changed

3 files changed

+155
-14
lines changed

ansible_base/lib/logging/context.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,29 @@ class trace_context:
1313
A context manager and decorator to set the trace context for non-web operations.
1414
"""
1515

16-
def __init__(self, origin=None, **kwargs):
16+
def __init__(self, origin=None, trace_id=None):
1717
self.origin = origin
18-
self.kwargs = kwargs
1918
self.tokens = []
2019

20+
if trace_id:
21+
try:
22+
# Validate that the provided header is a valid UUID
23+
uuid.UUID(trace_id)
24+
self.trace_id = trace_id
25+
except (ValueError, TypeError):
26+
# If it's not a valid UUID, discard it and we'll generate a new one
27+
self.trace_id = str(uuid.uuid4())
28+
else:
29+
self.trace_id = str(uuid.uuid4())
30+
2131
def __enter__(self):
22-
# Set a new trace ID for this context
23-
self.tokens.append(trace_id_var.set(str(uuid.uuid4())))
32+
# Set the trace ID for this context
33+
self.tokens.append(trace_id_var.set(self.trace_id))
2434

2535
# Set the origin (e.g., 'dispatcher')
2636
if self.origin:
2737
self.tokens.append(origin_var.set(self.origin))
2838

29-
for key, value in self.kwargs.items():
30-
var = contextvars.ContextVar(key)
31-
self.tokens.append(var.set(value))
32-
3339
def __exit__(self, exc_type, exc_value, traceback):
3440
# Reset the context variables to their previous state
3541
for token in self.tokens:

ansible_base/lib/middleware/profiling/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ To enable cProfile support, set the following in your Django settings:
3131
ANSIBLE_BASE_CPROFILE_REQUESTS = True
3232
```
3333

34+
> **Note:** Enabling cProfile has significant performance implications and is intended for temporary, live debugging sessions, not for permanent use in production environments.
35+
3436
### SQL Profiling Support
3537

3638
When the `ANSIBLE_BASE_SQL_PROFILING` setting is enabled, the middleware provides insights into the database queries executed during a request. It adds the following headers to the response:
@@ -48,6 +50,8 @@ To enable SQL profiling, set the following in your Django settings:
4850
ANSIBLE_BASE_SQL_PROFILING = True
4951
```
5052

53+
> **Note:** This feature is most effective when used in combination with your database's slow query logging capabilities. For high-traffic environments, consider configuring your database to log only a percentage of queries to manage logging overhead.
54+
5155
## `DABProfiler`
5256

5357
For profiling non-HTTP contexts, such as background tasks or gRPC services, the `DABProfiler` class can be used directly.
@@ -76,6 +80,34 @@ def my_background_task():
7680
print(f"Task took {elapsed:.3f}s to complete.")
7781
```
7882

83+
## `trace_context` for Background Tasks
84+
85+
For adding observability to non-HTTP contexts without the overhead of the `DABProfiler`, the `trace_context` context manager is the ideal tool. It ensures that background tasks can be traced with a unique request ID, just like the `ObservabilityMiddleware` does for web requests.
86+
87+
This is particularly useful for background tasks, such as those initiated by the controller's dispatcher, where you want to correlate all log messages for a specific operation.
88+
89+
### Example Usage
90+
91+
Here's how you might use the `trace_context` manager in the controller's dispatcher to ensure that all work related to a specific job has a consistent trace ID.
92+
93+
```python
94+
# In a hypothetical controller dispatcher task
95+
from ansible_base.lib.logging.context import trace_context
96+
97+
def run_job(job_id, parent_trace_id=None):
98+
"""
99+
A background task that runs a job.
100+
"""
101+
# Use the parent_trace_id if it exists; otherwise, a new one will be generated.
102+
# The origin is a string that identifies the source of the trace.
103+
with trace_context(origin='controller_dispatcher', trace_id=parent_trace_id):
104+
# All logging within this block will now have the same trace_id.
105+
# logger.info(f"Starting job {job_id}")
106+
# ... do work ...
107+
# logger.info(f"Finished job {job_id}")
108+
pass
109+
```
110+
79111
## Visualizing Profile Data
80112

81113
The `.prof` files generated by the cProfile support can be analyzed with a variety of tools.

test_app/tests/lib/logging/test_context.py

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,31 @@
22
import threading
33
import time
44
import unittest
5+
import uuid
56

6-
from ansible_base.lib.logging.context import trace_id_var
7+
import pytest
78

9+
from ansible_base.lib.logging.context import origin_var, trace_context, trace_id_var
10+
11+
12+
class TestTraceContextThreadSafety(unittest.TestCase):
13+
"""
14+
Tests the thread safety of context variables.
15+
"""
816

9-
class TestContextSafety(unittest.TestCase):
1017
def test_trace_id_is_thread_safe(self):
1118
"""
12-
Verify that the trace_id context variable is thread-safe.
19+
Verify that the trace_id context variable is thread-safe and does not leak between threads.
1320
"""
1421
results = []
1522

1623
def target_function(thread_id):
1724
# Set a unique trace ID for this thread
1825
trace_id_var.set(f"trace-id-{thread_id}")
19-
2026
# Sleep for a random, short duration to encourage thread interleaving
2127
time.sleep(random.uniform(0.01, 0.05))
22-
2328
# Get the trace ID and verify it has not been changed by another thread
2429
retrieved_id = trace_id_var.get()
25-
2630
# Store the result of the check for the main thread to verify
2731
results.append(retrieved_id == f"trace-id-{thread_id}")
2832

@@ -38,3 +42,102 @@ def target_function(thread_id):
3842
# Verify that all threads successfully retrieved their own context
3943
self.assertEqual(len(results), 10, "Not all threads completed successfully.")
4044
self.assertTrue(all(results), "Context leaked between threads.")
45+
46+
47+
class TestTraceContext:
48+
"""
49+
Tests the functionality of the trace_context context manager and decorator.
50+
"""
51+
52+
def test_generates_trace_id(self):
53+
"""
54+
Test that the context manager generates a new trace_id when none is provided.
55+
"""
56+
assert trace_id_var.get() is None
57+
with trace_context(origin='test_origin'):
58+
generated_id = trace_id_var.get()
59+
assert generated_id is not None
60+
assert isinstance(uuid.UUID(generated_id), uuid.UUID)
61+
assert trace_id_var.get() is None
62+
63+
def test_uses_provided_trace_id(self):
64+
"""
65+
Test that the context manager uses the trace_id that is passed to it.
66+
"""
67+
provided_id = str(uuid.uuid4())
68+
assert trace_id_var.get() is None
69+
with trace_context(origin='test_origin', trace_id=provided_id):
70+
assert trace_id_var.get() == provided_id
71+
assert trace_id_var.get() is None
72+
73+
def test_handles_invalid_trace_id(self):
74+
"""
75+
Test that the context manager generates a new trace_id if the provided one is invalid.
76+
"""
77+
invalid_id = 'not-a-uuid'
78+
assert trace_id_var.get() is None
79+
with trace_context(origin='test_origin', trace_id=invalid_id):
80+
generated_id = trace_id_var.get()
81+
assert generated_id is not None
82+
assert generated_id != invalid_id
83+
assert isinstance(uuid.UUID(generated_id), uuid.UUID)
84+
assert trace_id_var.get() is None
85+
86+
def test_resets_context_on_exception(self):
87+
"""
88+
Test that context variables are reset even if an exception is raised.
89+
"""
90+
assert trace_id_var.get() is None
91+
with pytest.raises(ValueError):
92+
with trace_context(origin='test_exception'):
93+
raise ValueError("Test exception")
94+
assert trace_id_var.get() is None
95+
96+
def test_as_decorator(self):
97+
"""
98+
Test that the trace_context decorator sets and clears context correctly.
99+
"""
100+
101+
@trace_context(origin='test_decorator')
102+
def my_function():
103+
assert trace_id_var.get() is not None
104+
assert origin_var.get() == 'test_decorator'
105+
106+
assert trace_id_var.get() is None
107+
my_function()
108+
assert trace_id_var.get() is None
109+
110+
def test_decorator_with_provided_id(self):
111+
"""
112+
Test that the trace_context decorator uses a provided trace_id.
113+
"""
114+
provided_id = str(uuid.uuid4())
115+
116+
@trace_context(origin='test_decorator_id', trace_id=provided_id)
117+
def my_function():
118+
assert trace_id_var.get() == provided_id
119+
assert origin_var.get() == 'test_decorator_id'
120+
121+
assert trace_id_var.get() is None
122+
my_function()
123+
assert trace_id_var.get() is None
124+
125+
def test_nested_trace_context(self):
126+
"""
127+
Test that nested trace_context managers work correctly, restoring the previous context.
128+
"""
129+
outer_id = str(uuid.uuid4())
130+
with trace_context(origin='outer', trace_id=outer_id):
131+
assert trace_id_var.get() == outer_id
132+
assert origin_var.get() == 'outer'
133+
134+
with trace_context(origin='inner'):
135+
inner_id = trace_id_var.get()
136+
assert inner_id is not None
137+
assert inner_id != outer_id
138+
assert origin_var.get() == 'inner'
139+
140+
assert trace_id_var.get() == outer_id
141+
assert origin_var.get() == 'outer'
142+
143+
assert trace_id_var.get() is None

0 commit comments

Comments
 (0)