@@ -266,6 +266,7 @@ def __init__(self, name: str) -> None:
266
266
# Should not create traced spans when no trace context is present
267
267
mock_tracer .start_as_current_span .assert_not_called ()
268
268
269
+ # pylint: disable=too-many-locals,too-many-statements
269
270
def test_end_to_end_client_server_communication (
270
271
self ,
271
272
) -> None :
@@ -465,3 +466,145 @@ async def server_handle_request(self, session: Any, server_request: Any) -> Dict
465
466
any (expected_entry in log_entry for log_entry in e2e_system .communication_log ),
466
467
f"Expected log entry '{ expected_entry } ' not found in: { e2e_system .communication_log } " ,
467
468
)
469
+
470
+
471
+ class TestMCPInstrumentorEdgeCases (unittest .TestCase ):
472
+ """Test edge cases and error conditions for MCP instrumentor"""
473
+
474
+ def setUp (self ) -> None :
475
+ self .instrumentor = MCPInstrumentor ()
476
+
477
+ def test_invalid_traceparent_format (self ) -> None :
478
+ """Test handling of malformed traceparent headers"""
479
+ invalid_formats = [
480
+ "invalid-format" ,
481
+ "00-invalid-hex-01" ,
482
+ "00-12345-67890" , # Missing part
483
+ "00-12345-67890-01-extra" , # Too many parts
484
+ "" , # Empty string
485
+ ]
486
+
487
+ for invalid_format in invalid_formats :
488
+ with unittest .mock .patch .dict ("sys.modules" , {"mcp.types" : MagicMock ()}):
489
+ result = self .instrumentor ._extract_span_context_from_traceparent (invalid_format )
490
+ self .assertIsNone (result , f"Should return None for invalid format: { invalid_format } " )
491
+
492
+ def test_version_import (self ) -> None :
493
+ """Test that version can be imported"""
494
+ from amazon .opentelemetry .distro .instrumentation .mcp import version
495
+
496
+ self .assertIsNotNone (version )
497
+
498
+ def test_constants_import (self ) -> None :
499
+ """Test that constants can be imported"""
500
+ from amazon .opentelemetry .distro .instrumentation .mcp .constants import MCPEnvironmentVariables
501
+
502
+ self .assertIsNotNone (MCPEnvironmentVariables .SERVER_NAME )
503
+
504
+ def test_add_client_attributes_default_server_name (self ) -> None :
505
+ """Test _add_client_attributes uses default server name"""
506
+ mock_span = MagicMock ()
507
+
508
+ class MockRequest :
509
+ def __init__ (self ) -> None :
510
+ self .params = MockParams ()
511
+
512
+ class MockParams :
513
+ def __init__ (self ) -> None :
514
+ self .name = "test_tool"
515
+
516
+ request = MockRequest ()
517
+ self .instrumentor ._add_client_attributes (mock_span , "test_operation" , request )
518
+
519
+ # Verify default server name is used
520
+ mock_span .set_attribute .assert_any_call ("rpc.service" , "mcp server" )
521
+ mock_span .set_attribute .assert_any_call ("rpc.method" , "test_operation" )
522
+ mock_span .set_attribute .assert_any_call ("mcp.tool.name" , "test_tool" )
523
+
524
+ def test_add_client_attributes_without_tool_name (self ) -> None :
525
+ """Test _add_client_attributes when request has no tool name"""
526
+ mock_span = MagicMock ()
527
+
528
+ class MockRequestNoTool :
529
+ def __init__ (self ) -> None :
530
+ self .params = None
531
+
532
+ request = MockRequestNoTool ()
533
+ self .instrumentor ._add_client_attributes (mock_span , "test_operation" , request )
534
+
535
+ # Should still set service and method, but not tool name
536
+ mock_span .set_attribute .assert_any_call ("rpc.service" , "mcp server" )
537
+ mock_span .set_attribute .assert_any_call ("rpc.method" , "test_operation" )
538
+
539
+ def test_add_server_attributes_without_tool_name (self ) -> None :
540
+ """Test _add_server_attributes when request has no tool name"""
541
+ mock_span = MagicMock ()
542
+
543
+ class MockRequestNoTool :
544
+ def __init__ (self ) -> None :
545
+ self .params = None
546
+
547
+ request = MockRequestNoTool ()
548
+ self .instrumentor ._add_server_attributes (mock_span , "test_operation" , request )
549
+
550
+ # Should not set any attributes for server when no tool name
551
+ mock_span .set_attribute .assert_not_called ()
552
+
553
+ def test_inject_trace_context_empty_request (self ) -> None :
554
+ """Test trace context injection with minimal request data"""
555
+ request_data = {}
556
+ span_ctx = SimpleSpanContext (trace_id = 111 , span_id = 222 )
557
+
558
+ self .instrumentor ._inject_trace_context (request_data , span_ctx )
559
+
560
+ # Should create params and _meta structure
561
+ self .assertIn ("params" , request_data )
562
+ self .assertIn ("_meta" , request_data ["params" ])
563
+ self .assertIn ("traceparent" , request_data ["params" ]["_meta" ])
564
+
565
+ # Verify traceparent format
566
+ traceparent = request_data ["params" ]["_meta" ]["traceparent" ]
567
+ parts = traceparent .split ("-" )
568
+ self .assertEqual (len (parts ), 4 )
569
+ self .assertEqual (int (parts [1 ], 16 ), 111 ) # trace_id
570
+ self .assertEqual (int (parts [2 ], 16 ), 222 ) # span_id
571
+
572
+ def test_uninstrument (self ) -> None :
573
+ """Test _uninstrument method removes instrumentation"""
574
+ with unittest .mock .patch (
575
+ "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap"
576
+ ) as mock_unwrap :
577
+ self .instrumentor ._uninstrument ()
578
+
579
+ # Verify both unwrap calls are made
580
+ self .assertEqual (mock_unwrap .call_count , 2 )
581
+ mock_unwrap .assert_any_call ("mcp.shared.session" , "BaseSession.send_request" )
582
+ mock_unwrap .assert_any_call ("mcp.server.lowlevel.server" , "Server._handle_request" )
583
+
584
+ def test_extract_span_context_valid_traceparent (self ) -> None :
585
+ """Test _extract_span_context_from_traceparent with valid format"""
586
+ # Use correct hex values: 12345 = 0x3039, 67890 = 0x10932
587
+ valid_traceparent = "00-0000000000003039-0000000000010932-01"
588
+ result = self .instrumentor ._extract_span_context_from_traceparent (valid_traceparent )
589
+
590
+ self .assertIsNotNone (result )
591
+ self .assertEqual (result .trace_id , 12345 )
592
+ self .assertEqual (result .span_id , 67890 )
593
+ self .assertTrue (result .is_remote )
594
+
595
+ def test_extract_span_context_value_error (self ) -> None :
596
+ """Test _extract_span_context_from_traceparent with invalid hex values"""
597
+ invalid_hex_traceparent = "00-invalid-hex-values-01"
598
+ result = self .instrumentor ._extract_span_context_from_traceparent (invalid_hex_traceparent )
599
+
600
+ self .assertIsNone (result )
601
+
602
+ def test_instrument_method_coverage (self ) -> None :
603
+ """Test _instrument method registers hooks"""
604
+ with unittest .mock .patch (
605
+ "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook"
606
+ ) as mock_register :
607
+ self .instrumentor ._instrument ()
608
+
609
+ # Should register two hooks
610
+ self .assertEqual (mock_register .call_count , 2 )
0 commit comments