@@ -266,6 +266,7 @@ def __init__(self, name: str) -> None:
266266 # Should not create traced spans when no trace context is present
267267 mock_tracer .start_as_current_span .assert_not_called ()
268268
269+ # pylint: disable=too-many-locals,too-many-statements
269270 def test_end_to_end_client_server_communication (
270271 self ,
271272 ) -> None :
@@ -465,3 +466,145 @@ async def server_handle_request(self, session: Any, server_request: Any) -> Dict
465466 any (expected_entry in log_entry for log_entry in e2e_system .communication_log ),
466467 f"Expected log entry '{ expected_entry } ' not found in: { e2e_system .communication_log } " ,
467468 )
469+
470+
471+ class TestMCPInstrumentorEdgeCases (unittest .TestCase ):
472+ """Test edge cases and error conditions for MCP instrumentor"""
473+
474+ def setUp (self ) -> None :
475+ self .instrumentor = MCPInstrumentor ()
476+
477+ def test_invalid_traceparent_format (self ) -> None :
478+ """Test handling of malformed traceparent headers"""
479+ invalid_formats = [
480+ "invalid-format" ,
481+ "00-invalid-hex-01" ,
482+ "00-12345-67890" , # Missing part
483+ "00-12345-67890-01-extra" , # Too many parts
484+ "" , # Empty string
485+ ]
486+
487+ for invalid_format in invalid_formats :
488+ with unittest .mock .patch .dict ("sys.modules" , {"mcp.types" : MagicMock ()}):
489+ result = self .instrumentor ._extract_span_context_from_traceparent (invalid_format )
490+ self .assertIsNone (result , f"Should return None for invalid format: { invalid_format } " )
491+
492+ def test_version_import (self ) -> None :
493+ """Test that version can be imported"""
494+ from amazon .opentelemetry .distro .instrumentation .mcp import version
495+
496+ self .assertIsNotNone (version )
497+
498+ def test_constants_import (self ) -> None :
499+ """Test that constants can be imported"""
500+ from amazon .opentelemetry .distro .instrumentation .mcp .constants import MCPEnvironmentVariables
501+
502+ self .assertIsNotNone (MCPEnvironmentVariables .SERVER_NAME )
503+
504+ def test_add_client_attributes_default_server_name (self ) -> None :
505+ """Test _add_client_attributes uses default server name"""
506+ mock_span = MagicMock ()
507+
508+ class MockRequest :
509+ def __init__ (self ) -> None :
510+ self .params = MockParams ()
511+
512+ class MockParams :
513+ def __init__ (self ) -> None :
514+ self .name = "test_tool"
515+
516+ request = MockRequest ()
517+ self .instrumentor ._add_client_attributes (mock_span , "test_operation" , request )
518+
519+ # Verify default server name is used
520+ mock_span .set_attribute .assert_any_call ("rpc.service" , "mcp server" )
521+ mock_span .set_attribute .assert_any_call ("rpc.method" , "test_operation" )
522+ mock_span .set_attribute .assert_any_call ("mcp.tool.name" , "test_tool" )
523+
524+ def test_add_client_attributes_without_tool_name (self ) -> None :
525+ """Test _add_client_attributes when request has no tool name"""
526+ mock_span = MagicMock ()
527+
528+ class MockRequestNoTool :
529+ def __init__ (self ) -> None :
530+ self .params = None
531+
532+ request = MockRequestNoTool ()
533+ self .instrumentor ._add_client_attributes (mock_span , "test_operation" , request )
534+
535+ # Should still set service and method, but not tool name
536+ mock_span .set_attribute .assert_any_call ("rpc.service" , "mcp server" )
537+ mock_span .set_attribute .assert_any_call ("rpc.method" , "test_operation" )
538+
539+ def test_add_server_attributes_without_tool_name (self ) -> None :
540+ """Test _add_server_attributes when request has no tool name"""
541+ mock_span = MagicMock ()
542+
543+ class MockRequestNoTool :
544+ def __init__ (self ) -> None :
545+ self .params = None
546+
547+ request = MockRequestNoTool ()
548+ self .instrumentor ._add_server_attributes (mock_span , "test_operation" , request )
549+
550+ # Should not set any attributes for server when no tool name
551+ mock_span .set_attribute .assert_not_called ()
552+
553+ def test_inject_trace_context_empty_request (self ) -> None :
554+ """Test trace context injection with minimal request data"""
555+ request_data = {}
556+ span_ctx = SimpleSpanContext (trace_id = 111 , span_id = 222 )
557+
558+ self .instrumentor ._inject_trace_context (request_data , span_ctx )
559+
560+ # Should create params and _meta structure
561+ self .assertIn ("params" , request_data )
562+ self .assertIn ("_meta" , request_data ["params" ])
563+ self .assertIn ("traceparent" , request_data ["params" ]["_meta" ])
564+
565+ # Verify traceparent format
566+ traceparent = request_data ["params" ]["_meta" ]["traceparent" ]
567+ parts = traceparent .split ("-" )
568+ self .assertEqual (len (parts ), 4 )
569+ self .assertEqual (int (parts [1 ], 16 ), 111 ) # trace_id
570+ self .assertEqual (int (parts [2 ], 16 ), 222 ) # span_id
571+
572+ def test_uninstrument (self ) -> None :
573+ """Test _uninstrument method removes instrumentation"""
574+ with unittest .mock .patch (
575+ "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap"
576+ ) as mock_unwrap :
577+ self .instrumentor ._uninstrument ()
578+
579+ # Verify both unwrap calls are made
580+ self .assertEqual (mock_unwrap .call_count , 2 )
581+ mock_unwrap .assert_any_call ("mcp.shared.session" , "BaseSession.send_request" )
582+ mock_unwrap .assert_any_call ("mcp.server.lowlevel.server" , "Server._handle_request" )
583+
584+ def test_extract_span_context_valid_traceparent (self ) -> None :
585+ """Test _extract_span_context_from_traceparent with valid format"""
586+ # Use correct hex values: 12345 = 0x3039, 67890 = 0x10932
587+ valid_traceparent = "00-0000000000003039-0000000000010932-01"
588+ result = self .instrumentor ._extract_span_context_from_traceparent (valid_traceparent )
589+
590+ self .assertIsNotNone (result )
591+ self .assertEqual (result .trace_id , 12345 )
592+ self .assertEqual (result .span_id , 67890 )
593+ self .assertTrue (result .is_remote )
594+
595+ def test_extract_span_context_value_error (self ) -> None :
596+ """Test _extract_span_context_from_traceparent with invalid hex values"""
597+ invalid_hex_traceparent = "00-invalid-hex-values-01"
598+ result = self .instrumentor ._extract_span_context_from_traceparent (invalid_hex_traceparent )
599+
600+ self .assertIsNone (result )
601+
602+ def test_instrument_method_coverage (self ) -> None :
603+ """Test _instrument method registers hooks"""
604+ with unittest .mock .patch (
605+ "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook"
606+ ) as mock_register :
607+ self .instrumentor ._instrument ()
608+
609+ # Should register two hooks
610+ self .assertEqual (mock_register .call_count , 2 )
0 commit comments