@@ -50,9 +50,19 @@ def test_inject_trace_context_empty_dict(self):
5050 # Execute - Actually test the mcpinstrumentor method
5151 self .instrumentor ._inject_trace_context (request_data , span_ctx )
5252
53- # Verify
54- expected = {"params" : {"_meta" : {"trace_context" : {"trace_id" : 12345 , "span_id" : 67890 }}}}
55- self .assertEqual (request_data , expected )
53+ # Verify - now uses traceparent W3C format
54+ self .assertIn ("params" , request_data )
55+ self .assertIn ("_meta" , request_data ["params" ])
56+ self .assertIn ("traceparent" , request_data ["params" ]["_meta" ])
57+
58+ # Verify traceparent format: "00-{trace_id:032x}-{span_id:016x}-01"
59+ traceparent = request_data ["params" ]["_meta" ]["traceparent" ]
60+ self .assertTrue (traceparent .startswith ("00-" ))
61+ self .assertTrue (traceparent .endswith ("-01" ))
62+ parts = traceparent .split ("-" )
63+ self .assertEqual (len (parts ), 4 )
64+ self .assertEqual (int (parts [1 ], 16 ), 12345 ) # trace_id
65+ self .assertEqual (int (parts [2 ], 16 ), 67890 ) # span_id
5666
5767 def test_inject_trace_context_existing_params (self ):
5868 """Test injecting trace context when params already exist"""
@@ -63,41 +73,52 @@ def test_inject_trace_context_existing_params(self):
6373 # Execute - Actually test the mcpinstrumentor method
6474 self .instrumentor ._inject_trace_context (request_data , span_ctx )
6575
66- # Verify the existing field is preserved and trace context is added
76+ # Verify the existing field is preserved and traceparent is added
6777 self .assertEqual (request_data ["params" ]["existing_field" ], "test_value" )
68- self .assertEqual (request_data ["params" ]["_meta" ]["trace_context" ]["trace_id" ], 99999 )
69- self .assertEqual (request_data ["params" ]["_meta" ]["trace_context" ]["span_id" ], 11111 )
78+ self .assertIn ("_meta" , request_data ["params" ])
79+ self .assertIn ("traceparent" , request_data ["params" ]["_meta" ])
80+
81+ # Verify traceparent format contains correct trace/span IDs
82+ traceparent = request_data ["params" ]["_meta" ]["traceparent" ]
83+ parts = traceparent .split ("-" )
84+ self .assertEqual (int (parts [1 ], 16 ), 99999 ) # trace_id
85+ self .assertEqual (int (parts [2 ], 16 ), 11111 ) # span_id
7086
7187
7288class TestTracerProvider (unittest .TestCase ):
7389 """Test the tracer provider kwargs logic in _instrument method"""
7490
7591 def setUp (self ):
7692 self .instrumentor = MCPInstrumentor ()
77- # Reset tracer_provider to ensure test isolation
78- if hasattr (self .instrumentor , "tracer_provider " ):
79- delattr (self .instrumentor , "tracer_provider " )
93+ # Reset tracer to ensure test isolation
94+ if hasattr (self .instrumentor , "tracer " ):
95+ delattr (self .instrumentor , "tracer " )
8096
8197 def test_instrument_without_tracer_provider_kwargs (self ):
82- """Test _instrument method when no tracer_provider in kwargs - should set to None """
98+ """Test _instrument method when no tracer_provider in kwargs - should use default tracer """
8399 # Execute - Actually test the mcpinstrumentor method
84- self .instrumentor ._instrument ()
100+ with unittest .mock .patch ("opentelemetry.trace.get_tracer" ) as mock_get_tracer :
101+ mock_get_tracer .return_value = "default_tracer"
102+ self .instrumentor ._instrument ()
85103
86- # Verify - tracer_provider should be None
87- self .assertTrue (hasattr (self .instrumentor , "tracer_provider" ))
88- self .assertIsNone (self .instrumentor .tracer_provider )
104+ # Verify - tracer should be set from trace.get_tracer
105+ self .assertTrue (hasattr (self .instrumentor , "tracer" ))
106+ self .assertEqual (self .instrumentor .tracer , "default_tracer" )
107+ mock_get_tracer .assert_called_with ("mcp" )
89108
90109 def test_instrument_with_tracer_provider_kwargs (self ):
91- """Test _instrument method when tracer_provider is in kwargs - should set to that value """
110+ """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer """
92111 # Setup
93112 provider = SimpleTracerProvider ()
94113
95114 # Execute - Actually test the mcpinstrumentor method
96115 self .instrumentor ._instrument (tracer_provider = provider )
97116
98- # Verify - tracer_provider should be set to the provided value
99- self .assertTrue (hasattr (self .instrumentor , "tracer_provider" ))
100- self .assertEqual (self .instrumentor .tracer_provider , provider )
117+ # Verify - tracer should be set from the provided tracer_provider
118+ self .assertTrue (hasattr (self .instrumentor , "tracer" ))
119+ self .assertEqual (self .instrumentor .tracer , "mock_tracer_from_provider" )
120+ self .assertTrue (provider .get_tracer_called )
121+ self .assertEqual (provider .tracer_name , "mcp" )
101122
102123
103124class TestInstrumentationDependencies (unittest .TestCase ):
@@ -171,9 +192,13 @@ def __init__(self, name, arguments=None):
171192 # Verify the actual mcpinstrumentor method worked correctly
172193 client_data = modified_request .model_dump ()
173194 self .assertIn ("_meta" , client_data ["params" ])
174- self .assertIn ("trace_context" , client_data ["params" ]["_meta" ])
175- self .assertEqual (client_data ["params" ]["_meta" ]["trace_context" ]["trace_id" ], 98765 )
176- self .assertEqual (client_data ["params" ]["_meta" ]["trace_context" ]["span_id" ], 43210 )
195+ self .assertIn ("traceparent" , client_data ["params" ]["_meta" ])
196+
197+ # Verify traceparent format contains correct trace/span IDs
198+ traceparent = client_data ["params" ]["_meta" ]["traceparent" ]
199+ parts = traceparent .split ("-" )
200+ self .assertEqual (int (parts [1 ], 16 ), 98765 ) # trace_id
201+ self .assertEqual (int (parts [2 ], 16 ), 43210 ) # span_id
177202
178203 # Verify the tool call data is also preserved
179204 self .assertEqual (client_data ["params" ]["name" ], "create_metric" )
@@ -185,7 +210,9 @@ class TestInstrumentedMCPServer(unittest.TestCase):
185210
186211 def setUp (self ):
187212 self .instrumentor = MCPInstrumentor ()
188- self .instrumentor .tracer_provider = None
213+ # Initialize tracer so the instrumentor can work
214+ mock_tracer = MagicMock ()
215+ self .instrumentor .tracer = mock_tracer
189216
190217 def test_no_trace_context_fallback (self ):
191218 """Test graceful handling when no trace context is present on server side"""
@@ -275,15 +302,15 @@ class MCPServerRequestParams:
275302 def __init__ (self , params_data ):
276303 self .name = params_data ["name" ]
277304 self .arguments = params_data .get ("arguments" )
278- # Extract trace context from _meta if present
279- if "_meta" in params_data and "trace_context " in params_data ["_meta" ]:
280- self .meta = MCPServerRequestMeta (params_data ["_meta" ]["trace_context " ])
305+ # Extract traceparent from _meta if present
306+ if "_meta" in params_data and "traceparent " in params_data ["_meta" ]:
307+ self .meta = MCPServerRequestMeta (params_data ["_meta" ]["traceparent " ])
281308 else :
282309 self .meta = None
283310
284311 class MCPServerRequestMeta :
285- def __init__ (self , trace_context ):
286- self .trace_context = trace_context
312+ def __init__ (self , traceparent ):
313+ self .traceparent = traceparent
287314
288315 # Mock client and server that actually communicate
289316 class EndToEndMCPSystem :
@@ -307,13 +334,19 @@ async def server_handle_request(self, session, server_request):
307334 """Server handles the request it received"""
308335 self .communication_log .append (f"SERVER: Received request for { server_request .params .name } " )
309336
310- # Check if trace context was received
311- if server_request .params .meta and server_request .params .meta .trace_context :
312- trace_info = server_request .params .meta .trace_context
313- self .communication_log .append (
314- f"SERVER: Found trace context - trace_id: { trace_info ['trace_id' ]} , "
315- f"span_id: { trace_info ['span_id' ]} "
316- )
337+ # Check if traceparent was received
338+ if server_request .params .meta and server_request .params .meta .traceparent :
339+ traceparent = server_request .params .meta .traceparent
340+ # Parse traceparent to extract trace_id and span_id
341+ parts = traceparent .split ("-" )
342+ if len (parts ) == 4 :
343+ trace_id = int (parts [1 ], 16 )
344+ span_id = int (parts [2 ], 16 )
345+ self .communication_log .append (
346+ f"SERVER: Found trace context - trace_id: { trace_id } , " f"span_id: { span_id } "
347+ )
348+ else :
349+ self .communication_log .append ("SERVER: Invalid traceparent format" )
317350 else :
318351 self .communication_log .append ("SERVER: No trace context found" )
319352
@@ -339,6 +372,8 @@ async def server_handle_request(self, session, server_request):
339372 with unittest .mock .patch ("opentelemetry.trace.get_tracer" , return_value = mock_tracer ), unittest .mock .patch .dict (
340373 "sys.modules" , {"mcp.types" : MagicMock ()}
341374 ), unittest .mock .patch .object (self .instrumentor , "handle_attributes" ):
375+ # Override the setup tracer with the properly mocked one
376+ self .instrumentor .tracer = mock_tracer
342377
343378 client_result = asyncio .run (
344379 self .instrumentor ._wrap_send_request (e2e_system .client_send_request , None , (original_request ,), {})
@@ -352,11 +387,15 @@ async def server_handle_request(self, session, server_request):
352387 sent_request = e2e_system .last_sent_request
353388 sent_request_data = sent_request .model_dump ()
354389
355- # Verify trace context was injected by client instrumentation
390+ # Verify traceparent was injected by client instrumentation
356391 self .assertIn ("_meta" , sent_request_data ["params" ])
357- self .assertIn ("trace_context" , sent_request_data ["params" ]["_meta" ])
358- self .assertEqual (sent_request_data ["params" ]["_meta" ]["trace_context" ]["trace_id" ], 12345 )
359- self .assertEqual (sent_request_data ["params" ]["_meta" ]["trace_context" ]["span_id" ], 67890 )
392+ self .assertIn ("traceparent" , sent_request_data ["params" ]["_meta" ])
393+
394+ # Parse and verify traceparent contains correct trace/span IDs
395+ traceparent = sent_request_data ["params" ]["_meta" ]["traceparent" ]
396+ parts = traceparent .split ("-" )
397+ self .assertEqual (int (parts [1 ], 16 ), 12345 ) # trace_id
398+ self .assertEqual (int (parts [2 ], 16 ), 67890 ) # span_id
360399
361400 # STEP 2: Server receives the EXACT request that client sent
362401 # Create server request from the client's serialized data
@@ -389,11 +428,15 @@ async def server_handle_request(self, session, server_request):
389428 self .assertEqual (server_request .params .arguments ["name" ], "cpu_usage" )
390429 self .assertEqual (server_request .params .arguments ["value" ], 85 )
391430
392- # Verify the trace context made it through end-to-end
431+ # Verify the traceparent made it through end-to-end
393432 self .assertIsNotNone (server_request .params .meta )
394- self .assertIsNotNone (server_request .params .meta .trace_context )
395- self .assertEqual (server_request .params .meta .trace_context ["trace_id" ], 12345 )
396- self .assertEqual (server_request .params .meta .trace_context ["span_id" ], 67890 )
433+ self .assertIsNotNone (server_request .params .meta .traceparent )
434+
435+ # Parse traceparent and verify trace/span IDs
436+ traceparent = server_request .params .meta .traceparent
437+ parts = traceparent .split ("-" )
438+ self .assertEqual (int (parts [1 ], 16 ), 12345 ) # trace_id
439+ self .assertEqual (int (parts [2 ], 16 ), 67890 ) # span_id
397440
398441 # Verify complete communication flow
399442 expected_log_entries = [
0 commit comments