Skip to content

Commit 9806246

Browse files
committed
fix route in sql comment and also guard against sql injection
1 parent 15aaec6 commit 9806246

File tree

5 files changed

+112
-25
lines changed

5 files changed

+112
-25
lines changed

ansible_base/lib/logging/context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# Providing a default value is important so that they can be accessed
66
# even when the context has not been explicitly set.
77
trace_id_var = contextvars.ContextVar('trace_id', default=None)
8-
route_var = contextvars.ContextVar('route', default=None)
98
origin_var = contextvars.ContextVar('origin', default=None)
109

1110

ansible_base/lib/middleware/profiling/profile_request.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import time
77
import uuid
88
from typing import Optional, Union
9+
from urllib.parse import quote
910

1011
from django.conf import settings
1112
from django.db import connection
1213
from django.utils.translation import gettext_lazy as _
1314

14-
from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var
15+
from ansible_base.lib.logging.context import origin_var, trace_id_var
1516
from ansible_base.lib.utils.settings import get_function_from_setting, get_setting
1617

1718
logger = logging.getLogger(__name__)
@@ -85,23 +86,50 @@ def __call__(self, request):
8586
return response
8687

8788

89+
# Define the maximum length for a value in a SQL comment
90+
SQL_COMMENT_MAX_LENGTH = 256
91+
92+
93+
def _sanitize_for_sql_comment(value: str) -> str:
94+
"""
95+
Sanitizes a string for safe inclusion in a SQL comment.
96+
97+
- URL-encodes the value to handle special characters.
98+
- Escapes the '%' character to prevent conflicts with database placeholders.
99+
- Truncates the string to a maximum length.
100+
"""
101+
# URL-encode the value
102+
quoted_value = quote(str(value))
103+
# Escape the '%' character for the database driver
104+
sanitized_value = quoted_value.replace('%', '%%')
105+
# Truncate to the maximum length
106+
return sanitized_value[:SQL_COMMENT_MAX_LENGTH]
107+
108+
88109
class SQLQueryMetrics:
89-
def __init__(self):
110+
def __init__(self, request=None):
111+
self.request = request
90112
self.query_count = 0
91113
self.query_time = 0.0
92114

93115
def __call__(self, execute, sql, params, many, context):
94116
# Build the context comment
95117
context_items = []
118+
# trace_id is already validated as a UUID, so it is safe
96119
if trace_id := trace_id_var.get():
97-
context_items.append(f"trace_id={trace_id}")
98-
if route := route_var.get():
99-
context_items.append(f"route={route}")
120+
context_items.append(f"trace_id='{trace_id}'")
121+
122+
# The route is only available after the URL resolver has run
123+
if self.request and getattr(self.request, 'resolver_match', None):
124+
if route := self.request.resolver_match.route:
125+
context_items.append(f"route='{_sanitize_for_sql_comment(route)}'")
126+
100127
if origin := origin_var.get():
101-
context_items.append(f"origin={origin}")
128+
context_items.append(f"origin='{_sanitize_for_sql_comment(origin)}'")
102129

103130
if context_items:
104-
sql = f"/* {', '.join(context_items)} */ {sql}"
131+
comment = f"/* {', '.join(context_items)} */"
132+
sql = f"{comment} {sql}"
105133

106134
start_time = time.time()
107135
try:
@@ -126,7 +154,7 @@ def __call__(self, request):
126154
"Please use the ObservabilityMiddleware instead of including profiling middleware individually."
127155
)
128156

129-
metrics = SQLQueryMetrics()
157+
metrics = SQLQueryMetrics(request)
130158
with connection.execute_wrapper(metrics):
131159
response = self.get_response(request)
132160

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22

3-
from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var
3+
from ansible_base.lib.logging.context import origin_var, trace_id_var
44

55

66
class _TraceContextMiddleware:
@@ -10,13 +10,25 @@ def __init__(self, get_response):
1010
def __call__(self, request):
1111
# Set the context for the request and store the tokens
1212
origin_token = origin_var.set('request')
13-
# .get is case-insensitive, but we'll use lowercase for consistency
14-
trace_id = request.headers.get('x-request-id', str(uuid.uuid4()))
15-
trace_id_token = trace_id_var.set(trace_id)
1613

17-
route_token = None
18-
if request.resolver_match:
19-
route_token = route_var.set(request.resolver_match.route)
14+
# Get the request ID from the header
15+
header_trace_id = request.headers.get('x-request-id')
16+
trace_id = None
17+
18+
if header_trace_id:
19+
try:
20+
# Validate that the provided header is a valid UUID
21+
uuid.UUID(header_trace_id)
22+
trace_id = header_trace_id
23+
except ValueError:
24+
# If it's not a valid UUID, discard it and we'll generate a new one
25+
pass
26+
27+
# If no valid trace_id was found, generate a new one
28+
if not trace_id:
29+
trace_id = str(uuid.uuid4())
30+
31+
trace_id_token = trace_id_var.set(trace_id)
2032

2133
try:
2234
response = self.get_response(request)
@@ -25,7 +37,5 @@ def __call__(self, request):
2537
# Reset the context variables to their previous state
2638
origin_var.reset(origin_token)
2739
trace_id_var.reset(trace_id_token)
28-
if route_token:
29-
route_var.reset(route_token)
3040

3141
return response

test_app/tests/lib/middleware/test_profiling_middleware.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,22 @@ def test_logs_warning_if_context_middleware_is_missing(self, mock_logger):
137137

138138
class SQLQueryMetricsTest(TestCase):
139139
def test_sql_comment_injection(self):
140-
from ansible_base.lib.logging.context import origin_var, route_var, trace_id_var
140+
from django.test.client import RequestFactory
141141

142+
from ansible_base.lib.logging.context import origin_var, trace_id_var
143+
144+
# 1. Manually set the context, saving the tokens to reset it later.
142145
trace_id_token = trace_id_var.set("test-trace-id")
143-
route_token = route_var.set("test/route")
144146
origin_token = origin_var.set("test-origin")
145147

148+
# 2. Create a mock request and manually set the resolver_match
149+
factory = RequestFactory()
150+
request = factory.get('/test-db/')
151+
request.resolver_match = type('ResolverMatch', (), {'route': 'test/route'})
152+
146153
try:
147-
metrics = SQLQueryMetrics()
154+
# 3. Instantiate our metrics class and call it directly.
155+
metrics = SQLQueryMetrics(request)
148156
original_sql = "SELECT 1"
149157
modified_sql = ""
150158

@@ -155,15 +163,16 @@ def mock_execute(sql, params, many, context):
155163

156164
metrics(mock_execute, original_sql, [], False, {})
157165

166+
# 4. Assert that the SQL passed to our mock was correctly modified.
158167
self.assertIn("/*", modified_sql)
159-
self.assertIn("trace_id=test-trace-id", modified_sql)
160-
self.assertIn("route=test/route", modified_sql)
161-
self.assertIn("origin=test-origin", modified_sql)
168+
self.assertIn("trace_id='test-trace-id'", modified_sql)
169+
self.assertIn("route='test/route'", modified_sql)
170+
self.assertIn("origin='test-origin'", modified_sql)
162171
self.assertIn("*/", modified_sql)
163172
self.assertIn(original_sql, modified_sql)
164173
finally:
174+
# 5. Reset the context variables to their previous state.
165175
trace_id_var.reset(trace_id_token)
166-
route_var.reset(route_token)
167176
origin_var.reset(origin_token)
168177

169178

@@ -208,3 +217,26 @@ def test_observability_middleware_all_headers(self):
208217
# 3. From _SQLProfilingMiddleware: Check SQL headers
209218
self.assertIn('X-API-Query-Count', response)
210219
self.assertIn('X-API-Query-Time', response)
220+
221+
222+
class SQLCommentSanitizationTest(TestCase):
223+
def test_sanitization_escapes_disallowed_chars(self):
224+
from ansible_base.lib.middleware.profiling.profile_request import _sanitize_for_sql_comment
225+
226+
malicious_string = "*/; DROP TABLE users; --"
227+
sanitized = _sanitize_for_sql_comment(malicious_string)
228+
self.assertEqual(sanitized, "%%2A/%%3B%%20DROP%%20TABLE%%20users%%3B%%20--")
229+
230+
def test_sanitization_allows_safe_chars(self):
231+
from ansible_base.lib.middleware.profiling.profile_request import _sanitize_for_sql_comment
232+
233+
safe_string = "a-b_c.d/e123"
234+
sanitized = _sanitize_for_sql_comment(safe_string)
235+
self.assertEqual(sanitized, "a-b_c.d/e123")
236+
237+
def test_sanitization_truncates_long_strings(self):
238+
from ansible_base.lib.middleware.profiling.profile_request import SQL_COMMENT_MAX_LENGTH, _sanitize_for_sql_comment
239+
240+
long_string = "a" * (SQL_COMMENT_MAX_LENGTH + 100)
241+
sanitized = _sanitize_for_sql_comment(long_string)
242+
self.assertEqual(len(sanitized), SQL_COMMENT_MAX_LENGTH)

test_app/tests/lib/middleware/test_request_context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,21 @@ def test_context_does_not_bleed_between_requests(self):
8686
UUID(trace_id_2, version=4)
8787
except ValueError:
8888
self.fail("trace_id_2 is not a valid UUID4")
89+
90+
def test_discards_invalid_uuid_in_header(self):
91+
"""
92+
Test that the middleware discards an invalid UUID in the X-Request-ID
93+
header and generates a new, valid one.
94+
"""
95+
malicious_id = "not-a-uuid' --"
96+
response = self.client.get('/context/', HTTP_X_REQUEST_ID=malicious_id)
97+
self.assertEqual(response.status_code, 200)
98+
99+
# The trace_id in the response should be a new, valid UUID, not the malicious one.
100+
new_trace_id = response.headers.get("X-Request-ID")
101+
self.assertIsNotNone(new_trace_id)
102+
self.assertNotEqual(new_trace_id, malicious_id)
103+
try:
104+
UUID(new_trace_id, version=4)
105+
except ValueError:
106+
self.fail("The new trace_id is not a valid UUID4")

0 commit comments

Comments
 (0)