Skip to content

Commit 442a45a

Browse files
committed
Code Cleanup and To Test Dependencies
1 parent 49649e7 commit 442a45a

File tree

1 file changed

+85
-42
lines changed

1 file changed

+85
-42
lines changed

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,19 @@ def test_inject_trace_context_empty_dict(self):
5050
# Execute - Actually test the mcpinstrumentor method
5151
self.instrumentor._inject_trace_context(request_data, span_ctx)
5252

53-
# Verify
54-
expected = {"params": {"_meta": {"trace_context": {"trace_id": 12345, "span_id": 67890}}}}
55-
self.assertEqual(request_data, expected)
53+
# Verify - now uses traceparent W3C format
54+
self.assertIn("params", request_data)
55+
self.assertIn("_meta", request_data["params"])
56+
self.assertIn("traceparent", request_data["params"]["_meta"])
57+
58+
# Verify traceparent format: "00-{trace_id:032x}-{span_id:016x}-01"
59+
traceparent = request_data["params"]["_meta"]["traceparent"]
60+
self.assertTrue(traceparent.startswith("00-"))
61+
self.assertTrue(traceparent.endswith("-01"))
62+
parts = traceparent.split("-")
63+
self.assertEqual(len(parts), 4)
64+
self.assertEqual(int(parts[1], 16), 12345) # trace_id
65+
self.assertEqual(int(parts[2], 16), 67890) # span_id
5666

5767
def test_inject_trace_context_existing_params(self):
5868
"""Test injecting trace context when params already exist"""
@@ -63,41 +73,52 @@ def test_inject_trace_context_existing_params(self):
6373
# Execute - Actually test the mcpinstrumentor method
6474
self.instrumentor._inject_trace_context(request_data, span_ctx)
6575

66-
# Verify the existing field is preserved and trace context is added
76+
# Verify the existing field is preserved and traceparent is added
6777
self.assertEqual(request_data["params"]["existing_field"], "test_value")
68-
self.assertEqual(request_data["params"]["_meta"]["trace_context"]["trace_id"], 99999)
69-
self.assertEqual(request_data["params"]["_meta"]["trace_context"]["span_id"], 11111)
78+
self.assertIn("_meta", request_data["params"])
79+
self.assertIn("traceparent", request_data["params"]["_meta"])
80+
81+
# Verify traceparent format contains correct trace/span IDs
82+
traceparent = request_data["params"]["_meta"]["traceparent"]
83+
parts = traceparent.split("-")
84+
self.assertEqual(int(parts[1], 16), 99999) # trace_id
85+
self.assertEqual(int(parts[2], 16), 11111) # span_id
7086

7187

7288
class TestTracerProvider(unittest.TestCase):
7389
"""Test the tracer provider kwargs logic in _instrument method"""
7490

7591
def setUp(self):
7692
self.instrumentor = MCPInstrumentor()
77-
# Reset tracer_provider to ensure test isolation
78-
if hasattr(self.instrumentor, "tracer_provider"):
79-
delattr(self.instrumentor, "tracer_provider")
93+
# Reset tracer to ensure test isolation
94+
if hasattr(self.instrumentor, "tracer"):
95+
delattr(self.instrumentor, "tracer")
8096

8197
def test_instrument_without_tracer_provider_kwargs(self):
82-
"""Test _instrument method when no tracer_provider in kwargs - should set to None"""
98+
"""Test _instrument method when no tracer_provider in kwargs - should use default tracer"""
8399
# Execute - Actually test the mcpinstrumentor method
84-
self.instrumentor._instrument()
100+
with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer:
101+
mock_get_tracer.return_value = "default_tracer"
102+
self.instrumentor._instrument()
85103

86-
# Verify - tracer_provider should be None
87-
self.assertTrue(hasattr(self.instrumentor, "tracer_provider"))
88-
self.assertIsNone(self.instrumentor.tracer_provider)
104+
# Verify - tracer should be set from trace.get_tracer
105+
self.assertTrue(hasattr(self.instrumentor, "tracer"))
106+
self.assertEqual(self.instrumentor.tracer, "default_tracer")
107+
mock_get_tracer.assert_called_with("mcp")
89108

90109
def test_instrument_with_tracer_provider_kwargs(self):
91-
"""Test _instrument method when tracer_provider is in kwargs - should set to that value"""
110+
"""Test _instrument method when tracer_provider is in kwargs - should use provider's tracer"""
92111
# Setup
93112
provider = SimpleTracerProvider()
94113

95114
# Execute - Actually test the mcpinstrumentor method
96115
self.instrumentor._instrument(tracer_provider=provider)
97116

98-
# Verify - tracer_provider should be set to the provided value
99-
self.assertTrue(hasattr(self.instrumentor, "tracer_provider"))
100-
self.assertEqual(self.instrumentor.tracer_provider, provider)
117+
# Verify - tracer should be set from the provided tracer_provider
118+
self.assertTrue(hasattr(self.instrumentor, "tracer"))
119+
self.assertEqual(self.instrumentor.tracer, "mock_tracer_from_provider")
120+
self.assertTrue(provider.get_tracer_called)
121+
self.assertEqual(provider.tracer_name, "mcp")
101122

102123

103124
class TestInstrumentationDependencies(unittest.TestCase):
@@ -171,9 +192,13 @@ def __init__(self, name, arguments=None):
171192
# Verify the actual mcpinstrumentor method worked correctly
172193
client_data = modified_request.model_dump()
173194
self.assertIn("_meta", client_data["params"])
174-
self.assertIn("trace_context", client_data["params"]["_meta"])
175-
self.assertEqual(client_data["params"]["_meta"]["trace_context"]["trace_id"], 98765)
176-
self.assertEqual(client_data["params"]["_meta"]["trace_context"]["span_id"], 43210)
195+
self.assertIn("traceparent", client_data["params"]["_meta"])
196+
197+
# Verify traceparent format contains correct trace/span IDs
198+
traceparent = client_data["params"]["_meta"]["traceparent"]
199+
parts = traceparent.split("-")
200+
self.assertEqual(int(parts[1], 16), 98765) # trace_id
201+
self.assertEqual(int(parts[2], 16), 43210) # span_id
177202

178203
# Verify the tool call data is also preserved
179204
self.assertEqual(client_data["params"]["name"], "create_metric")
@@ -185,7 +210,9 @@ class TestInstrumentedMCPServer(unittest.TestCase):
185210

186211
def setUp(self):
187212
self.instrumentor = MCPInstrumentor()
188-
self.instrumentor.tracer_provider = None
213+
# Initialize tracer so the instrumentor can work
214+
mock_tracer = MagicMock()
215+
self.instrumentor.tracer = mock_tracer
189216

190217
def test_no_trace_context_fallback(self):
191218
"""Test graceful handling when no trace context is present on server side"""
@@ -275,15 +302,15 @@ class MCPServerRequestParams:
275302
def __init__(self, params_data):
276303
self.name = params_data["name"]
277304
self.arguments = params_data.get("arguments")
278-
# Extract trace context from _meta if present
279-
if "_meta" in params_data and "trace_context" in params_data["_meta"]:
280-
self.meta = MCPServerRequestMeta(params_data["_meta"]["trace_context"])
305+
# Extract traceparent from _meta if present
306+
if "_meta" in params_data and "traceparent" in params_data["_meta"]:
307+
self.meta = MCPServerRequestMeta(params_data["_meta"]["traceparent"])
281308
else:
282309
self.meta = None
283310

284311
class MCPServerRequestMeta:
285-
def __init__(self, trace_context):
286-
self.trace_context = trace_context
312+
def __init__(self, traceparent):
313+
self.traceparent = traceparent
287314

288315
# Mock client and server that actually communicate
289316
class EndToEndMCPSystem:
@@ -307,13 +334,19 @@ async def server_handle_request(self, session, server_request):
307334
"""Server handles the request it received"""
308335
self.communication_log.append(f"SERVER: Received request for {server_request.params.name}")
309336

310-
# Check if trace context was received
311-
if server_request.params.meta and server_request.params.meta.trace_context:
312-
trace_info = server_request.params.meta.trace_context
313-
self.communication_log.append(
314-
f"SERVER: Found trace context - trace_id: {trace_info['trace_id']}, "
315-
f"span_id: {trace_info['span_id']}"
316-
)
337+
# Check if traceparent was received
338+
if server_request.params.meta and server_request.params.meta.traceparent:
339+
traceparent = server_request.params.meta.traceparent
340+
# Parse traceparent to extract trace_id and span_id
341+
parts = traceparent.split("-")
342+
if len(parts) == 4:
343+
trace_id = int(parts[1], 16)
344+
span_id = int(parts[2], 16)
345+
self.communication_log.append(
346+
f"SERVER: Found trace context - trace_id: {trace_id}, " f"span_id: {span_id}"
347+
)
348+
else:
349+
self.communication_log.append("SERVER: Invalid traceparent format")
317350
else:
318351
self.communication_log.append("SERVER: No trace context found")
319352

@@ -339,6 +372,8 @@ async def server_handle_request(self, session, server_request):
339372
with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict(
340373
"sys.modules", {"mcp.types": MagicMock()}
341374
), unittest.mock.patch.object(self.instrumentor, "handle_attributes"):
375+
# Override the setup tracer with the properly mocked one
376+
self.instrumentor.tracer = mock_tracer
342377

343378
client_result = asyncio.run(
344379
self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {})
@@ -352,11 +387,15 @@ async def server_handle_request(self, session, server_request):
352387
sent_request = e2e_system.last_sent_request
353388
sent_request_data = sent_request.model_dump()
354389

355-
# Verify trace context was injected by client instrumentation
390+
# Verify traceparent was injected by client instrumentation
356391
self.assertIn("_meta", sent_request_data["params"])
357-
self.assertIn("trace_context", sent_request_data["params"]["_meta"])
358-
self.assertEqual(sent_request_data["params"]["_meta"]["trace_context"]["trace_id"], 12345)
359-
self.assertEqual(sent_request_data["params"]["_meta"]["trace_context"]["span_id"], 67890)
392+
self.assertIn("traceparent", sent_request_data["params"]["_meta"])
393+
394+
# Parse and verify traceparent contains correct trace/span IDs
395+
traceparent = sent_request_data["params"]["_meta"]["traceparent"]
396+
parts = traceparent.split("-")
397+
self.assertEqual(int(parts[1], 16), 12345) # trace_id
398+
self.assertEqual(int(parts[2], 16), 67890) # span_id
360399

361400
# STEP 2: Server receives the EXACT request that client sent
362401
# Create server request from the client's serialized data
@@ -389,11 +428,15 @@ async def server_handle_request(self, session, server_request):
389428
self.assertEqual(server_request.params.arguments["name"], "cpu_usage")
390429
self.assertEqual(server_request.params.arguments["value"], 85)
391430

392-
# Verify the trace context made it through end-to-end
431+
# Verify the traceparent made it through end-to-end
393432
self.assertIsNotNone(server_request.params.meta)
394-
self.assertIsNotNone(server_request.params.meta.trace_context)
395-
self.assertEqual(server_request.params.meta.trace_context["trace_id"], 12345)
396-
self.assertEqual(server_request.params.meta.trace_context["span_id"], 67890)
433+
self.assertIsNotNone(server_request.params.meta.traceparent)
434+
435+
# Parse traceparent and verify trace/span IDs
436+
traceparent = server_request.params.meta.traceparent
437+
parts = traceparent.split("-")
438+
self.assertEqual(int(parts[1], 16), 12345) # trace_id
439+
self.assertEqual(int(parts[2], 16), 67890) # span_id
397440

398441
# Verify complete communication flow
399442
expected_log_entries = [

0 commit comments

Comments
 (0)