From d8a0ee3545dd10fd19a5dcaaeaa6069432406126 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 22 Jul 2025 18:02:11 -0700 Subject: [PATCH 01/41] Unit Test and MCP Instrumentor --- .../distro/mcpinstrumentor/README.md | 17 + .../distro/mcpinstrumentor/loggertwo.log | 0 .../distro/mcpinstrumentor/mcpinstrumentor.py | 187 +++++++ .../distro/mcpinstrumentor/pyproject.toml | 48 ++ .../distro/test_mcpinstrumentor.py | 499 ++++++++++++++++++ 5 files changed, 751 insertions(+) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/loggertwo.log create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md new file mode 100644 index 000000000..0e7b72007 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md @@ -0,0 +1,17 @@ +# MCP Instrumentor + +OpenTelemetry MCP instrumentation package. + +## Installation + +```bash +pip install mcpinstrumentor +``` + +## Usage + +```python +from mcpinstrumentor import MCPInstrumentor + +MCPInstrumentor().instrument() +``` \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/loggertwo.log b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/loggertwo.log new file mode 100644 index 000000000..e69de29bb diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py new file mode 100644 index 000000000..a3a2898ff --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -0,0 +1,187 @@ +import logging +from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast + +from openinference.instrumentation.mcp.package import _instruments +from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper + +from opentelemetry import context, propagate, trace +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.sdk.resources import Resource + + +def setup_loggertwo(): + logger = logging.getLogger("loggertwo") + logger.setLevel(logging.DEBUG) + handler = logging.FileHandler("loggertwo.log", mode="w") + handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + if not logger.handlers: + logger.addHandler(handler) + return logger + + +loggertwo = setup_loggertwo() + + +class MCPInstrumentor(BaseInstrumentor): + """ + An instrumenter for MCP. + """ + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: Any) -> None: + tracer_provider = kwargs.get("tracer_provider") # Move this line up + if tracer_provider: + self.tracer_provider = tracer_provider + else: + self.tracer_provider = None + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.shared.session", + "BaseSession.send_request", + self._send_request_wrapper, + ), + "mcp.shared.session", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.lowlevel.server", + "Server._handle_request", + self._server_handle_request_wrapper, + ), + "mcp.server.lowlevel.server", + ) + + def _uninstrument(self, **kwargs: Any) -> None: + unwrap("mcp.shared.session", "BaseSession.send_request") + unwrap("mcp.server.lowlevel.server", "Server._handle_request") + + def handle_attributes(self, span, request, is_client=True): + import mcp.types as types + + operation = "Server Handle Request" + if isinstance(request, types.ListToolsRequest): + operation = "ListTool" + span.set_attribute("mcp.list_tools", True) + elif isinstance(request, types.CallToolRequest): + if hasattr(request, "params") and hasattr(request.params, "name"): + operation = request.params.name + span.set_attribute("mcp.call_tool", True) + if is_client: + self._add_client_attributes(span, operation, request) + else: + self._add_server_attributes(span, operation, request) + + + + def _add_client_attributes(self, span, operation, request): + span.set_attribute("span.kind", "CLIENT") + span.set_attribute("aws.remote.service", "Appsignals MCP Server") + span.set_attribute("aws.remote.operation", operation) + if hasattr(request, "params") and hasattr(request.params, "name"): + span.set_attribute("tool.name", request.params.name) + + def _add_server_attributes(self, span, operation, request): + span.set_attribute("server_side", True) + span.set_attribute("aws.span.kind", "SERVER") + if hasattr(request, "params") and hasattr(request.params, "name"): + span.set_attribute("tool.name", request.params.name) + + def _inject_trace_context(self, request_data, span_ctx): + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"]["trace_context"] = {"trace_id": span_ctx.trace_id, "span_id": span_ctx.span_id} + + # Send Request Wrapper + def _send_request_wrapper(self, wrapped, instance, args, kwargs): + """ + Changes made: + The wrapper intercepts the request before sending, injects distributed tracing context into the + request's params._meta field and creates OpenTelemetry spans. The wrapper does not change anything else from the original function's + behavior because it reconstructs the request object with the same type and calling the original function with identical parameters. + """ + + async def async_wrapper(): + if self.tracer_provider is None: + tracer = trace.get_tracer("mcp.client") + else: + tracer = self.tracer_provider.get_tracer("mcp.client") + with tracer.start_as_current_span( + "client.send_request", kind=trace.SpanKind.CLIENT + ) as span: + span_ctx = span.get_span_context() + request = args[0] if len(args) > 0 else kwargs.get("request") + if request: + req_root = request.root if hasattr(request, "root") else request + + self.handle_attributes(span, req_root, True) + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + self._inject_trace_context(request_data, span_ctx) + # Reconstruct request object with injected trace context + modified_request = type(request).model_validate(request_data) + if len(args) > 0: + new_args = (modified_request,) + args[1:] + result = await wrapped(*new_args, **kwargs) + else: + kwargs["request"] = modified_request + result = await wrapped(*args, **kwargs) + else: + result = await wrapped(*args, **kwargs) + return result + + return async_wrapper() + + def getname(self, req): + span_name = "unknown" + import mcp.types as types + + if isinstance(req, types.ListToolsRequest): + span_name = "tools/list" + elif isinstance(req, types.CallToolRequest): + if hasattr(req, "params") and hasattr(req.params, "name"): + span_name = f"tools/{req.params.name}" + else: + span_name = "unknown" + return span_name + + # Handle Request Wrapper + async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs): + """ + Changes made: + This wrapper intercepts requests before processing, extracts distributed tracing context from + the request's params._meta field, and creates server-side OpenTelemetry spans linked to the client spans. The wrapper + also does not change the original function's behavior by calling it with identical parameters + ensuring no breaking changes to the MCP server functionality. + """ + req = args[1] if len(args) > 1 else None + trace_context = None + + if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: + trace_context = req.params.meta.trace_context + if trace_context: + + if self.tracer_provider is None: + tracer = trace.get_tracer("mcp.server") + else: + tracer = self.tracer_provider.get_tracer("mcp.server") + trace_id = trace_context.get("trace_id") + span_id = trace_context.get("span_id") + span_context = trace.SpanContext(trace_id=trace_id, span_id=span_id, is_remote=True,trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),trace_state=trace.TraceState()) + span_name = self.getname(req) + with tracer.start_as_current_span( + span_name, + kind=trace.SpanKind.SERVER, + context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), + ) as span: + self.handle_attributes(span, req, False) + result = await wrapped(*args, **kwargs) + return result + else: + return await wrapped(*args, **kwargs,) + diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml new file mode 100644 index 000000000..b15828005 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "amazon-opentelemetry-distro-mcpinstrumentor" +version = "0.1.0" +description = "OpenTelemetry MCP instrumentation for AWS Distro" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.9" +authors = [ + { name = "Johnny Lin", email = "jonzilin@amazon.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "opentelemetry-api", + "opentelemetry-instrumentation", + "opentelemetry-semantic-conventions", + "wrapt", + "opentelemetry-sdk", +] + +[project.optional-dependencies] +instruments = ["mcp"] + +[project.entry-points.opentelemetry_instrumentor] +mcp = "mcpinstrumentor:MCPInstrumentor" + +[tool.hatch.build.targets.sdist] +include = [ + "mcpinstrumentor.py", + "README.md" +] + +[tool.hatch.build.targets.wheel] +packages = ["."] \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py new file mode 100644 index 000000000..ef03fcaef --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -0,0 +1,499 @@ +""" +Super simple unit test for MCPInstrumentor - no mocking, just one functionality +Testing the _inject_trace_context method +""" + +import unittest +import sys +import os + +# Import the module under test +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../../../src')) +from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor + + +class SimpleSpanContext: + """Simple mock span context without using MagicMock""" + def __init__(self, trace_id, span_id): + self.trace_id = trace_id + self.span_id = span_id + + +class SimpleTracerProvider: + """Simple mock tracer provider without using MagicMock""" + def __init__(self): + self.get_tracer_called = False + self.tracer_name = None + + def get_tracer(self, name): + self.get_tracer_called = True + self.tracer_name = name + return "mock_tracer_from_provider" + + +class TestInjectTraceContext(unittest.TestCase): + """Test the _inject_trace_context method - simple functionality""" + + def setUp(self): + self.instrumentor = MCPInstrumentor() + + def test_inject_trace_context_empty_dict(self): + """Test injecting trace context into empty dictionary""" + # Setup + request_data = {} + span_ctx = SimpleSpanContext(trace_id=12345, span_id=67890) + + # Execute + self.instrumentor._inject_trace_context(request_data, span_ctx) + + # Verify + expected = { + "params": { + "_meta": { + "trace_context": { + "trace_id": 12345, + "span_id": 67890 + } + } + } + } + self.assertEqual(request_data, expected) + + def test_inject_trace_context_existing_params(self): + """Test injecting trace context when params already exist""" + # Setup + request_data = {"params": {"existing_field": "test_value"}} + span_ctx = SimpleSpanContext(trace_id=99999, span_id=11111) + + # Execute + self.instrumentor._inject_trace_context(request_data, span_ctx) + + # Verify the existing field is preserved and trace context is added + self.assertEqual(request_data["params"]["existing_field"], "test_value") + self.assertEqual(request_data["params"]["_meta"]["trace_context"]["trace_id"], 99999) + self.assertEqual(request_data["params"]["_meta"]["trace_context"]["span_id"], 11111) + + +class TestTracerProvider(unittest.TestCase): + """Test the tracer provider kwargs logic in _instrument method""" + + def setUp(self): + self.instrumentor = MCPInstrumentor() + # Reset tracer_provider to ensure test isolation + if hasattr(self.instrumentor, 'tracer_provider'): + delattr(self.instrumentor, 'tracer_provider') + + def test_instrument_without_tracer_provider_kwargs(self): + """Test _instrument method when no tracer_provider in kwargs - should set to None""" + # Execute - call _instrument without tracer_provider in kwargs + self.instrumentor._instrument() + + # Verify - tracer_provider should be None + self.assertTrue(hasattr(self.instrumentor, 'tracer_provider')) + self.assertIsNone(self.instrumentor.tracer_provider) + + def test_instrument_with_tracer_provider_kwargs(self): + """Test _instrument method when tracer_provider is in kwargs - should set to that value""" + # Setup + provider = SimpleTracerProvider() + + # Execute - call _instrument with tracer_provider in kwargs + self.instrumentor._instrument(tracer_provider=provider) + + # Verify - tracer_provider should be set to the provided value + self.assertTrue(hasattr(self.instrumentor, 'tracer_provider')) + self.assertEqual(self.instrumentor.tracer_provider, provider) + + +class TestAppSignalToolCallRequest(unittest.TestCase): + """Test with realistic AppSignal MCP server tool call requests""" + + def setUp(self): + self.instrumentor = MCPInstrumentor() + self.instrumentor.tracer_provider = None + + def test_appsignal_call_tool_request(self): + """Test with a realistic AppSignal MCP server CallToolRequest""" + # Create a realistic CallToolRequest for AppSignal MCP server + class CallToolRequest: + def __init__(self, tool_name, arguments): + self.root = self + self.params = CallToolParams(tool_name, arguments) + + def model_dump(self, **kwargs): + return { + "method": "call_tool", + "params": { + "name": self.params.name, + "arguments": self.params.arguments + } + } + + class CallToolParams: + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + # Create an AppSignal tool call request + appsignal_request = CallToolRequest( + tool_name="create_metric", + arguments={ + "metric_name": "response_time", + "value": 250, + "tags": {"endpoint": "/api/users", "method": "GET"} + } + ) + + # Verify the tool call request structure matches AppSignal expectations + request_data = appsignal_request.model_dump() + self.assertEqual(request_data["method"], "call_tool") + self.assertEqual(request_data["params"]["name"], "create_metric") + self.assertEqual(request_data["params"]["arguments"]["metric_name"], "response_time") + self.assertEqual(request_data["params"]["arguments"]["value"], 250) + self.assertIn("endpoint", request_data["params"]["arguments"]["tags"]) + self.assertEqual(request_data["params"]["arguments"]["tags"]["endpoint"], "/api/users") + + # Test expected AppSignal response structure + expected_appsignal_result = { + "success": True, + "metric_id": "metric_12345", + "message": "Metric 'response_time' created successfully", + "metadata": { + "timestamp": "2025-01-22T21:19:00Z", + "tags_applied": ["endpoint:/api/users", "method:GET"], + "value_recorded": 250 + } + } + + self.assertTrue(expected_appsignal_result["success"]) + self.assertIn("metric_id", expected_appsignal_result) + self.assertEqual(expected_appsignal_result["metadata"]["value_recorded"], 250) + + def test_appsignal_tool_without_arguments(self): + """Test AppSignal tool call that doesn't require arguments""" + # Create a realistic CallToolRequest for tools without arguments + class CallToolRequestNoArgs: + def __init__(self, tool_name): + self.root = self + self.params = CallToolParamsNoArgs(tool_name) + + def model_dump(self, **kwargs): + return { + "method": "call_tool", + "params": { + "name": self.params.name + # No arguments field for this tool + } + } + + @classmethod + def model_validate(cls, data): + return cls(data["params"]["name"]) + + class CallToolParamsNoArgs: + def __init__(self, name): + self.name = name + # No arguments attribute for this tool + + # Create an AppSignal tool call without arguments + list_apps_request = CallToolRequestNoArgs(tool_name="list_applications") + + # Test argument detection + args_with_request = (list_apps_request,) + kwargs_empty = {} + extracted_request = args_with_request[0] if len(args_with_request) > 0 else kwargs_empty.get("request") + + self.assertEqual(extracted_request, list_apps_request) + self.assertEqual(extracted_request.params.name, "list_applications") + + # Verify the request structure for tools without arguments + request_data = list_apps_request.model_dump() + self.assertEqual(request_data["method"], "call_tool") + self.assertEqual(request_data["params"]["name"], "list_applications") + self.assertNotIn("arguments", request_data["params"]) # No arguments field + + # Test expected result for list_applications tool + expected_list_apps_result = { + "success": True, + "applications": [ + { + "id": "app_001", + "name": "web-frontend", + "environment": "production" + }, + { + "id": "app_002", + "name": "api-backend", + "environment": "staging" + } + ], + "total_count": 2 + } + + # Verify expected response structure + self.assertTrue(expected_list_apps_result["success"]) + self.assertIn("applications", expected_list_apps_result) + self.assertEqual(expected_list_apps_result["total_count"], 2) + self.assertEqual(len(expected_list_apps_result["applications"]), 2) + self.assertEqual(expected_list_apps_result["applications"][0]["name"], "web-frontend") + + def test_send_request_wrapper_argument_reconstruction(self): + """Test the argument logic: if len(args) > 0 vs else path""" + # Create a realistic AppSignal request + class CallToolRequest: + def __init__(self, tool_name, arguments=None): + self.root = self + self.params = CallToolParams(tool_name, arguments) + + def model_dump(self, by_alias=True, mode="json", exclude_none=True): + result = { + "method": "call_tool", + "params": { + "name": self.params.name + } + } + if self.params.arguments: + result["params"]["arguments"] = self.params.arguments + return result + + @classmethod + def model_validate(cls, data): + return cls( + data["params"]["name"], + data["params"].get("arguments") + ) + + class CallToolParams: + def __init__(self, name, arguments=None): + self.name = name + self.arguments = arguments + + request = CallToolRequest("create_metric", {"metric_name": "test", "value": 100}) + + # Test 1: len(args) > 0 path - should trigger new_args = (modified_request,) + args[1:] + args_with_request = (request, "extra_arg1", "extra_arg2") + kwargs_test = {"extra_kwarg": "test"} + + # Simulate what the wrapper logic does + if len(args_with_request) > 0: + # This tests: new_args = (modified_request,) + args[1:] + new_args = ("modified_request_placeholder",) + args_with_request[1:] + result_args = new_args + result_kwargs = kwargs_test + else: + # This shouldn't happen in this test + result_args = args_with_request + result_kwargs = kwargs_test.copy() + result_kwargs["request"] = "modified_request_placeholder" + + # Verify args path reconstruction + self.assertEqual(len(result_args), 3) # modified_request + 2 extra args + self.assertEqual(result_args[0], "modified_request_placeholder") + self.assertEqual(result_args[1], "extra_arg1") # args[1:] preserved + self.assertEqual(result_args[2], "extra_arg2") # args[1:] preserved + self.assertEqual(result_kwargs["extra_kwarg"], "test") + self.assertNotIn("request", result_kwargs) # Should NOT modify kwargs in args path + + # Test 2: len(args) == 0 path - should trigger kwargs["request"] = modified_request + args_empty = () + kwargs_with_request = {"request": request, "other_param": "value"} + + # Simulate what the wrapper logic does + if len(args_empty) > 0: + # This shouldn't happen in this test + new_args = ("modified_request_placeholder",) + args_empty[1:] + result_args = new_args + result_kwargs = kwargs_with_request + else: + # This tests: kwargs["request"] = modified_request + result_args = args_empty + result_kwargs = kwargs_with_request.copy() + result_kwargs["request"] = "modified_request_placeholder" + + # Verify kwargs path reconstruction + self.assertEqual(len(result_args), 0) # No positional args + self.assertEqual(result_kwargs["request"], "modified_request_placeholder") + self.assertEqual(result_kwargs["other_param"], "value") # Other kwargs preserved + + def test_server_handle_request_wrapper_logic(self): + """Test the _server_handle_request_wrapper""" + # Create realistic server request structures + class ServerRequest: + def __init__(self, has_trace_context=False): + if has_trace_context: + self.params = ServerRequestParams(has_meta=True) + else: + self.params = ServerRequestParams(has_meta=False) + + class ServerRequestParams: + def __init__(self, has_meta=False): + if has_meta: + self.meta = ServerRequestMeta() + else: + self.meta = None + + class ServerRequestMeta: + def __init__(self): + self.trace_context = { + "trace_id": 12345, + "span_id": 67890 + } + + # Test 1: Request WITHOUT trace context - should take else path + request_no_trace = ServerRequest(has_trace_context=False) + + # wrapper's request extraction logic + args_with_request = (None, request_no_trace) # args[1] is the request + req = args_with_request[1] if len(args_with_request) > 1 else None + + # Check trace context extraction logic + trace_context = None + if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: + trace_context = req.params.meta.trace_context + + # Verify - should NOT find trace context + self.assertIsNotNone(req) + self.assertIsNotNone(req.params) + self.assertIsNone(req.params.meta) # No meta field + self.assertIsNone(trace_context) + + # Test 2: Request WITH trace context - should take if path + request_with_trace = ServerRequest(has_trace_context=True) + + # Simulate the wrapper's request extraction logic + args_with_trace = (None, request_with_trace) # args[1] is the request + req2 = args_with_trace[1] if len(args_with_trace) > 1 else None + + # Check trace context extraction logic + trace_context2 = None + if req2 and hasattr(req2, "params") and req2.params and hasattr(req2.params, "meta") and req2.params.meta: + trace_context2 = req2.params.meta.trace_context + + # Verify - should find trace context + self.assertIsNotNone(req2) + self.assertIsNotNone(req2.params) + self.assertIsNotNone(req2.params.meta) # Has meta field + self.assertIsNotNone(trace_context2) + self.assertEqual(trace_context2["trace_id"], 12345) + self.assertEqual(trace_context2["span_id"], 67890) + + # Test 3: No request at all (args[1] doesn't exist) + args_no_request = (None,) # Only one arg, no request + req3 = args_no_request[1] if len(args_no_request) > 1 else None + + # Verify - should handle missing request gracefully + self.assertIsNone(req3) + + def test_end_to_end_trace_context_propagation(self): + """Test client sending trace context and server receiving the same trace context""" + # STEP 1: CLIENT SIDE - Create and prepare request with trace context + + # Create a realistic AppSignal request (what client would send) + class CallToolRequest: + def __init__(self, tool_name, arguments=None): + self.root = self + self.params = CallToolParams(tool_name, arguments) + + def model_dump(self, by_alias=True, mode="json", exclude_none=True): + result = { + "method": "call_tool", + "params": { + "name": self.params.name + } + } + if self.params.arguments: + result["params"]["arguments"] = self.params.arguments + # Include _meta if it exists (trace context injection point) + if hasattr(self.params, '_meta') and self.params._meta: + result["params"]["_meta"] = self.params._meta + return result + + @classmethod + def model_validate(cls, data): + instance = cls( + data["params"]["name"], + data["params"].get("arguments") + ) + # Restore _meta field if present + if "_meta" in data["params"]: + instance.params._meta = data["params"]["_meta"] + return instance + + class CallToolParams: + def __init__(self, name, arguments=None): + self.name = name + self.arguments = arguments + self._meta = None # Will hold trace context + + # Client creates original request + client_request = CallToolRequest("create_metric", {"metric_name": "response_time", "value": 250}) + + # Client injects trace context (what _send_request_wrapper does) + original_trace_context = SimpleSpanContext(trace_id=98765, span_id=43210) + + # Get request data and inject trace context + request_data = client_request.model_dump() + self.instrumentor._inject_trace_context(request_data, original_trace_context) + + # Create modified request with trace context (what client sends over network) + modified_request = CallToolRequest.model_validate(request_data) + + # Verify client successfully injected trace context + client_data = modified_request.model_dump() + self.assertIn("_meta", client_data["params"]) + self.assertIn("trace_context", client_data["params"]["_meta"]) + self.assertEqual(client_data["params"]["_meta"]["trace_context"]["trace_id"], 98765) + self.assertEqual(client_data["params"]["_meta"]["trace_context"]["span_id"], 43210) + + # STEP 2: SERVER SIDE - Receive and extract trace context + + # Create server request structure (what server receives) + class ServerRequest: + def __init__(self, client_request_data): + self.params = ServerRequestParams(client_request_data["params"]) + + class ServerRequestParams: + def __init__(self, params_data): + self.name = params_data["name"] + if "arguments" in params_data: + self.arguments = params_data["arguments"] + # Extract meta field (trace context) + if "_meta" in params_data: + self.meta = ServerRequestMeta(params_data["_meta"]) + else: + self.meta = None + + class ServerRequestMeta: + def __init__(self, meta_data): + self.trace_context = meta_data["trace_context"] + + # Server receives the request (simulating network transmission) + server_request = ServerRequest(client_data) + + # Server extracts trace context (what _server_handle_request_wrapper does) + args_with_request = (None, server_request) # args[1] is the request + req = args_with_request[1] if len(args_with_request) > 1 else None + + # Extract trace context using server logic + extracted_trace_context = None + if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: + extracted_trace_context = req.params.meta.trace_context + + # STEP 3: VERIFY END-TO-END PROPAGATION + + # Verify server successfully received the trace context + self.assertIsNotNone(extracted_trace_context) + self.assertEqual(extracted_trace_context["trace_id"], 98765) + self.assertEqual(extracted_trace_context["span_id"], 43210) + + # Verify it's the SAME trace context that client sent + self.assertEqual(extracted_trace_context["trace_id"], original_trace_context.trace_id) + self.assertEqual(extracted_trace_context["span_id"], original_trace_context.span_id) + + # Verify the tool call data is also preserved + self.assertEqual(server_request.params.name, "create_metric") + self.assertEqual(server_request.params.arguments["metric_name"], "response_time") + + +if __name__ == '__main__': + unittest.main() From b2bad7c2283aade66b115b99bb2c63b62a5db2c9 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 23 Jul 2025 10:52:38 -0700 Subject: [PATCH 02/41] Fix lint errors --- .flake8 | 3 + aws-opentelemetry-distro/pyproject.toml | 3 +- .../distro/mcpinstrumentor/mcpinstrumentor.py | 44 +-- .../distro/test_mcpinstrumentor.py | 259 ++++++++---------- 4 files changed, 139 insertions(+), 170 deletions(-) diff --git a/.flake8 b/.flake8 index e55e479dd..206a12e12 100644 --- a/.flake8 +++ b/.flake8 @@ -19,6 +19,9 @@ exclude = CVS .venv*/ venv*/ + **/venv*/ + **/.venv*/ + aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/venv target __pycache__ mock_collector_service_pb2.py diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 414b09221..a831eacb4 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "opentelemetry-instrumentation-boto == 0.54b1", "opentelemetry-instrumentation-boto3sqs == 0.54b1", "opentelemetry-instrumentation-botocore == 0.54b1", + "opentelemetry-instrumentation-cassandra == 0.54b1", "opentelemetry-instrumentation-celery == 0.54b1", "opentelemetry-instrumentation-confluent-kafka == 0.54b1", "opentelemetry-instrumentation-dbapi == 0.54b1", @@ -81,9 +82,9 @@ dependencies = [ "opentelemetry-instrumentation-urllib == 0.54b1", "opentelemetry-instrumentation-urllib3 == 0.54b1", "opentelemetry-instrumentation-wsgi == 0.54b1", - "opentelemetry-instrumentation-cassandra == 0.54b1", ] + [project.optional-dependencies] # The 'patch' optional dependency is used for applying patches to specific libraries. # If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index a3a2898ff..7780e97d4 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -1,13 +1,12 @@ import logging -from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast +from typing import Any, Collection from openinference.instrumentation.mcp.package import _instruments -from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper +from wrapt import register_post_import_hook, wrap_function_wrapper -from opentelemetry import context, propagate, trace +from opentelemetry import trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap -from opentelemetry.sdk.resources import Resource def setup_loggertwo(): @@ -69,15 +68,13 @@ def handle_attributes(self, span, request, is_client=True): span.set_attribute("mcp.list_tools", True) elif isinstance(request, types.CallToolRequest): if hasattr(request, "params") and hasattr(request.params, "name"): - operation = request.params.name + operation = request.params.name span.set_attribute("mcp.call_tool", True) if is_client: self._add_client_attributes(span, operation, request) else: self._add_server_attributes(span, operation, request) - - def _add_client_attributes(self, span, operation, request): span.set_attribute("span.kind", "CLIENT") span.set_attribute("aws.remote.service", "Appsignals MCP Server") @@ -103,8 +100,9 @@ def _send_request_wrapper(self, wrapped, instance, args, kwargs): """ Changes made: The wrapper intercepts the request before sending, injects distributed tracing context into the - request's params._meta field and creates OpenTelemetry spans. The wrapper does not change anything else from the original function's - behavior because it reconstructs the request object with the same type and calling the original function with identical parameters. + request's params._meta field and creates OpenTelemetry spans. The wrapper does not change anything + else from the original function's behavior because it reconstructs the request object with the same + type and calling the original function with identical parameters. """ async def async_wrapper(): @@ -112,14 +110,12 @@ async def async_wrapper(): tracer = trace.get_tracer("mcp.client") else: tracer = self.tracer_provider.get_tracer("mcp.client") - with tracer.start_as_current_span( - "client.send_request", kind=trace.SpanKind.CLIENT - ) as span: + with tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span: span_ctx = span.get_span_context() request = args[0] if len(args) > 0 else kwargs.get("request") if request: req_root = request.root if hasattr(request, "root") else request - + self.handle_attributes(span, req_root, True) request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) self._inject_trace_context(request_data, span_ctx) @@ -155,24 +151,30 @@ async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs): """ Changes made: This wrapper intercepts requests before processing, extracts distributed tracing context from - the request's params._meta field, and creates server-side OpenTelemetry spans linked to the client spans. The wrapper - also does not change the original function's behavior by calling it with identical parameters + the request's params._meta field, and creates server-side OpenTelemetry spans linked to the client spans. + The wrapper also does not change the original function's behavior by calling it with identical parameters ensuring no breaking changes to the MCP server functionality. """ req = args[1] if len(args) > 1 else None trace_context = None - + if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: trace_context = req.params.meta.trace_context if trace_context: - + if self.tracer_provider is None: tracer = trace.get_tracer("mcp.server") else: tracer = self.tracer_provider.get_tracer("mcp.server") trace_id = trace_context.get("trace_id") span_id = trace_context.get("span_id") - span_context = trace.SpanContext(trace_id=trace_id, span_id=span_id, is_remote=True,trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),trace_state=trace.TraceState()) + span_context = trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), + trace_state=trace.TraceState(), + ) span_name = self.getname(req) with tracer.start_as_current_span( span_name, @@ -183,5 +185,7 @@ async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs): result = await wrapped(*args, **kwargs) return result else: - return await wrapped(*args, **kwargs,) - + return await wrapped( + *args, + **kwargs, + ) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index ef03fcaef..ca3673567 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -3,17 +3,17 @@ Testing the _inject_trace_context method """ -import unittest -import sys import os +import sys +import unittest -# Import the module under test -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../../../src')) -from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../../../src")) +from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor # noqa: E402 class SimpleSpanContext: """Simple mock span context without using MagicMock""" + def __init__(self, trace_id, span_id): self.trace_id = trace_id self.span_id = span_id @@ -21,10 +21,11 @@ def __init__(self, trace_id, span_id): class SimpleTracerProvider: """Simple mock tracer provider without using MagicMock""" + def __init__(self): self.get_tracer_called = False self.tracer_name = None - + def get_tracer(self, name): self.get_tracer_called = True self.tracer_name = name @@ -33,41 +34,32 @@ def get_tracer(self, name): class TestInjectTraceContext(unittest.TestCase): """Test the _inject_trace_context method - simple functionality""" - + def setUp(self): self.instrumentor = MCPInstrumentor() - + def test_inject_trace_context_empty_dict(self): """Test injecting trace context into empty dictionary""" # Setup request_data = {} span_ctx = SimpleSpanContext(trace_id=12345, span_id=67890) - + # Execute self.instrumentor._inject_trace_context(request_data, span_ctx) - + # Verify - expected = { - "params": { - "_meta": { - "trace_context": { - "trace_id": 12345, - "span_id": 67890 - } - } - } - } + expected = {"params": {"_meta": {"trace_context": {"trace_id": 12345, "span_id": 67890}}}} self.assertEqual(request_data, expected) - + def test_inject_trace_context_existing_params(self): """Test injecting trace context when params already exist""" # Setup request_data = {"params": {"existing_field": "test_value"}} span_ctx = SimpleSpanContext(trace_id=99999, span_id=11111) - + # Execute self.instrumentor._inject_trace_context(request_data, span_ctx) - + # Verify the existing field is preserved and trace context is added self.assertEqual(request_data["params"]["existing_field"], "test_value") self.assertEqual(request_data["params"]["_meta"]["trace_context"]["trace_id"], 99999) @@ -76,74 +68,69 @@ def test_inject_trace_context_existing_params(self): class TestTracerProvider(unittest.TestCase): """Test the tracer provider kwargs logic in _instrument method""" - + def setUp(self): self.instrumentor = MCPInstrumentor() # Reset tracer_provider to ensure test isolation - if hasattr(self.instrumentor, 'tracer_provider'): - delattr(self.instrumentor, 'tracer_provider') - + if hasattr(self.instrumentor, "tracer_provider"): + delattr(self.instrumentor, "tracer_provider") + def test_instrument_without_tracer_provider_kwargs(self): """Test _instrument method when no tracer_provider in kwargs - should set to None""" # Execute - call _instrument without tracer_provider in kwargs self.instrumentor._instrument() - + # Verify - tracer_provider should be None - self.assertTrue(hasattr(self.instrumentor, 'tracer_provider')) + self.assertTrue(hasattr(self.instrumentor, "tracer_provider")) self.assertIsNone(self.instrumentor.tracer_provider) - + def test_instrument_with_tracer_provider_kwargs(self): """Test _instrument method when tracer_provider is in kwargs - should set to that value""" # Setup provider = SimpleTracerProvider() - + # Execute - call _instrument with tracer_provider in kwargs self.instrumentor._instrument(tracer_provider=provider) - + # Verify - tracer_provider should be set to the provided value - self.assertTrue(hasattr(self.instrumentor, 'tracer_provider')) + self.assertTrue(hasattr(self.instrumentor, "tracer_provider")) self.assertEqual(self.instrumentor.tracer_provider, provider) class TestAppSignalToolCallRequest(unittest.TestCase): """Test with realistic AppSignal MCP server tool call requests""" - + def setUp(self): self.instrumentor = MCPInstrumentor() self.instrumentor.tracer_provider = None - + def test_appsignal_call_tool_request(self): """Test with a realistic AppSignal MCP server CallToolRequest""" + # Create a realistic CallToolRequest for AppSignal MCP server class CallToolRequest: def __init__(self, tool_name, arguments): self.root = self self.params = CallToolParams(tool_name, arguments) - + def model_dump(self, **kwargs): - return { - "method": "call_tool", - "params": { - "name": self.params.name, - "arguments": self.params.arguments - } - } - + return {"method": "call_tool", "params": {"name": self.params.name, "arguments": self.params.arguments}} + class CallToolParams: def __init__(self, name, arguments): self.name = name self.arguments = arguments - + # Create an AppSignal tool call request appsignal_request = CallToolRequest( tool_name="create_metric", arguments={ "metric_name": "response_time", "value": 250, - "tags": {"endpoint": "/api/users", "method": "GET"} - } + "tags": {"endpoint": "/api/users", "method": "GET"}, + }, ) - + # Verify the tool call request structure matches AppSignal expectations request_data = appsignal_request.model_dump() self.assertEqual(request_data["method"], "call_tool") @@ -152,7 +139,7 @@ def __init__(self, name, arguments): self.assertEqual(request_data["params"]["arguments"]["value"], 250) self.assertIn("endpoint", request_data["params"]["arguments"]["tags"]) self.assertEqual(request_data["params"]["arguments"]["tags"]["endpoint"], "/api/users") - + # Test expected AppSignal response structure expected_appsignal_result = { "success": True, @@ -161,119 +148,105 @@ def __init__(self, name, arguments): "metadata": { "timestamp": "2025-01-22T21:19:00Z", "tags_applied": ["endpoint:/api/users", "method:GET"], - "value_recorded": 250 - } + "value_recorded": 250, + }, } - + self.assertTrue(expected_appsignal_result["success"]) self.assertIn("metric_id", expected_appsignal_result) self.assertEqual(expected_appsignal_result["metadata"]["value_recorded"], 250) - + def test_appsignal_tool_without_arguments(self): """Test AppSignal tool call that doesn't require arguments""" + # Create a realistic CallToolRequest for tools without arguments class CallToolRequestNoArgs: def __init__(self, tool_name): self.root = self self.params = CallToolParamsNoArgs(tool_name) - + def model_dump(self, **kwargs): return { "method": "call_tool", "params": { "name": self.params.name # No arguments field for this tool - } + }, } - + @classmethod def model_validate(cls, data): return cls(data["params"]["name"]) - + class CallToolParamsNoArgs: def __init__(self, name): self.name = name # No arguments attribute for this tool - + # Create an AppSignal tool call without arguments list_apps_request = CallToolRequestNoArgs(tool_name="list_applications") - + # Test argument detection args_with_request = (list_apps_request,) kwargs_empty = {} extracted_request = args_with_request[0] if len(args_with_request) > 0 else kwargs_empty.get("request") - + self.assertEqual(extracted_request, list_apps_request) self.assertEqual(extracted_request.params.name, "list_applications") - + # Verify the request structure for tools without arguments request_data = list_apps_request.model_dump() - self.assertEqual(request_data["method"], "call_tool") + self.assertEqual(request_data["method"], "call_tool") self.assertEqual(request_data["params"]["name"], "list_applications") self.assertNotIn("arguments", request_data["params"]) # No arguments field - + # Test expected result for list_applications tool expected_list_apps_result = { "success": True, "applications": [ - { - "id": "app_001", - "name": "web-frontend", - "environment": "production" - }, - { - "id": "app_002", - "name": "api-backend", - "environment": "staging" - } + {"id": "app_001", "name": "web-frontend", "environment": "production"}, + {"id": "app_002", "name": "api-backend", "environment": "staging"}, ], - "total_count": 2 + "total_count": 2, } - + # Verify expected response structure self.assertTrue(expected_list_apps_result["success"]) self.assertIn("applications", expected_list_apps_result) self.assertEqual(expected_list_apps_result["total_count"], 2) self.assertEqual(len(expected_list_apps_result["applications"]), 2) self.assertEqual(expected_list_apps_result["applications"][0]["name"], "web-frontend") - + def test_send_request_wrapper_argument_reconstruction(self): """Test the argument logic: if len(args) > 0 vs else path""" + # Create a realistic AppSignal request class CallToolRequest: def __init__(self, tool_name, arguments=None): self.root = self self.params = CallToolParams(tool_name, arguments) - + def model_dump(self, by_alias=True, mode="json", exclude_none=True): - result = { - "method": "call_tool", - "params": { - "name": self.params.name - } - } + result = {"method": "call_tool", "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments return result - + @classmethod def model_validate(cls, data): - return cls( - data["params"]["name"], - data["params"].get("arguments") - ) - + return cls(data["params"]["name"], data["params"].get("arguments")) + class CallToolParams: def __init__(self, name, arguments=None): self.name = name self.arguments = arguments - + request = CallToolRequest("create_metric", {"metric_name": "test", "value": 100}) - + # Test 1: len(args) > 0 path - should trigger new_args = (modified_request,) + args[1:] args_with_request = (request, "extra_arg1", "extra_arg2") kwargs_test = {"extra_kwarg": "test"} - + # Simulate what the wrapper logic does if len(args_with_request) > 0: # This tests: new_args = (modified_request,) + args[1:] @@ -285,7 +258,7 @@ def __init__(self, name, arguments=None): result_args = args_with_request result_kwargs = kwargs_test.copy() result_kwargs["request"] = "modified_request_placeholder" - + # Verify args path reconstruction self.assertEqual(len(result_args), 3) # modified_request + 2 extra args self.assertEqual(result_args[0], "modified_request_placeholder") @@ -293,11 +266,11 @@ def __init__(self, name, arguments=None): self.assertEqual(result_args[2], "extra_arg2") # args[1:] preserved self.assertEqual(result_kwargs["extra_kwarg"], "test") self.assertNotIn("request", result_kwargs) # Should NOT modify kwargs in args path - + # Test 2: len(args) == 0 path - should trigger kwargs["request"] = modified_request args_empty = () kwargs_with_request = {"request": request, "other_param": "value"} - + # Simulate what the wrapper logic does if len(args_empty) > 0: # This shouldn't happen in this test @@ -309,14 +282,15 @@ def __init__(self, name, arguments=None): result_args = args_empty result_kwargs = kwargs_with_request.copy() result_kwargs["request"] = "modified_request_placeholder" - + # Verify kwargs path reconstruction self.assertEqual(len(result_args), 0) # No positional args self.assertEqual(result_kwargs["request"], "modified_request_placeholder") self.assertEqual(result_kwargs["other_param"], "value") # Other kwargs preserved - + def test_server_handle_request_wrapper_logic(self): """Test the _server_handle_request_wrapper""" + # Create realistic server request structures class ServerRequest: def __init__(self, has_trace_context=False): @@ -324,51 +298,48 @@ def __init__(self, has_trace_context=False): self.params = ServerRequestParams(has_meta=True) else: self.params = ServerRequestParams(has_meta=False) - + class ServerRequestParams: def __init__(self, has_meta=False): if has_meta: self.meta = ServerRequestMeta() else: self.meta = None - + class ServerRequestMeta: def __init__(self): - self.trace_context = { - "trace_id": 12345, - "span_id": 67890 - } - + self.trace_context = {"trace_id": 12345, "span_id": 67890} + # Test 1: Request WITHOUT trace context - should take else path request_no_trace = ServerRequest(has_trace_context=False) - + # wrapper's request extraction logic args_with_request = (None, request_no_trace) # args[1] is the request req = args_with_request[1] if len(args_with_request) > 1 else None - + # Check trace context extraction logic trace_context = None if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: trace_context = req.params.meta.trace_context - + # Verify - should NOT find trace context self.assertIsNotNone(req) self.assertIsNotNone(req.params) self.assertIsNone(req.params.meta) # No meta field self.assertIsNone(trace_context) - - # Test 2: Request WITH trace context - should take if path + + # Test 2: Request WITH trace context - should take if path request_with_trace = ServerRequest(has_trace_context=True) - + # Simulate the wrapper's request extraction logic args_with_trace = (None, request_with_trace) # args[1] is the request req2 = args_with_trace[1] if len(args_with_trace) > 1 else None - + # Check trace context extraction logic trace_context2 = None if req2 and hasattr(req2, "params") and req2.params and hasattr(req2.params, "meta") and req2.params.meta: trace_context2 = req2.params.meta.trace_context - + # Verify - should find trace context self.assertIsNotNone(req2) self.assertIsNotNone(req2.params) @@ -376,82 +347,74 @@ def __init__(self): self.assertIsNotNone(trace_context2) self.assertEqual(trace_context2["trace_id"], 12345) self.assertEqual(trace_context2["span_id"], 67890) - + # Test 3: No request at all (args[1] doesn't exist) args_no_request = (None,) # Only one arg, no request req3 = args_no_request[1] if len(args_no_request) > 1 else None - + # Verify - should handle missing request gracefully self.assertIsNone(req3) - + def test_end_to_end_trace_context_propagation(self): """Test client sending trace context and server receiving the same trace context""" # STEP 1: CLIENT SIDE - Create and prepare request with trace context - + # Create a realistic AppSignal request (what client would send) class CallToolRequest: def __init__(self, tool_name, arguments=None): self.root = self self.params = CallToolParams(tool_name, arguments) - + def model_dump(self, by_alias=True, mode="json", exclude_none=True): - result = { - "method": "call_tool", - "params": { - "name": self.params.name - } - } + result = {"method": "call_tool", "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments # Include _meta if it exists (trace context injection point) - if hasattr(self.params, '_meta') and self.params._meta: + if hasattr(self.params, "_meta") and self.params._meta: result["params"]["_meta"] = self.params._meta return result - - @classmethod + + @classmethod def model_validate(cls, data): - instance = cls( - data["params"]["name"], - data["params"].get("arguments") - ) + instance = cls(data["params"]["name"], data["params"].get("arguments")) # Restore _meta field if present if "_meta" in data["params"]: instance.params._meta = data["params"]["_meta"] return instance - + class CallToolParams: def __init__(self, name, arguments=None): self.name = name self.arguments = arguments self._meta = None # Will hold trace context - + # Client creates original request client_request = CallToolRequest("create_metric", {"metric_name": "response_time", "value": 250}) - + # Client injects trace context (what _send_request_wrapper does) original_trace_context = SimpleSpanContext(trace_id=98765, span_id=43210) - + # Get request data and inject trace context request_data = client_request.model_dump() self.instrumentor._inject_trace_context(request_data, original_trace_context) - + # Create modified request with trace context (what client sends over network) modified_request = CallToolRequest.model_validate(request_data) - + # Verify client successfully injected trace context client_data = modified_request.model_dump() self.assertIn("_meta", client_data["params"]) self.assertIn("trace_context", client_data["params"]["_meta"]) self.assertEqual(client_data["params"]["_meta"]["trace_context"]["trace_id"], 98765) self.assertEqual(client_data["params"]["_meta"]["trace_context"]["span_id"], 43210) - + # STEP 2: SERVER SIDE - Receive and extract trace context - + # Create server request structure (what server receives) class ServerRequest: def __init__(self, client_request_data): self.params = ServerRequestParams(client_request_data["params"]) - + class ServerRequestParams: def __init__(self, params_data): self.name = params_data["name"] @@ -462,38 +425,36 @@ def __init__(self, params_data): self.meta = ServerRequestMeta(params_data["_meta"]) else: self.meta = None - + class ServerRequestMeta: def __init__(self, meta_data): self.trace_context = meta_data["trace_context"] - + # Server receives the request (simulating network transmission) server_request = ServerRequest(client_data) - + # Server extracts trace context (what _server_handle_request_wrapper does) args_with_request = (None, server_request) # args[1] is the request req = args_with_request[1] if len(args_with_request) > 1 else None - + # Extract trace context using server logic extracted_trace_context = None if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: extracted_trace_context = req.params.meta.trace_context - - # STEP 3: VERIFY END-TO-END PROPAGATION - + # Verify server successfully received the trace context self.assertIsNotNone(extracted_trace_context) self.assertEqual(extracted_trace_context["trace_id"], 98765) self.assertEqual(extracted_trace_context["span_id"], 43210) - + # Verify it's the SAME trace context that client sent self.assertEqual(extracted_trace_context["trace_id"], original_trace_context.trace_id) self.assertEqual(extracted_trace_context["span_id"], original_trace_context.span_id) - + # Verify the tool call data is also preserved self.assertEqual(server_request.params.name, "create_metric") self.assertEqual(server_request.params.arguments["metric_name"], "response_time") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 49649e7d7889245b8467a881d33f3051520566ce Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 23 Jul 2025 19:18:09 -0700 Subject: [PATCH 03/41] Testing coverage and dependency issues --- .flake8 | 3 - .../distro/mcpinstrumentor => }/loggertwo.log | 0 aws-opentelemetry-distro/pyproject.toml | 5 +- .../distro/mcpinstrumentor/mcpinstrumentor.py | 23 +- .../distro/mcpinstrumentor/pyproject.toml | 4 +- .../amazon/opentelemetry/distro/loggertwo.log | 0 .../distro/test_mcpinstrumentor.py | 558 ++++++++---------- 7 files changed, 270 insertions(+), 323 deletions(-) rename aws-opentelemetry-distro/{src/amazon/opentelemetry/distro/mcpinstrumentor => }/loggertwo.log (100%) create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/loggertwo.log diff --git a/.flake8 b/.flake8 index 206a12e12..e55e479dd 100644 --- a/.flake8 +++ b/.flake8 @@ -19,9 +19,6 @@ exclude = CVS .venv*/ venv*/ - **/venv*/ - **/.venv*/ - aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/venv target __pycache__ mock_collector_service_pb2.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/loggertwo.log b/aws-opentelemetry-distro/loggertwo.log similarity index 100% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/loggertwo.log rename to aws-opentelemetry-distro/loggertwo.log diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index a831eacb4..947d0e22f 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "opentelemetry-instrumentation-boto == 0.54b1", "opentelemetry-instrumentation-boto3sqs == 0.54b1", "opentelemetry-instrumentation-botocore == 0.54b1", - "opentelemetry-instrumentation-cassandra == 0.54b1", "opentelemetry-instrumentation-celery == 0.54b1", "opentelemetry-instrumentation-confluent-kafka == 0.54b1", "opentelemetry-instrumentation-dbapi == 0.54b1", @@ -82,9 +81,9 @@ dependencies = [ "opentelemetry-instrumentation-urllib == 0.54b1", "opentelemetry-instrumentation-urllib3 == 0.54b1", "opentelemetry-instrumentation-wsgi == 0.54b1", + "opentelemetry-instrumentation-cassandra == 0.54b1", ] - [project.optional-dependencies] # The 'patch' optional dependency is used for applying patches to specific libraries. # If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch @@ -112,4 +111,4 @@ include = [ ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] +packages = ["src/amazon"] \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 7780e97d4..8df12cb26 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -9,7 +9,7 @@ from opentelemetry.instrumentation.utils import unwrap -def setup_loggertwo(): +def setup_logger_two(): logger = logging.getLogger("loggertwo") logger.setLevel(logging.DEBUG) handler = logging.FileHandler("loggertwo.log", mode="w") @@ -21,7 +21,7 @@ def setup_loggertwo(): return logger -loggertwo = setup_loggertwo() +logger_two = setup_logger_two() class MCPInstrumentor(BaseInstrumentor): @@ -33,7 +33,7 @@ def instrumentation_dependencies(self) -> Collection[str]: return _instruments def _instrument(self, **kwargs: Any) -> None: - tracer_provider = kwargs.get("tracer_provider") # Move this line up + tracer_provider = kwargs.get("tracer_provider") if tracer_provider: self.tracer_provider = tracer_provider else: @@ -42,7 +42,7 @@ def _instrument(self, **kwargs: Any) -> None: lambda _: wrap_function_wrapper( "mcp.shared.session", "BaseSession.send_request", - self._send_request_wrapper, + self._wrap_send_request, ), "mcp.shared.session", ) @@ -50,7 +50,7 @@ def _instrument(self, **kwargs: Any) -> None: lambda _: wrap_function_wrapper( "mcp.server.lowlevel.server", "Server._handle_request", - self._server_handle_request_wrapper, + self._wrap_handle_request, ), "mcp.server.lowlevel.server", ) @@ -96,7 +96,7 @@ def _inject_trace_context(self, request_data, span_ctx): request_data["params"]["_meta"]["trace_context"] = {"trace_id": span_ctx.trace_id, "span_id": span_ctx.span_id} # Send Request Wrapper - def _send_request_wrapper(self, wrapped, instance, args, kwargs): + def _wrap_send_request(self, wrapped, instance, args, kwargs): """ Changes made: The wrapper intercepts the request before sending, injects distributed tracing context into the @@ -133,7 +133,7 @@ async def async_wrapper(): return async_wrapper() - def getname(self, req): + def _get_span_name(self, req): span_name = "unknown" import mcp.types as types @@ -147,7 +147,7 @@ def getname(self, req): return span_name # Handle Request Wrapper - async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs): + async def _wrap_handle_request(self, wrapped, instance, args, kwargs): """ Changes made: This wrapper intercepts requests before processing, extracts distributed tracing context from @@ -175,7 +175,7 @@ async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs): trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), trace_state=trace.TraceState(), ) - span_name = self.getname(req) + span_name = self._get_span_name(req) with tracer.start_as_current_span( span_name, kind=trace.SpanKind.SERVER, @@ -185,7 +185,4 @@ async def _server_handle_request_wrapper(self, wrapped, instance, args, kwargs): result = await wrapped(*args, **kwargs) return result else: - return await wrapped( - *args, - **kwargs, - ) + return await wrapped(*args, **kwargs) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml index b15828005..0ca34dd39 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml @@ -25,11 +25,11 @@ classifiers = [ "Programming Language :: Python :: 3.13", ] dependencies = [ + "openinference-instrumentation-mcp", "opentelemetry-api", "opentelemetry-instrumentation", - "opentelemetry-semantic-conventions", - "wrapt", "opentelemetry-sdk", + "wrapt" ] [project.optional-dependencies] diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/loggertwo.log b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/loggertwo.log new file mode 100644 index 000000000..e69de29bb diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index ca3673567..fc4921c36 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -1,13 +1,16 @@ """ -Super simple unit test for MCPInstrumentor - no mocking, just one functionality -Testing the _inject_trace_context method +Unit tests for MCPInstrumentor - testing actual mcpinstrumentor methods """ +import asyncio import os import sys import unittest +from unittest.mock import MagicMock -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../../../src")) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) +src_path = os.path.join(project_root, "src") +sys.path.insert(0, src_path) from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor # noqa: E402 @@ -33,7 +36,7 @@ def get_tracer(self, name): class TestInjectTraceContext(unittest.TestCase): - """Test the _inject_trace_context method - simple functionality""" + """Test the _inject_trace_context method""" def setUp(self): self.instrumentor = MCPInstrumentor() @@ -44,7 +47,7 @@ def test_inject_trace_context_empty_dict(self): request_data = {} span_ctx = SimpleSpanContext(trace_id=12345, span_id=67890) - # Execute + # Execute - Actually test the mcpinstrumentor method self.instrumentor._inject_trace_context(request_data, span_ctx) # Verify @@ -57,7 +60,7 @@ def test_inject_trace_context_existing_params(self): request_data = {"params": {"existing_field": "test_value"}} span_ctx = SimpleSpanContext(trace_id=99999, span_id=11111) - # Execute + # Execute - Actually test the mcpinstrumentor method self.instrumentor._inject_trace_context(request_data, span_ctx) # Verify the existing field is preserved and trace context is added @@ -77,7 +80,7 @@ def setUp(self): def test_instrument_without_tracer_provider_kwargs(self): """Test _instrument method when no tracer_provider in kwargs - should set to None""" - # Execute - call _instrument without tracer_provider in kwargs + # Execute - Actually test the mcpinstrumentor method self.instrumentor._instrument() # Verify - tracer_provider should be None @@ -89,7 +92,7 @@ def test_instrument_with_tracer_provider_kwargs(self): # Setup provider = SimpleTracerProvider() - # Execute - call _instrument with tracer_provider in kwargs + # Execute - Actually test the mcpinstrumentor method self.instrumentor._instrument(tracer_provider=provider) # Verify - tracer_provider should be set to the provided value @@ -97,269 +100,33 @@ def test_instrument_with_tracer_provider_kwargs(self): self.assertEqual(self.instrumentor.tracer_provider, provider) -class TestAppSignalToolCallRequest(unittest.TestCase): - """Test with realistic AppSignal MCP server tool call requests""" +class TestInstrumentationDependencies(unittest.TestCase): + """Test the instrumentation_dependencies method""" def setUp(self): self.instrumentor = MCPInstrumentor() - self.instrumentor.tracer_provider = None - - def test_appsignal_call_tool_request(self): - """Test with a realistic AppSignal MCP server CallToolRequest""" - - # Create a realistic CallToolRequest for AppSignal MCP server - class CallToolRequest: - def __init__(self, tool_name, arguments): - self.root = self - self.params = CallToolParams(tool_name, arguments) - - def model_dump(self, **kwargs): - return {"method": "call_tool", "params": {"name": self.params.name, "arguments": self.params.arguments}} - - class CallToolParams: - def __init__(self, name, arguments): - self.name = name - self.arguments = arguments - # Create an AppSignal tool call request - appsignal_request = CallToolRequest( - tool_name="create_metric", - arguments={ - "metric_name": "response_time", - "value": 250, - "tags": {"endpoint": "/api/users", "method": "GET"}, - }, - ) - - # Verify the tool call request structure matches AppSignal expectations - request_data = appsignal_request.model_dump() - self.assertEqual(request_data["method"], "call_tool") - self.assertEqual(request_data["params"]["name"], "create_metric") - self.assertEqual(request_data["params"]["arguments"]["metric_name"], "response_time") - self.assertEqual(request_data["params"]["arguments"]["value"], 250) - self.assertIn("endpoint", request_data["params"]["arguments"]["tags"]) - self.assertEqual(request_data["params"]["arguments"]["tags"]["endpoint"], "/api/users") - - # Test expected AppSignal response structure - expected_appsignal_result = { - "success": True, - "metric_id": "metric_12345", - "message": "Metric 'response_time' created successfully", - "metadata": { - "timestamp": "2025-01-22T21:19:00Z", - "tags_applied": ["endpoint:/api/users", "method:GET"], - "value_recorded": 250, - }, - } - - self.assertTrue(expected_appsignal_result["success"]) - self.assertIn("metric_id", expected_appsignal_result) - self.assertEqual(expected_appsignal_result["metadata"]["value_recorded"], 250) - - def test_appsignal_tool_without_arguments(self): - """Test AppSignal tool call that doesn't require arguments""" - - # Create a realistic CallToolRequest for tools without arguments - class CallToolRequestNoArgs: - def __init__(self, tool_name): - self.root = self - self.params = CallToolParamsNoArgs(tool_name) - - def model_dump(self, **kwargs): - return { - "method": "call_tool", - "params": { - "name": self.params.name - # No arguments field for this tool - }, - } - - @classmethod - def model_validate(cls, data): - return cls(data["params"]["name"]) - - class CallToolParamsNoArgs: - def __init__(self, name): - self.name = name - # No arguments attribute for this tool - - # Create an AppSignal tool call without arguments - list_apps_request = CallToolRequestNoArgs(tool_name="list_applications") - - # Test argument detection - args_with_request = (list_apps_request,) - kwargs_empty = {} - extracted_request = args_with_request[0] if len(args_with_request) > 0 else kwargs_empty.get("request") - - self.assertEqual(extracted_request, list_apps_request) - self.assertEqual(extracted_request.params.name, "list_applications") - - # Verify the request structure for tools without arguments - request_data = list_apps_request.model_dump() - self.assertEqual(request_data["method"], "call_tool") - self.assertEqual(request_data["params"]["name"], "list_applications") - self.assertNotIn("arguments", request_data["params"]) # No arguments field - - # Test expected result for list_applications tool - expected_list_apps_result = { - "success": True, - "applications": [ - {"id": "app_001", "name": "web-frontend", "environment": "production"}, - {"id": "app_002", "name": "api-backend", "environment": "staging"}, - ], - "total_count": 2, - } - - # Verify expected response structure - self.assertTrue(expected_list_apps_result["success"]) - self.assertIn("applications", expected_list_apps_result) - self.assertEqual(expected_list_apps_result["total_count"], 2) - self.assertEqual(len(expected_list_apps_result["applications"]), 2) - self.assertEqual(expected_list_apps_result["applications"][0]["name"], "web-frontend") - - def test_send_request_wrapper_argument_reconstruction(self): - """Test the argument logic: if len(args) > 0 vs else path""" - - # Create a realistic AppSignal request - class CallToolRequest: - def __init__(self, tool_name, arguments=None): - self.root = self - self.params = CallToolParams(tool_name, arguments) + def test_instrumentation_dependencies(self): + """Test that instrumentation_dependencies method returns the expected dependencies""" + # Execute - Actually test the mcpinstrumentor method + dependencies = self.instrumentor.instrumentation_dependencies() - def model_dump(self, by_alias=True, mode="json", exclude_none=True): - result = {"method": "call_tool", "params": {"name": self.params.name}} - if self.params.arguments: - result["params"]["arguments"] = self.params.arguments - return result + # Verify - should return the _instruments collection + self.assertIsNotNone(dependencies) + # The dependencies come from openinference.instrumentation.mcp.package._instruments + # which should be a collection - @classmethod - def model_validate(cls, data): - return cls(data["params"]["name"], data["params"].get("arguments")) - class CallToolParams: - def __init__(self, name, arguments=None): - self.name = name - self.arguments = arguments +class TestTraceContextInjection(unittest.TestCase): + """Test trace context injection using actual mcpinstrumentor methods""" - request = CallToolRequest("create_metric", {"metric_name": "test", "value": 100}) - - # Test 1: len(args) > 0 path - should trigger new_args = (modified_request,) + args[1:] - args_with_request = (request, "extra_arg1", "extra_arg2") - kwargs_test = {"extra_kwarg": "test"} - - # Simulate what the wrapper logic does - if len(args_with_request) > 0: - # This tests: new_args = (modified_request,) + args[1:] - new_args = ("modified_request_placeholder",) + args_with_request[1:] - result_args = new_args - result_kwargs = kwargs_test - else: - # This shouldn't happen in this test - result_args = args_with_request - result_kwargs = kwargs_test.copy() - result_kwargs["request"] = "modified_request_placeholder" - - # Verify args path reconstruction - self.assertEqual(len(result_args), 3) # modified_request + 2 extra args - self.assertEqual(result_args[0], "modified_request_placeholder") - self.assertEqual(result_args[1], "extra_arg1") # args[1:] preserved - self.assertEqual(result_args[2], "extra_arg2") # args[1:] preserved - self.assertEqual(result_kwargs["extra_kwarg"], "test") - self.assertNotIn("request", result_kwargs) # Should NOT modify kwargs in args path - - # Test 2: len(args) == 0 path - should trigger kwargs["request"] = modified_request - args_empty = () - kwargs_with_request = {"request": request, "other_param": "value"} - - # Simulate what the wrapper logic does - if len(args_empty) > 0: - # This shouldn't happen in this test - new_args = ("modified_request_placeholder",) + args_empty[1:] - result_args = new_args - result_kwargs = kwargs_with_request - else: - # This tests: kwargs["request"] = modified_request - result_args = args_empty - result_kwargs = kwargs_with_request.copy() - result_kwargs["request"] = "modified_request_placeholder" - - # Verify kwargs path reconstruction - self.assertEqual(len(result_args), 0) # No positional args - self.assertEqual(result_kwargs["request"], "modified_request_placeholder") - self.assertEqual(result_kwargs["other_param"], "value") # Other kwargs preserved - - def test_server_handle_request_wrapper_logic(self): - """Test the _server_handle_request_wrapper""" - - # Create realistic server request structures - class ServerRequest: - def __init__(self, has_trace_context=False): - if has_trace_context: - self.params = ServerRequestParams(has_meta=True) - else: - self.params = ServerRequestParams(has_meta=False) + def setUp(self): + self.instrumentor = MCPInstrumentor() - class ServerRequestParams: - def __init__(self, has_meta=False): - if has_meta: - self.meta = ServerRequestMeta() - else: - self.meta = None + def test_trace_context_injection_with_realistic_request(self): + """Test actual trace context injection using mcpinstrumentor._inject_trace_context with realistic MCP request""" - class ServerRequestMeta: - def __init__(self): - self.trace_context = {"trace_id": 12345, "span_id": 67890} - - # Test 1: Request WITHOUT trace context - should take else path - request_no_trace = ServerRequest(has_trace_context=False) - - # wrapper's request extraction logic - args_with_request = (None, request_no_trace) # args[1] is the request - req = args_with_request[1] if len(args_with_request) > 1 else None - - # Check trace context extraction logic - trace_context = None - if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - trace_context = req.params.meta.trace_context - - # Verify - should NOT find trace context - self.assertIsNotNone(req) - self.assertIsNotNone(req.params) - self.assertIsNone(req.params.meta) # No meta field - self.assertIsNone(trace_context) - - # Test 2: Request WITH trace context - should take if path - request_with_trace = ServerRequest(has_trace_context=True) - - # Simulate the wrapper's request extraction logic - args_with_trace = (None, request_with_trace) # args[1] is the request - req2 = args_with_trace[1] if len(args_with_trace) > 1 else None - - # Check trace context extraction logic - trace_context2 = None - if req2 and hasattr(req2, "params") and req2.params and hasattr(req2.params, "meta") and req2.params.meta: - trace_context2 = req2.params.meta.trace_context - - # Verify - should find trace context - self.assertIsNotNone(req2) - self.assertIsNotNone(req2.params) - self.assertIsNotNone(req2.params.meta) # Has meta field - self.assertIsNotNone(trace_context2) - self.assertEqual(trace_context2["trace_id"], 12345) - self.assertEqual(trace_context2["span_id"], 67890) - - # Test 3: No request at all (args[1] doesn't exist) - args_no_request = (None,) # Only one arg, no request - req3 = args_no_request[1] if len(args_no_request) > 1 else None - - # Verify - should handle missing request gracefully - self.assertIsNone(req3) - - def test_end_to_end_trace_context_propagation(self): - """Test client sending trace context and server receiving the same trace context""" - # STEP 1: CLIENT SIDE - Create and prepare request with trace context - - # Create a realistic AppSignal request (what client would send) + # Create a realistic MCP request structure class CallToolRequest: def __init__(self, tool_name, arguments=None): self.root = self @@ -391,69 +158,256 @@ def __init__(self, name, arguments=None): # Client creates original request client_request = CallToolRequest("create_metric", {"metric_name": "response_time", "value": 250}) - # Client injects trace context (what _send_request_wrapper does) + # Client injects trace context using ACTUAL mcpinstrumentor method original_trace_context = SimpleSpanContext(trace_id=98765, span_id=43210) - - # Get request data and inject trace context request_data = client_request.model_dump() + + # This is the actual mcpinstrumentor method we're testing self.instrumentor._inject_trace_context(request_data, original_trace_context) - # Create modified request with trace context (what client sends over network) + # Create modified request with trace context modified_request = CallToolRequest.model_validate(request_data) - # Verify client successfully injected trace context + # Verify the actual mcpinstrumentor method worked correctly client_data = modified_request.model_dump() self.assertIn("_meta", client_data["params"]) self.assertIn("trace_context", client_data["params"]["_meta"]) self.assertEqual(client_data["params"]["_meta"]["trace_context"]["trace_id"], 98765) self.assertEqual(client_data["params"]["_meta"]["trace_context"]["span_id"], 43210) - # STEP 2: SERVER SIDE - Receive and extract trace context + # Verify the tool call data is also preserved + self.assertEqual(client_data["params"]["name"], "create_metric") + self.assertEqual(client_data["params"]["arguments"]["metric_name"], "response_time") - # Create server request structure (what server receives) - class ServerRequest: - def __init__(self, client_request_data): - self.params = ServerRequestParams(client_request_data["params"]) - class ServerRequestParams: - def __init__(self, params_data): - self.name = params_data["name"] - if "arguments" in params_data: - self.arguments = params_data["arguments"] - # Extract meta field (trace context) - if "_meta" in params_data: - self.meta = ServerRequestMeta(params_data["_meta"]) - else: - self.meta = None +class TestInstrumentedMCPServer(unittest.TestCase): + """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" - class ServerRequestMeta: - def __init__(self, meta_data): - self.trace_context = meta_data["trace_context"] + def setUp(self): + self.instrumentor = MCPInstrumentor() + self.instrumentor.tracer_provider = None - # Server receives the request (simulating network transmission) - server_request = ServerRequest(client_data) + def test_no_trace_context_fallback(self): + """Test graceful handling when no trace context is present on server side""" - # Server extracts trace context (what _server_handle_request_wrapper does) - args_with_request = (None, server_request) # args[1] is the request - req = args_with_request[1] if len(args_with_request) > 1 else None + class MockServerNoTrace: + async def _handle_request(self, session, request): + return {"success": True, "handled_without_trace": True} - # Extract trace context using server logic - extracted_trace_context = None - if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - extracted_trace_context = req.params.meta.trace_context + class MockServerRequestNoTrace: + def __init__(self, tool_name): + self.params = MockServerRequestParamsNoTrace(tool_name) - # Verify server successfully received the trace context - self.assertIsNotNone(extracted_trace_context) - self.assertEqual(extracted_trace_context["trace_id"], 98765) - self.assertEqual(extracted_trace_context["span_id"], 43210) + class MockServerRequestParamsNoTrace: + def __init__(self, name): + self.name = name + self.meta = None # No trace context + + mock_server = MockServerNoTrace() + server_request = MockServerRequestNoTrace("create_metric") + + # Setup mocks + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + + # Test server handling without trace context (fallback scenario) + with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( + "sys.modules", {"mcp.types": MagicMock()} + ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_span_name", return_value="tools/create_metric" + ): + + result = asyncio.run( + self.instrumentor._wrap_handle_request(mock_server._handle_request, None, (None, server_request), {}) + ) + + # Verify graceful fallback - no tracing spans should be created when no trace context + # The wrapper should call the original function without creating distributed trace spans + self.assertEqual(result["success"], True) + self.assertEqual(result["handled_without_trace"], True) + + # Should not create traced spans when no trace context is present + mock_tracer.start_as_current_span.assert_not_called() + + def test_end_to_end_client_server_communication(self): + """Test where server actually receives what client sends (including injected trace context)""" + + # Create realistic request/response classes + class MCPRequest: + def __init__(self, tool_name, arguments=None, method="call_tool"): + self.root = self + self.params = MCPRequestParams(tool_name, arguments) + self.method = method - # Verify it's the SAME trace context that client sent - self.assertEqual(extracted_trace_context["trace_id"], original_trace_context.trace_id) - self.assertEqual(extracted_trace_context["span_id"], original_trace_context.span_id) + def model_dump(self, by_alias=True, mode="json", exclude_none=True): + result = {"method": self.method, "params": {"name": self.params.name}} + if self.params.arguments: + result["params"]["arguments"] = self.params.arguments + # Include _meta if it exists (for trace context) + if hasattr(self.params, "_meta") and self.params._meta: + result["params"]["_meta"] = self.params._meta + return result - # Verify the tool call data is also preserved + @classmethod + def model_validate(cls, data): + method = data.get("method", "call_tool") + instance = cls(data["params"]["name"], data["params"].get("arguments"), method) + # Restore _meta field if present + if "_meta" in data["params"]: + instance.params._meta = data["params"]["_meta"] + return instance + + class MCPRequestParams: + def __init__(self, name, arguments=None): + self.name = name + self.arguments = arguments + self._meta = None + + class MCPServerRequest: + def __init__(self, client_request_data): + """Server request created from client's serialized data""" + self.method = client_request_data.get("method", "call_tool") + self.params = MCPServerRequestParams(client_request_data["params"]) + + class MCPServerRequestParams: + def __init__(self, params_data): + self.name = params_data["name"] + self.arguments = params_data.get("arguments") + # Extract trace context from _meta if present + if "_meta" in params_data and "trace_context" in params_data["_meta"]: + self.meta = MCPServerRequestMeta(params_data["_meta"]["trace_context"]) + else: + self.meta = None + + class MCPServerRequestMeta: + def __init__(self, trace_context): + self.trace_context = trace_context + + # Mock client and server that actually communicate + class EndToEndMCPSystem: + def __init__(self): + self.communication_log = [] + self.last_sent_request = None + + async def client_send_request(self, request): + """Client sends request - captures what gets sent""" + self.communication_log.append("CLIENT: Preparing to send request") + self.last_sent_request = request # Capture the modified request + + # Simulate sending over network - serialize the request + serialized_request = request.model_dump() + self.communication_log.append(f"CLIENT: Sent {serialized_request}") + + # Return client response + return {"success": True, "client_response": "Request sent successfully"} + + async def server_handle_request(self, session, server_request): + """Server handles the request it received""" + self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") + + # Check if trace context was received + if server_request.params.meta and server_request.params.meta.trace_context: + trace_info = server_request.params.meta.trace_context + self.communication_log.append( + f"SERVER: Found trace context - trace_id: {trace_info['trace_id']}, " + f"span_id: {trace_info['span_id']}" + ) + else: + self.communication_log.append("SERVER: No trace context found") + + return {"success": True, "server_response": f"Handled {server_request.params.name}"} + + # Create the end-to-end system + e2e_system = EndToEndMCPSystem() + + # Create original client request + original_request = MCPRequest("create_metric", {"name": "cpu_usage", "value": 85}) + + # Setup OpenTelemetry mocks + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_span_context = MagicMock() + mock_span_context.trace_id = 12345 + mock_span_context.span_id = 67890 + mock_span.get_span_context.return_value = mock_span_context + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + + # STEP 1: Client sends request through instrumentation + with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( + "sys.modules", {"mcp.types": MagicMock()} + ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"): + + client_result = asyncio.run( + self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {}) + ) + + # Verify client side worked + self.assertEqual(client_result["success"], True) + self.assertIn("CLIENT: Preparing to send request", e2e_system.communication_log) + + # Get the request that was actually sent (with trace context injected) + sent_request = e2e_system.last_sent_request + sent_request_data = sent_request.model_dump() + + # Verify trace context was injected by client instrumentation + self.assertIn("_meta", sent_request_data["params"]) + self.assertIn("trace_context", sent_request_data["params"]["_meta"]) + self.assertEqual(sent_request_data["params"]["_meta"]["trace_context"]["trace_id"], 12345) + self.assertEqual(sent_request_data["params"]["_meta"]["trace_context"]["span_id"], 67890) + + # STEP 2: Server receives the EXACT request that client sent + # Create server request from the client's serialized data + server_request = MCPServerRequest(sent_request_data) + + # Reset tracer mock for server side + mock_tracer.reset_mock() + + # Server processes the request it received + with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( + "sys.modules", {"mcp.types": MagicMock()} + ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_span_name", return_value="tools/create_metric" + ): + + server_result = asyncio.run( + self.instrumentor._wrap_handle_request( + e2e_system.server_handle_request, None, (None, server_request), {} + ) + ) + + # Verify server side worked + self.assertEqual(server_result["success"], True) + + # Verify end-to-end trace context propagation + self.assertIn("SERVER: Found trace context - trace_id: 12345, span_id: 67890", e2e_system.communication_log) + + # Verify the server received the exact same data the client sent self.assertEqual(server_request.params.name, "create_metric") - self.assertEqual(server_request.params.arguments["metric_name"], "response_time") + self.assertEqual(server_request.params.arguments["name"], "cpu_usage") + self.assertEqual(server_request.params.arguments["value"], 85) + + # Verify the trace context made it through end-to-end + self.assertIsNotNone(server_request.params.meta) + self.assertIsNotNone(server_request.params.meta.trace_context) + self.assertEqual(server_request.params.meta.trace_context["trace_id"], 12345) + self.assertEqual(server_request.params.meta.trace_context["span_id"], 67890) + + # Verify complete communication flow + expected_log_entries = [ + "CLIENT: Preparing to send request", + "CLIENT: Sent", # Part of the serialized request log + "SERVER: Received request for create_metric", + "SERVER: Found trace context - trace_id: 12345, span_id: 67890", + ] + + for expected_entry in expected_log_entries: + self.assertTrue( + any(expected_entry in log_entry for log_entry in e2e_system.communication_log), + f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", + ) if __name__ == "__main__": From 442a45a02b46e65bc46e840a746aa29efd334abf Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 15:15:36 -0700 Subject: [PATCH 04/41] Code Cleanup and To Test Dependencies --- .../distro/test_mcpinstrumentor.py | 127 ++++++++++++------ 1 file changed, 85 insertions(+), 42 deletions(-) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index fc4921c36..a17a73d0a 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -50,9 +50,19 @@ def test_inject_trace_context_empty_dict(self): # Execute - Actually test the mcpinstrumentor method self.instrumentor._inject_trace_context(request_data, span_ctx) - # Verify - expected = {"params": {"_meta": {"trace_context": {"trace_id": 12345, "span_id": 67890}}}} - self.assertEqual(request_data, expected) + # Verify - now uses traceparent W3C format + self.assertIn("params", request_data) + self.assertIn("_meta", request_data["params"]) + self.assertIn("traceparent", request_data["params"]["_meta"]) + + # Verify traceparent format: "00-{trace_id:032x}-{span_id:016x}-01" + traceparent = request_data["params"]["_meta"]["traceparent"] + self.assertTrue(traceparent.startswith("00-")) + self.assertTrue(traceparent.endswith("-01")) + parts = traceparent.split("-") + self.assertEqual(len(parts), 4) + self.assertEqual(int(parts[1], 16), 12345) # trace_id + self.assertEqual(int(parts[2], 16), 67890) # span_id def test_inject_trace_context_existing_params(self): """Test injecting trace context when params already exist""" @@ -63,10 +73,16 @@ def test_inject_trace_context_existing_params(self): # Execute - Actually test the mcpinstrumentor method self.instrumentor._inject_trace_context(request_data, span_ctx) - # Verify the existing field is preserved and trace context is added + # Verify the existing field is preserved and traceparent is added self.assertEqual(request_data["params"]["existing_field"], "test_value") - self.assertEqual(request_data["params"]["_meta"]["trace_context"]["trace_id"], 99999) - self.assertEqual(request_data["params"]["_meta"]["trace_context"]["span_id"], 11111) + self.assertIn("_meta", request_data["params"]) + self.assertIn("traceparent", request_data["params"]["_meta"]) + + # Verify traceparent format contains correct trace/span IDs + traceparent = request_data["params"]["_meta"]["traceparent"] + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 99999) # trace_id + self.assertEqual(int(parts[2], 16), 11111) # span_id class TestTracerProvider(unittest.TestCase): @@ -74,30 +90,35 @@ class TestTracerProvider(unittest.TestCase): def setUp(self): self.instrumentor = MCPInstrumentor() - # Reset tracer_provider to ensure test isolation - if hasattr(self.instrumentor, "tracer_provider"): - delattr(self.instrumentor, "tracer_provider") + # Reset tracer to ensure test isolation + if hasattr(self.instrumentor, "tracer"): + delattr(self.instrumentor, "tracer") def test_instrument_without_tracer_provider_kwargs(self): - """Test _instrument method when no tracer_provider in kwargs - should set to None""" + """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" # Execute - Actually test the mcpinstrumentor method - self.instrumentor._instrument() + with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer: + mock_get_tracer.return_value = "default_tracer" + self.instrumentor._instrument() - # Verify - tracer_provider should be None - self.assertTrue(hasattr(self.instrumentor, "tracer_provider")) - self.assertIsNone(self.instrumentor.tracer_provider) + # Verify - tracer should be set from trace.get_tracer + self.assertTrue(hasattr(self.instrumentor, "tracer")) + self.assertEqual(self.instrumentor.tracer, "default_tracer") + mock_get_tracer.assert_called_with("mcp") def test_instrument_with_tracer_provider_kwargs(self): - """Test _instrument method when tracer_provider is in kwargs - should set to that value""" + """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer""" # Setup provider = SimpleTracerProvider() # Execute - Actually test the mcpinstrumentor method self.instrumentor._instrument(tracer_provider=provider) - # Verify - tracer_provider should be set to the provided value - self.assertTrue(hasattr(self.instrumentor, "tracer_provider")) - self.assertEqual(self.instrumentor.tracer_provider, provider) + # Verify - tracer should be set from the provided tracer_provider + self.assertTrue(hasattr(self.instrumentor, "tracer")) + self.assertEqual(self.instrumentor.tracer, "mock_tracer_from_provider") + self.assertTrue(provider.get_tracer_called) + self.assertEqual(provider.tracer_name, "mcp") class TestInstrumentationDependencies(unittest.TestCase): @@ -171,9 +192,13 @@ def __init__(self, name, arguments=None): # Verify the actual mcpinstrumentor method worked correctly client_data = modified_request.model_dump() self.assertIn("_meta", client_data["params"]) - self.assertIn("trace_context", client_data["params"]["_meta"]) - self.assertEqual(client_data["params"]["_meta"]["trace_context"]["trace_id"], 98765) - self.assertEqual(client_data["params"]["_meta"]["trace_context"]["span_id"], 43210) + self.assertIn("traceparent", client_data["params"]["_meta"]) + + # Verify traceparent format contains correct trace/span IDs + traceparent = client_data["params"]["_meta"]["traceparent"] + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 98765) # trace_id + self.assertEqual(int(parts[2], 16), 43210) # span_id # Verify the tool call data is also preserved self.assertEqual(client_data["params"]["name"], "create_metric") @@ -185,7 +210,9 @@ class TestInstrumentedMCPServer(unittest.TestCase): def setUp(self): self.instrumentor = MCPInstrumentor() - self.instrumentor.tracer_provider = None + # Initialize tracer so the instrumentor can work + mock_tracer = MagicMock() + self.instrumentor.tracer = mock_tracer def test_no_trace_context_fallback(self): """Test graceful handling when no trace context is present on server side""" @@ -275,15 +302,15 @@ class MCPServerRequestParams: def __init__(self, params_data): self.name = params_data["name"] self.arguments = params_data.get("arguments") - # Extract trace context from _meta if present - if "_meta" in params_data and "trace_context" in params_data["_meta"]: - self.meta = MCPServerRequestMeta(params_data["_meta"]["trace_context"]) + # Extract traceparent from _meta if present + if "_meta" in params_data and "traceparent" in params_data["_meta"]: + self.meta = MCPServerRequestMeta(params_data["_meta"]["traceparent"]) else: self.meta = None class MCPServerRequestMeta: - def __init__(self, trace_context): - self.trace_context = trace_context + def __init__(self, traceparent): + self.traceparent = traceparent # Mock client and server that actually communicate class EndToEndMCPSystem: @@ -307,13 +334,19 @@ async def server_handle_request(self, session, server_request): """Server handles the request it received""" self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") - # Check if trace context was received - if server_request.params.meta and server_request.params.meta.trace_context: - trace_info = server_request.params.meta.trace_context - self.communication_log.append( - f"SERVER: Found trace context - trace_id: {trace_info['trace_id']}, " - f"span_id: {trace_info['span_id']}" - ) + # Check if traceparent was received + if server_request.params.meta and server_request.params.meta.traceparent: + traceparent = server_request.params.meta.traceparent + # Parse traceparent to extract trace_id and span_id + parts = traceparent.split("-") + if len(parts) == 4: + trace_id = int(parts[1], 16) + span_id = int(parts[2], 16) + self.communication_log.append( + f"SERVER: Found trace context - trace_id: {trace_id}, " f"span_id: {span_id}" + ) + else: + self.communication_log.append("SERVER: Invalid traceparent format") else: self.communication_log.append("SERVER: No trace context found") @@ -339,6 +372,8 @@ async def server_handle_request(self, session, server_request): with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( "sys.modules", {"mcp.types": MagicMock()} ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"): + # Override the setup tracer with the properly mocked one + self.instrumentor.tracer = mock_tracer client_result = asyncio.run( self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {}) @@ -352,11 +387,15 @@ async def server_handle_request(self, session, server_request): sent_request = e2e_system.last_sent_request sent_request_data = sent_request.model_dump() - # Verify trace context was injected by client instrumentation + # Verify traceparent was injected by client instrumentation self.assertIn("_meta", sent_request_data["params"]) - self.assertIn("trace_context", sent_request_data["params"]["_meta"]) - self.assertEqual(sent_request_data["params"]["_meta"]["trace_context"]["trace_id"], 12345) - self.assertEqual(sent_request_data["params"]["_meta"]["trace_context"]["span_id"], 67890) + self.assertIn("traceparent", sent_request_data["params"]["_meta"]) + + # Parse and verify traceparent contains correct trace/span IDs + traceparent = sent_request_data["params"]["_meta"]["traceparent"] + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 12345) # trace_id + self.assertEqual(int(parts[2], 16), 67890) # span_id # STEP 2: Server receives the EXACT request that client sent # Create server request from the client's serialized data @@ -389,11 +428,15 @@ async def server_handle_request(self, session, server_request): self.assertEqual(server_request.params.arguments["name"], "cpu_usage") self.assertEqual(server_request.params.arguments["value"], 85) - # Verify the trace context made it through end-to-end + # Verify the traceparent made it through end-to-end self.assertIsNotNone(server_request.params.meta) - self.assertIsNotNone(server_request.params.meta.trace_context) - self.assertEqual(server_request.params.meta.trace_context["trace_id"], 12345) - self.assertEqual(server_request.params.meta.trace_context["span_id"], 67890) + self.assertIsNotNone(server_request.params.meta.traceparent) + + # Parse traceparent and verify trace/span IDs + traceparent = server_request.params.meta.traceparent + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 12345) # trace_id + self.assertEqual(int(parts[2], 16), 67890) # span_id # Verify complete communication flow expected_log_entries = [ From 10d91ac320dc7fb204c2fe1d9dbcd0e8d5736720 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 15:21:18 -0700 Subject: [PATCH 05/41] Fixed mcp supported python --- aws-opentelemetry-distro/loggertwo.log | 0 .../distro/mcpinstrumentor/mcpinstrumentor.py | 216 +++++++++--------- .../distro/mcpinstrumentor/package.py | 1 + .../distro/mcpinstrumentor/pyproject.toml | 6 +- .../distro/mcpinstrumentor/version.py | 1 + 5 files changed, 115 insertions(+), 109 deletions(-) delete mode 100644 aws-opentelemetry-distro/loggertwo.log create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py diff --git a/aws-opentelemetry-distro/loggertwo.log b/aws-opentelemetry-distro/loggertwo.log deleted file mode 100644 index e69de29bb..000000000 diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 8df12cb26..4c77b95ec 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -1,13 +1,14 @@ import logging from typing import Any, Collection -from openinference.instrumentation.mcp.package import _instruments from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap +from .package import _instruments + def setup_logger_two(): logger = logging.getLogger("loggertwo") @@ -21,80 +22,11 @@ def setup_logger_two(): return logger -logger_two = setup_logger_two() - - class MCPInstrumentor(BaseInstrumentor): """ An instrumenter for MCP. """ - def instrumentation_dependencies(self) -> Collection[str]: - return _instruments - - def _instrument(self, **kwargs: Any) -> None: - tracer_provider = kwargs.get("tracer_provider") - if tracer_provider: - self.tracer_provider = tracer_provider - else: - self.tracer_provider = None - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.shared.session", - "BaseSession.send_request", - self._wrap_send_request, - ), - "mcp.shared.session", - ) - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.server.lowlevel.server", - "Server._handle_request", - self._wrap_handle_request, - ), - "mcp.server.lowlevel.server", - ) - - def _uninstrument(self, **kwargs: Any) -> None: - unwrap("mcp.shared.session", "BaseSession.send_request") - unwrap("mcp.server.lowlevel.server", "Server._handle_request") - - def handle_attributes(self, span, request, is_client=True): - import mcp.types as types - - operation = "Server Handle Request" - if isinstance(request, types.ListToolsRequest): - operation = "ListTool" - span.set_attribute("mcp.list_tools", True) - elif isinstance(request, types.CallToolRequest): - if hasattr(request, "params") and hasattr(request.params, "name"): - operation = request.params.name - span.set_attribute("mcp.call_tool", True) - if is_client: - self._add_client_attributes(span, operation, request) - else: - self._add_server_attributes(span, operation, request) - - def _add_client_attributes(self, span, operation, request): - span.set_attribute("span.kind", "CLIENT") - span.set_attribute("aws.remote.service", "Appsignals MCP Server") - span.set_attribute("aws.remote.operation", operation) - if hasattr(request, "params") and hasattr(request.params, "name"): - span.set_attribute("tool.name", request.params.name) - - def _add_server_attributes(self, span, operation, request): - span.set_attribute("server_side", True) - span.set_attribute("aws.span.kind", "SERVER") - if hasattr(request, "params") and hasattr(request.params, "name"): - span.set_attribute("tool.name", request.params.name) - - def _inject_trace_context(self, request_data, span_ctx): - if "params" not in request_data: - request_data["params"] = {} - if "_meta" not in request_data["params"]: - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"]["trace_context"] = {"trace_id": span_ctx.trace_id, "span_id": span_ctx.span_id} - # Send Request Wrapper def _wrap_send_request(self, wrapped, instance, args, kwargs): """ @@ -106,11 +38,7 @@ def _wrap_send_request(self, wrapped, instance, args, kwargs): """ async def async_wrapper(): - if self.tracer_provider is None: - tracer = trace.get_tracer("mcp.client") - else: - tracer = self.tracer_provider.get_tracer("mcp.client") - with tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span: + with self.tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span: span_ctx = span.get_span_context() request = args[0] if len(args) > 0 else kwargs.get("request") if request: @@ -133,19 +61,6 @@ async def async_wrapper(): return async_wrapper() - def _get_span_name(self, req): - span_name = "unknown" - import mcp.types as types - - if isinstance(req, types.ListToolsRequest): - span_name = "tools/list" - elif isinstance(req, types.CallToolRequest): - if hasattr(req, "params") and hasattr(req.params, "name"): - span_name = f"tools/{req.params.name}" - else: - span_name = "unknown" - return span_name - # Handle Request Wrapper async def _wrap_handle_request(self, wrapped, instance, args, kwargs): """ @@ -154,29 +69,20 @@ async def _wrap_handle_request(self, wrapped, instance, args, kwargs): the request's params._meta field, and creates server-side OpenTelemetry spans linked to the client spans. The wrapper also does not change the original function's behavior by calling it with identical parameters ensuring no breaking changes to the MCP server functionality. + + request (args[1]) is typically an instance of CallToolRequest or ListToolsRequest + and should have the structure: + request.params.meta.traceparent -> "00---01" """ req = args[1] if len(args) > 1 else None - trace_context = None + traceparent = None if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - trace_context = req.params.meta.trace_context - if trace_context: - - if self.tracer_provider is None: - tracer = trace.get_tracer("mcp.server") - else: - tracer = self.tracer_provider.get_tracer("mcp.server") - trace_id = trace_context.get("trace_id") - span_id = trace_context.get("span_id") - span_context = trace.SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=True, - trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), - trace_state=trace.TraceState(), - ) + traceparent = getattr(req.params.meta, "traceparent", None) + span_context = self._extract_span_context_from_traceparent(traceparent) if traceparent else None + if span_context: span_name = self._get_span_name(req) - with tracer.start_as_current_span( + with self.tracer.start_as_current_span( span_name, kind=trace.SpanKind.SERVER, context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), @@ -186,3 +92,101 @@ async def _wrap_handle_request(self, wrapped, instance, args, kwargs): return result else: return await wrapped(*args, **kwargs) + + def _inject_trace_context(self, request_data, span_ctx): + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + trace_id_hex = f"{span_ctx.trace_id:032x}" + span_id_hex = f"{span_ctx.span_id:016x}" + trace_flags = "01" + traceparent = f"00-{trace_id_hex}-{span_id_hex}-{trace_flags}" + request_data["params"]["_meta"]["traceparent"] = traceparent + + def _extract_span_context_from_traceparent(self, traceparent): + parts = traceparent.split("-") + if len(parts) == 4: + try: + trace_id = int(parts[1], 16) + span_id = int(parts[2], 16) + return trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), + trace_state=trace.TraceState(), + ) + except ValueError: + return None + return None + + def _get_span_name(self, req): + span_name = "unknown" + import mcp.types as types + + if isinstance(req, types.ListToolsRequest): + span_name = "tools/list" + elif isinstance(req, types.CallToolRequest): + if hasattr(req, "params") and hasattr(req.params, "name"): + span_name = f"tools/{req.params.name}" + else: + span_name = "unknown" + return span_name + + def handle_attributes(self, span, request, is_client=True): + import mcp.types as types + + operation = self._get_span_name(request) + if isinstance(request, types.ListToolsRequest): + operation = "ListTool" + span.set_attribute("mcp.list_tools", True) + elif isinstance(request, types.CallToolRequest): + if hasattr(request, "params") and hasattr(request.params, "name"): + operation = request.params.name + span.set_attribute("mcp.call_tool", True) + if is_client: + self._add_client_attributes(span, operation, request) + else: + self._add_server_attributes(span, operation, request) + + def _add_client_attributes(self, span, operation, request): + span.set_attribute("aws.remote.service", "Appsignals MCP Server") + span.set_attribute("aws.remote.operation", operation) + if hasattr(request, "params") and hasattr(request.params, "name"): + span.set_attribute("tool.name", request.params.name) + + def _add_server_attributes(self, span, operation, request): + span.set_attribute("server_side", True) + if hasattr(request, "params") and hasattr(request.params, "name"): + span.set_attribute("tool.name", request.params.name) + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: Any) -> None: + tracer_provider = kwargs.get("tracer_provider") + if tracer_provider: + self.tracer = tracer_provider.get_tracer("mcp") + else: + self.tracer = trace.get_tracer("mcp") + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.shared.session", + "BaseSession.send_request", + self._wrap_send_request, + ), + "mcp.shared.session", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.lowlevel.server", + "Server._handle_request", + self._wrap_handle_request, + ), + "mcp.server.lowlevel.server", + ) + + def _uninstrument(self, **kwargs: Any) -> None: + unwrap("mcp.shared.session", "BaseSession.send_request") + unwrap("mcp.server.lowlevel.server", "Server._handle_request") diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py new file mode 100644 index 000000000..e43aa8566 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py @@ -0,0 +1 @@ +_instruments = ("mcp >= 1.6.0",) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml index 0ca34dd39..4d237fa4c 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.0" description = "OpenTelemetry MCP instrumentation for AWS Distro" readme = "README.md" license = "Apache-2.0" -requires-python = ">=3.9" +requires-python = ">=3.10" authors = [ { name = "Johnny Lin", email = "jonzilin@amazon.com" }, ] @@ -18,17 +18,17 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] dependencies = [ - "openinference-instrumentation-mcp", + "mcp", "opentelemetry-api", "opentelemetry-instrumentation", "opentelemetry-sdk", + "opentelemetry-semantic-conventions", "wrapt" ] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py new file mode 100644 index 000000000..3dc1f76bc --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py @@ -0,0 +1 @@ +__version__ = "0.1.0" From 0e6e3391645b59718787ea194facd16427921b77 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 16:00:58 -0700 Subject: [PATCH 06/41] Added copywrite headers --- .../opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py | 2 ++ .../src/amazon/opentelemetry/distro/mcpinstrumentor/package.py | 2 ++ .../src/amazon/opentelemetry/distro/mcpinstrumentor/version.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 4c77b95ec..00468b259 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -1,3 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 import logging from typing import Any, Collection diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py index e43aa8566..2694dbbb5 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py @@ -1 +1,3 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 _instruments = ("mcp >= 1.6.0",) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py index 3dc1f76bc..4aab890bb 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py @@ -1 +1,3 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 __version__ = "0.1.0" From 6a5f9d6f047817eb973933b4b67ac413c9d90729 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 16:35:21 -0700 Subject: [PATCH 07/41] Testing lint checker --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 16 ++++++++-------- .../opentelemetry/distro/test_mcpinstrumentor.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 00468b259..7a120470c 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -24,7 +24,7 @@ def setup_logger_two(): return logger -class MCPInstrumentor(BaseInstrumentor): +class MCPInstrumentor(BaseInstrumentor): # pylint: disable=attribute-defined-outside-init """ An instrumenter for MCP. """ @@ -95,7 +95,7 @@ async def _wrap_handle_request(self, wrapped, instance, args, kwargs): else: return await wrapped(*args, **kwargs) - def _inject_trace_context(self, request_data, span_ctx): + def _inject_trace_context(self, request_data, span_ctx): # pylint: disable=no-self-use if "params" not in request_data: request_data["params"] = {} if "_meta" not in request_data["params"]: @@ -106,7 +106,7 @@ def _inject_trace_context(self, request_data, span_ctx): traceparent = f"00-{trace_id_hex}-{span_id_hex}-{trace_flags}" request_data["params"]["_meta"]["traceparent"] = traceparent - def _extract_span_context_from_traceparent(self, traceparent): + def _extract_span_context_from_traceparent(self, traceparent): # pylint: disable=no-self-use parts = traceparent.split("-") if len(parts) == 4: try: @@ -123,9 +123,9 @@ def _extract_span_context_from_traceparent(self, traceparent): return None return None - def _get_span_name(self, req): + def _get_span_name(self, req): # pylint: disable=no-self-use span_name = "unknown" - import mcp.types as types + import mcp.types as types # pylint: disable=import-outside-toplevel if isinstance(req, types.ListToolsRequest): span_name = "tools/list" @@ -137,7 +137,7 @@ def _get_span_name(self, req): return span_name def handle_attributes(self, span, request, is_client=True): - import mcp.types as types + import mcp.types as types # pylint: disable=import-outside-toplevel operation = self._get_span_name(request) if isinstance(request, types.ListToolsRequest): @@ -152,13 +152,13 @@ def handle_attributes(self, span, request, is_client=True): else: self._add_server_attributes(span, operation, request) - def _add_client_attributes(self, span, operation, request): + def _add_client_attributes(self, span, operation, request): # pylint: disable=no-self-use span.set_attribute("aws.remote.service", "Appsignals MCP Server") span.set_attribute("aws.remote.operation", operation) if hasattr(request, "params") and hasattr(request.params, "name"): span.set_attribute("tool.name", request.params.name) - def _add_server_attributes(self, span, operation, request): + def _add_server_attributes(self, span, operation, request): # pylint: disable=no-self-use span.set_attribute("server_side", True) if hasattr(request, "params") and hasattr(request.params, "name"): span.set_attribute("tool.name", request.params.name) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index a17a73d0a..dcd34b5aa 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -1,3 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + """ Unit tests for MCPInstrumentor - testing actual mcpinstrumentor methods """ @@ -8,9 +11,12 @@ import unittest from unittest.mock import MagicMock +# Add src path for imports project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) src_path = os.path.join(project_root, "src") sys.path.insert(0, src_path) + +# pylint: disable=wrong-import-position from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor # noqa: E402 @@ -214,7 +220,7 @@ def setUp(self): mock_tracer = MagicMock() self.instrumentor.tracer = mock_tracer - def test_no_trace_context_fallback(self): + def test_no_trace_context_fallback(self): # pylint: disable=no-self-use """Test graceful handling when no trace context is present on server side""" class MockServerNoTrace: @@ -258,7 +264,7 @@ def __init__(self, name): # Should not create traced spans when no trace context is present mock_tracer.start_as_current_span.assert_not_called() - def test_end_to_end_client_server_communication(self): + def test_end_to_end_client_server_communication(self): # pylint: disable=too-many-locals,too-many-statements """Test where server actually receives what client sends (including injected trace context)""" # Create realistic request/response classes From db1af61642fef7b05aa6889d4f481bfeb1af3edc Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 16:42:18 -0700 Subject: [PATCH 08/41] Testing lint checker 2 --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 7a120470c..d213eaeae 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -24,11 +24,15 @@ def setup_logger_two(): return logger -class MCPInstrumentor(BaseInstrumentor): # pylint: disable=attribute-defined-outside-init +class MCPInstrumentor(BaseInstrumentor): """ An instrumenter for MCP. """ + def __init__(self): + super().__init__() + self.tracer = None + # Send Request Wrapper def _wrap_send_request(self, wrapped, instance, args, kwargs): """ @@ -125,7 +129,7 @@ def _extract_span_context_from_traceparent(self, traceparent): # pylint: disabl def _get_span_name(self, req): # pylint: disable=no-self-use span_name = "unknown" - import mcp.types as types # pylint: disable=import-outside-toplevel + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import if isinstance(req, types.ListToolsRequest): span_name = "tools/list" @@ -137,7 +141,7 @@ def _get_span_name(self, req): # pylint: disable=no-self-use return span_name def handle_attributes(self, span, request, is_client=True): - import mcp.types as types # pylint: disable=import-outside-toplevel + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import operation = self._get_span_name(request) if isinstance(request, types.ListToolsRequest): From a6ba4c451b2b31912c967df54e6a2f747026a54f Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 17:00:02 -0700 Subject: [PATCH 09/41] Testing lint checker3 --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 16 ++--- .../distro/test_mcpinstrumentor.py | 62 +++++++++---------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index d213eaeae..46f65c897 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -29,12 +29,12 @@ class MCPInstrumentor(BaseInstrumentor): An instrumenter for MCP. """ - def __init__(self): + def __init__(self): # pylint: disable=no-self-use super().__init__() self.tracer = None # Send Request Wrapper - def _wrap_send_request(self, wrapped, instance, args, kwargs): + def _wrap_send_request(self, wrapped, instance, args, kwargs): # pylint: disable=no-self-use """ Changes made: The wrapper intercepts the request before sending, injects distributed tracing context into the @@ -43,7 +43,7 @@ def _wrap_send_request(self, wrapped, instance, args, kwargs): type and calling the original function with identical parameters. """ - async def async_wrapper(): + async def async_wrapper(): # pylint: disable=no-self-use with self.tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span: span_ctx = span.get_span_context() request = args[0] if len(args) > 0 else kwargs.get("request") @@ -68,7 +68,7 @@ async def async_wrapper(): return async_wrapper() # Handle Request Wrapper - async def _wrap_handle_request(self, wrapped, instance, args, kwargs): + async def _wrap_handle_request(self, wrapped, instance, args, kwargs): # pylint: disable=no-self-use """ Changes made: This wrapper intercepts requests before processing, extracts distributed tracing context from @@ -140,7 +140,7 @@ def _get_span_name(self, req): # pylint: disable=no-self-use span_name = "unknown" return span_name - def handle_attributes(self, span, request, is_client=True): + def handle_attributes(self, span, request, is_client=True): # pylint: disable=no-self-use import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import operation = self._get_span_name(request) @@ -167,10 +167,10 @@ def _add_server_attributes(self, span, operation, request): # pylint: disable=n if hasattr(request, "params") and hasattr(request.params, "name"): span.set_attribute("tool.name", request.params.name) - def instrumentation_dependencies(self) -> Collection[str]: + def instrumentation_dependencies(self) -> Collection[str]: # pylint: disable=no-self-use return _instruments - def _instrument(self, **kwargs: Any) -> None: + def _instrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use tracer_provider = kwargs.get("tracer_provider") if tracer_provider: self.tracer = tracer_provider.get_tracer("mcp") @@ -193,6 +193,6 @@ def _instrument(self, **kwargs: Any) -> None: "mcp.server.lowlevel.server", ) - def _uninstrument(self, **kwargs: Any) -> None: + def _uninstrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.server.lowlevel.server", "Server._handle_request") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index dcd34b5aa..774717620 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -23,7 +23,7 @@ class SimpleSpanContext: """Simple mock span context without using MagicMock""" - def __init__(self, trace_id, span_id): + def __init__(self, trace_id, span_id): # pylint: disable=no-self-use self.trace_id = trace_id self.span_id = span_id @@ -31,11 +31,11 @@ def __init__(self, trace_id, span_id): class SimpleTracerProvider: """Simple mock tracer provider without using MagicMock""" - def __init__(self): + def __init__(self): # pylint: disable=no-self-use self.get_tracer_called = False self.tracer_name = None - def get_tracer(self, name): + def get_tracer(self, name): # pylint: disable=no-self-use self.get_tracer_called = True self.tracer_name = name return "mock_tracer_from_provider" @@ -44,10 +44,10 @@ def get_tracer(self, name): class TestInjectTraceContext(unittest.TestCase): """Test the _inject_trace_context method""" - def setUp(self): + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() - def test_inject_trace_context_empty_dict(self): + def test_inject_trace_context_empty_dict(self): # pylint: disable=no-self-use """Test injecting trace context into empty dictionary""" # Setup request_data = {} @@ -70,7 +70,7 @@ def test_inject_trace_context_empty_dict(self): self.assertEqual(int(parts[1], 16), 12345) # trace_id self.assertEqual(int(parts[2], 16), 67890) # span_id - def test_inject_trace_context_existing_params(self): + def test_inject_trace_context_existing_params(self): # pylint: disable=no-self-use """Test injecting trace context when params already exist""" # Setup request_data = {"params": {"existing_field": "test_value"}} @@ -94,13 +94,13 @@ def test_inject_trace_context_existing_params(self): class TestTracerProvider(unittest.TestCase): """Test the tracer provider kwargs logic in _instrument method""" - def setUp(self): + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() # Reset tracer to ensure test isolation if hasattr(self.instrumentor, "tracer"): delattr(self.instrumentor, "tracer") - def test_instrument_without_tracer_provider_kwargs(self): + def test_instrument_without_tracer_provider_kwargs(self): # pylint: disable=no-self-use """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" # Execute - Actually test the mcpinstrumentor method with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer: @@ -112,7 +112,7 @@ def test_instrument_without_tracer_provider_kwargs(self): self.assertEqual(self.instrumentor.tracer, "default_tracer") mock_get_tracer.assert_called_with("mcp") - def test_instrument_with_tracer_provider_kwargs(self): + def test_instrument_with_tracer_provider_kwargs(self): # pylint: disable=no-self-use """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer""" # Setup provider = SimpleTracerProvider() @@ -130,10 +130,10 @@ def test_instrument_with_tracer_provider_kwargs(self): class TestInstrumentationDependencies(unittest.TestCase): """Test the instrumentation_dependencies method""" - def setUp(self): + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() - def test_instrumentation_dependencies(self): + def test_instrumentation_dependencies(self): # pylint: disable=no-self-use """Test that instrumentation_dependencies method returns the expected dependencies""" # Execute - Actually test the mcpinstrumentor method dependencies = self.instrumentor.instrumentation_dependencies() @@ -147,19 +147,19 @@ def test_instrumentation_dependencies(self): class TestTraceContextInjection(unittest.TestCase): """Test trace context injection using actual mcpinstrumentor methods""" - def setUp(self): + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() - def test_trace_context_injection_with_realistic_request(self): + def test_trace_context_injection_with_realistic_request(self): # pylint: disable=no-self-use """Test actual trace context injection using mcpinstrumentor._inject_trace_context with realistic MCP request""" # Create a realistic MCP request structure class CallToolRequest: - def __init__(self, tool_name, arguments=None): + def __init__(self, tool_name, arguments=None): # pylint: disable=no-self-use self.root = self self.params = CallToolParams(tool_name, arguments) - def model_dump(self, by_alias=True, mode="json", exclude_none=True): + def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use result = {"method": "call_tool", "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments @@ -169,7 +169,7 @@ def model_dump(self, by_alias=True, mode="json", exclude_none=True): return result @classmethod - def model_validate(cls, data): + def model_validate(cls, data): # pylint: disable=no-self-use instance = cls(data["params"]["name"], data["params"].get("arguments")) # Restore _meta field if present if "_meta" in data["params"]: @@ -177,7 +177,7 @@ def model_validate(cls, data): return instance class CallToolParams: - def __init__(self, name, arguments=None): + def __init__(self, name, arguments=None): # pylint: disable=no-self-use self.name = name self.arguments = arguments self._meta = None # Will hold trace context @@ -214,7 +214,7 @@ def __init__(self, name, arguments=None): class TestInstrumentedMCPServer(unittest.TestCase): """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" - def setUp(self): + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() # Initialize tracer so the instrumentor can work mock_tracer = MagicMock() @@ -224,15 +224,15 @@ def test_no_trace_context_fallback(self): # pylint: disable=no-self-use """Test graceful handling when no trace context is present on server side""" class MockServerNoTrace: - async def _handle_request(self, session, request): + async def _handle_request(self, session, request): # pylint: disable=no-self-use return {"success": True, "handled_without_trace": True} class MockServerRequestNoTrace: - def __init__(self, tool_name): + def __init__(self, tool_name): # pylint: disable=no-self-use self.params = MockServerRequestParamsNoTrace(tool_name) class MockServerRequestParamsNoTrace: - def __init__(self, name): + def __init__(self, name): # pylint: disable=no-self-use self.name = name self.meta = None # No trace context @@ -269,12 +269,12 @@ def test_end_to_end_client_server_communication(self): # pylint: disable=too-ma # Create realistic request/response classes class MCPRequest: - def __init__(self, tool_name, arguments=None, method="call_tool"): + def __init__(self, tool_name, arguments=None, method="call_tool"): # pylint: disable=no-self-use self.root = self self.params = MCPRequestParams(tool_name, arguments) self.method = method - def model_dump(self, by_alias=True, mode="json", exclude_none=True): + def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use result = {"method": self.method, "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments @@ -284,7 +284,7 @@ def model_dump(self, by_alias=True, mode="json", exclude_none=True): return result @classmethod - def model_validate(cls, data): + def model_validate(cls, data): # pylint: disable=no-self-use method = data.get("method", "call_tool") instance = cls(data["params"]["name"], data["params"].get("arguments"), method) # Restore _meta field if present @@ -293,19 +293,19 @@ def model_validate(cls, data): return instance class MCPRequestParams: - def __init__(self, name, arguments=None): + def __init__(self, name, arguments=None): # pylint: disable=no-self-use self.name = name self.arguments = arguments self._meta = None class MCPServerRequest: - def __init__(self, client_request_data): + def __init__(self, client_request_data): # pylint: disable=no-self-use """Server request created from client's serialized data""" self.method = client_request_data.get("method", "call_tool") self.params = MCPServerRequestParams(client_request_data["params"]) class MCPServerRequestParams: - def __init__(self, params_data): + def __init__(self, params_data): # pylint: disable=no-self-use self.name = params_data["name"] self.arguments = params_data.get("arguments") # Extract traceparent from _meta if present @@ -315,16 +315,16 @@ def __init__(self, params_data): self.meta = None class MCPServerRequestMeta: - def __init__(self, traceparent): + def __init__(self, traceparent): # pylint: disable=no-self-use self.traceparent = traceparent # Mock client and server that actually communicate class EndToEndMCPSystem: - def __init__(self): + def __init__(self): # pylint: disable=no-self-use self.communication_log = [] self.last_sent_request = None - async def client_send_request(self, request): + async def client_send_request(self, request): # pylint: disable=no-self-use """Client sends request - captures what gets sent""" self.communication_log.append("CLIENT: Preparing to send request") self.last_sent_request = request # Capture the modified request @@ -336,7 +336,7 @@ async def client_send_request(self, request): # Return client response return {"success": True, "client_response": "Request sent successfully"} - async def server_handle_request(self, session, server_request): + async def server_handle_request(self, session, server_request): # pylint: disable=no-self-use """Server handles the request it received""" self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") From 4ea1757bd9bea20eba66395cbd263da25f051529 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 24 Jul 2025 17:13:47 -0700 Subject: [PATCH 10/41] Test lint check4 --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 8 ++-- .../distro/test_mcpinstrumentor.py | 48 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 46f65c897..5e21e6578 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -29,7 +29,7 @@ class MCPInstrumentor(BaseInstrumentor): An instrumenter for MCP. """ - def __init__(self): # pylint: disable=no-self-use + def __init__(self): # pylint: disable=no-self-use super().__init__() self.tracer = None @@ -140,7 +140,7 @@ def _get_span_name(self, req): # pylint: disable=no-self-use span_name = "unknown" return span_name - def handle_attributes(self, span, request, is_client=True): # pylint: disable=no-self-use + def handle_attributes(self, span, request, is_client=True): # pylint: disable=no-self-use import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import operation = self._get_span_name(request) @@ -170,7 +170,7 @@ def _add_server_attributes(self, span, operation, request): # pylint: disable=n def instrumentation_dependencies(self) -> Collection[str]: # pylint: disable=no-self-use return _instruments - def _instrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use + def _instrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use tracer_provider = kwargs.get("tracer_provider") if tracer_provider: self.tracer = tracer_provider.get_tracer("mcp") @@ -193,6 +193,6 @@ def _instrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use "mcp.server.lowlevel.server", ) - def _uninstrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use + def _uninstrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.server.lowlevel.server", "Server._handle_request") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 774717620..517c672cb 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -31,11 +31,11 @@ def __init__(self, trace_id, span_id): # pylint: disable=no-self-use class SimpleTracerProvider: """Simple mock tracer provider without using MagicMock""" - def __init__(self): # pylint: disable=no-self-use + def __init__(self): # pylint: disable=no-self-use self.get_tracer_called = False self.tracer_name = None - def get_tracer(self, name): # pylint: disable=no-self-use + def get_tracer(self, name): # pylint: disable=no-self-use self.get_tracer_called = True self.tracer_name = name return "mock_tracer_from_provider" @@ -44,10 +44,10 @@ def get_tracer(self, name): # pylint: disable=no-self-use class TestInjectTraceContext(unittest.TestCase): """Test the _inject_trace_context method""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() - def test_inject_trace_context_empty_dict(self): # pylint: disable=no-self-use + def test_inject_trace_context_empty_dict(self): # pylint: disable=no-self-use """Test injecting trace context into empty dictionary""" # Setup request_data = {} @@ -70,7 +70,7 @@ def test_inject_trace_context_empty_dict(self): # pylint: disable=no-self-use self.assertEqual(int(parts[1], 16), 12345) # trace_id self.assertEqual(int(parts[2], 16), 67890) # span_id - def test_inject_trace_context_existing_params(self): # pylint: disable=no-self-use + def test_inject_trace_context_existing_params(self): # pylint: disable=no-self-use """Test injecting trace context when params already exist""" # Setup request_data = {"params": {"existing_field": "test_value"}} @@ -94,13 +94,13 @@ def test_inject_trace_context_existing_params(self): # pylint: disable=no-sel class TestTracerProvider(unittest.TestCase): """Test the tracer provider kwargs logic in _instrument method""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() # Reset tracer to ensure test isolation if hasattr(self.instrumentor, "tracer"): delattr(self.instrumentor, "tracer") - def test_instrument_without_tracer_provider_kwargs(self): # pylint: disable=no-self-use + def test_instrument_without_tracer_provider_kwargs(self): # pylint: disable=no-self-use """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" # Execute - Actually test the mcpinstrumentor method with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer: @@ -130,10 +130,10 @@ def test_instrument_with_tracer_provider_kwargs(self): # pylint: disable=no-sel class TestInstrumentationDependencies(unittest.TestCase): """Test the instrumentation_dependencies method""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() - def test_instrumentation_dependencies(self): # pylint: disable=no-self-use + def test_instrumentation_dependencies(self): # pylint: disable=no-self-use """Test that instrumentation_dependencies method returns the expected dependencies""" # Execute - Actually test the mcpinstrumentor method dependencies = self.instrumentor.instrumentation_dependencies() @@ -147,7 +147,7 @@ def test_instrumentation_dependencies(self): # pylint: disable=no-self-use class TestTraceContextInjection(unittest.TestCase): """Test trace context injection using actual mcpinstrumentor methods""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() def test_trace_context_injection_with_realistic_request(self): # pylint: disable=no-self-use @@ -159,7 +159,7 @@ def __init__(self, tool_name, arguments=None): # pylint: disable=no-self-use self.root = self self.params = CallToolParams(tool_name, arguments) - def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use + def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use result = {"method": "call_tool", "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments @@ -177,7 +177,7 @@ def model_validate(cls, data): # pylint: disable=no-self-use return instance class CallToolParams: - def __init__(self, name, arguments=None): # pylint: disable=no-self-use + def __init__(self, name, arguments=None): # pylint: disable=no-self-use self.name = name self.arguments = arguments self._meta = None # Will hold trace context @@ -214,7 +214,7 @@ def __init__(self, name, arguments=None): # pylint: disable=no-self-use class TestInstrumentedMCPServer(unittest.TestCase): """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self): # pylint: disable=no-self-use self.instrumentor = MCPInstrumentor() # Initialize tracer so the instrumentor can work mock_tracer = MagicMock() @@ -224,15 +224,15 @@ def test_no_trace_context_fallback(self): # pylint: disable=no-self-use """Test graceful handling when no trace context is present on server side""" class MockServerNoTrace: - async def _handle_request(self, session, request): # pylint: disable=no-self-use + async def _handle_request(self, session, request): # pylint: disable=no-self-use return {"success": True, "handled_without_trace": True} class MockServerRequestNoTrace: - def __init__(self, tool_name): # pylint: disable=no-self-use + def __init__(self, tool_name): # pylint: disable=no-self-use self.params = MockServerRequestParamsNoTrace(tool_name) class MockServerRequestParamsNoTrace: - def __init__(self, name): # pylint: disable=no-self-use + def __init__(self, name): # pylint: disable=no-self-use self.name = name self.meta = None # No trace context @@ -274,7 +274,7 @@ def __init__(self, tool_name, arguments=None, method="call_tool"): # pylint: di self.params = MCPRequestParams(tool_name, arguments) self.method = method - def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use + def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use result = {"method": self.method, "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments @@ -293,19 +293,19 @@ def model_validate(cls, data): # pylint: disable=no-self-use return instance class MCPRequestParams: - def __init__(self, name, arguments=None): # pylint: disable=no-self-use + def __init__(self, name, arguments=None): # pylint: disable=no-self-use self.name = name self.arguments = arguments self._meta = None class MCPServerRequest: - def __init__(self, client_request_data): # pylint: disable=no-self-use + def __init__(self, client_request_data): # pylint: disable=no-self-use """Server request created from client's serialized data""" self.method = client_request_data.get("method", "call_tool") self.params = MCPServerRequestParams(client_request_data["params"]) class MCPServerRequestParams: - def __init__(self, params_data): # pylint: disable=no-self-use + def __init__(self, params_data): # pylint: disable=no-self-use self.name = params_data["name"] self.arguments = params_data.get("arguments") # Extract traceparent from _meta if present @@ -315,16 +315,16 @@ def __init__(self, params_data): # pylint: disable=no-self-use self.meta = None class MCPServerRequestMeta: - def __init__(self, traceparent): # pylint: disable=no-self-use + def __init__(self, traceparent): # pylint: disable=no-self-use self.traceparent = traceparent # Mock client and server that actually communicate class EndToEndMCPSystem: - def __init__(self): # pylint: disable=no-self-use + def __init__(self): # pylint: disable=no-self-use self.communication_log = [] self.last_sent_request = None - async def client_send_request(self, request): # pylint: disable=no-self-use + async def client_send_request(self, request): # pylint: disable=no-self-use """Client sends request - captures what gets sent""" self.communication_log.append("CLIENT: Preparing to send request") self.last_sent_request = request # Capture the modified request @@ -336,7 +336,7 @@ async def client_send_request(self, request): # pylint: disable=no-self-use # Return client response return {"success": True, "client_response": "Request sent successfully"} - async def server_handle_request(self, session, server_request): # pylint: disable=no-self-use + async def server_handle_request(self, session, server_request): # pylint: disable=no-self-use """Server handles the request it received""" self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") From 3f576c29b25081a2b936a19e829a5b57bc6d8c4a Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Fri, 25 Jul 2025 00:49:34 -0700 Subject: [PATCH 11/41] Added input/output types, further code cleanup, got rid of pylint disable self use --- .../distro/mcpinstrumentor/README.md | 22 ++- .../distro/mcpinstrumentor/mcpinstrumentor.py | 145 ++++++++++-------- .../distro/test_mcpinstrumentor.py | 100 ++++++------ 3 files changed, 152 insertions(+), 115 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md index 0e7b72007..53628380b 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md @@ -1,11 +1,11 @@ # MCP Instrumentor -OpenTelemetry MCP instrumentation package. +OpenTelemetry MCP instrumentation package for AWS Distro. ## Installation ```bash -pip install mcpinstrumentor +pip install amazon-opentelemetry-distro-mcpinstrumentor ``` ## Usage @@ -14,4 +14,20 @@ pip install mcpinstrumentor from mcpinstrumentor import MCPInstrumentor MCPInstrumentor().instrument() -``` \ No newline at end of file +``` + +## Configuration + +### Environment Variables + +- `MCP_SERVICE_NAME`: Sets the service name for MCP client spans. Defaults to "Generic MCP Server" if not set. + +```bash +export MCP_SERVICE_NAME="My Custom MCP Server" +``` + +## Features + +- Automatic instrumentation of MCP client and server requests +- Distributed tracing support with trace context propagation +- Configurable service naming via environment variables \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 5e21e6578..b3a85d894 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Any, Collection +from typing import Any, Callable, Collection, Dict, Tuple +from mcp import ClientRequest from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace @@ -29,12 +30,46 @@ class MCPInstrumentor(BaseInstrumentor): An instrumenter for MCP. """ - def __init__(self): # pylint: disable=no-self-use + def __init__(self): super().__init__() self.tracer = None + @staticmethod + def instrumentation_dependencies() -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: Any) -> None: + tracer_provider = kwargs.get("tracer_provider") + if tracer_provider: + self.tracer = tracer_provider.get_tracer("mcp") + else: + self.tracer = trace.get_tracer("mcp") + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.shared.session", + "BaseSession.send_request", + self._wrap_send_request, + ), + "mcp.shared.session", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.lowlevel.server", + "Server._handle_request", + self._wrap_handle_request, + ), + "mcp.server.lowlevel.server", + ) + + @staticmethod + def _uninstrument(**kwargs: Any) -> None: + unwrap("mcp.shared.session", "BaseSession.send_request") + unwrap("mcp.server.lowlevel.server", "Server._handle_request") + # Send Request Wrapper - def _wrap_send_request(self, wrapped, instance, args, kwargs): # pylint: disable=no-self-use + def _wrap_send_request( + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Callable: """ Changes made: The wrapper intercepts the request before sending, injects distributed tracing context into the @@ -43,14 +78,14 @@ def _wrap_send_request(self, wrapped, instance, args, kwargs): # pylint: disabl type and calling the original function with identical parameters. """ - async def async_wrapper(): # pylint: disable=no-self-use + async def async_wrapper(): with self.tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span: span_ctx = span.get_span_context() request = args[0] if len(args) > 0 else kwargs.get("request") if request: req_root = request.root if hasattr(request, "root") else request - self.handle_attributes(span, req_root, True) + self._generate_mcp_attributes(span, req_root, is_client=True) request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) self._inject_trace_context(request_data, span_ctx) # Reconstruct request object with injected trace context @@ -68,7 +103,9 @@ async def async_wrapper(): # pylint: disable=no-self-use return async_wrapper() # Handle Request Wrapper - async def _wrap_handle_request(self, wrapped, instance, args, kwargs): # pylint: disable=no-self-use + async def _wrap_handle_request( + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: """ Changes made: This wrapper intercepts requests before processing, extracts distributed tracing context from @@ -87,19 +124,35 @@ async def _wrap_handle_request(self, wrapped, instance, args, kwargs): # pylint traceparent = getattr(req.params.meta, "traceparent", None) span_context = self._extract_span_context_from_traceparent(traceparent) if traceparent else None if span_context: - span_name = self._get_span_name(req) + span_name = self._get_mcp_operation(req) with self.tracer.start_as_current_span( span_name, kind=trace.SpanKind.SERVER, context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), ) as span: - self.handle_attributes(span, req, False) + self._generate_mcp_attributes(span, req, False) result = await wrapped(*args, **kwargs) return result else: return await wrapped(*args, **kwargs) - def _inject_trace_context(self, request_data, span_ctx): # pylint: disable=no-self-use + def _generate_mcp_attributes(self, span: trace.Span, request: ClientRequest, is_client: bool) -> None: + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + + operation = "UnknownOperation" + if isinstance(request, types.ListToolsRequest): + operation = "ListTool" + span.set_attribute("mcp.list_tools", True) + elif isinstance(request, types.CallToolRequest): + operation = request.params.name + span.set_attribute("mcp.call_tool", True) + if is_client: + self._add_client_attributes(span, operation, request) + else: + self._add_server_attributes(span, operation, request) + + @staticmethod + def _inject_trace_context(request_data: Dict[str, Any], span_ctx) -> None: if "params" not in request_data: request_data["params"] = {} if "_meta" not in request_data["params"]: @@ -110,7 +163,8 @@ def _inject_trace_context(self, request_data, span_ctx): # pylint: disable=no-s traceparent = f"00-{trace_id_hex}-{span_id_hex}-{trace_flags}" request_data["params"]["_meta"]["traceparent"] = traceparent - def _extract_span_context_from_traceparent(self, traceparent): # pylint: disable=no-self-use + @staticmethod + def _extract_span_context_from_traceparent(traceparent: str): parts = traceparent.split("-") if len(parts) == 4: try: @@ -127,72 +181,29 @@ def _extract_span_context_from_traceparent(self, traceparent): # pylint: disabl return None return None - def _get_span_name(self, req): # pylint: disable=no-self-use - span_name = "unknown" + @staticmethod + def _get_mcp_operation(req: ClientRequest) -> str: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + span_name = "unknown" + if isinstance(req, types.ListToolsRequest): span_name = "tools/list" elif isinstance(req, types.CallToolRequest): - if hasattr(req, "params") and hasattr(req.params, "name"): - span_name = f"tools/{req.params.name}" - else: - span_name = "unknown" + span_name = f"tools/{req.params.name}" return span_name - def handle_attributes(self, span, request, is_client=True): # pylint: disable=no-self-use - import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - - operation = self._get_span_name(request) - if isinstance(request, types.ListToolsRequest): - operation = "ListTool" - span.set_attribute("mcp.list_tools", True) - elif isinstance(request, types.CallToolRequest): - if hasattr(request, "params") and hasattr(request.params, "name"): - operation = request.params.name - span.set_attribute("mcp.call_tool", True) - if is_client: - self._add_client_attributes(span, operation, request) - else: - self._add_server_attributes(span, operation, request) + @staticmethod + def _add_client_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: + import os # pylint: disable=import-outside-toplevel - def _add_client_attributes(self, span, operation, request): # pylint: disable=no-self-use - span.set_attribute("aws.remote.service", "Appsignals MCP Server") + service_name = os.environ.get("MCP_SERVICE_NAME", "Generic MCP Server") + span.set_attribute("aws.remote.service", service_name) span.set_attribute("aws.remote.operation", operation) - if hasattr(request, "params") and hasattr(request.params, "name"): + if hasattr(request, "params") and request.params and hasattr(request.params, "name"): span.set_attribute("tool.name", request.params.name) - def _add_server_attributes(self, span, operation, request): # pylint: disable=no-self-use - span.set_attribute("server_side", True) - if hasattr(request, "params") and hasattr(request.params, "name"): + @staticmethod + def _add_server_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: + if hasattr(request, "params") and request.params and hasattr(request.params, "name"): span.set_attribute("tool.name", request.params.name) - - def instrumentation_dependencies(self) -> Collection[str]: # pylint: disable=no-self-use - return _instruments - - def _instrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use - tracer_provider = kwargs.get("tracer_provider") - if tracer_provider: - self.tracer = tracer_provider.get_tracer("mcp") - else: - self.tracer = trace.get_tracer("mcp") - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.shared.session", - "BaseSession.send_request", - self._wrap_send_request, - ), - "mcp.shared.session", - ) - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.server.lowlevel.server", - "Server._handle_request", - self._wrap_handle_request, - ), - "mcp.server.lowlevel.server", - ) - - def _uninstrument(self, **kwargs: Any) -> None: # pylint: disable=no-self-use - unwrap("mcp.shared.session", "BaseSession.send_request") - unwrap("mcp.server.lowlevel.server", "Server._handle_request") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 517c672cb..0bfab1f00 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -9,6 +9,7 @@ import os import sys import unittest +from typing import Any, Dict, List, Optional from unittest.mock import MagicMock # Add src path for imports @@ -16,14 +17,13 @@ src_path = os.path.join(project_root, "src") sys.path.insert(0, src_path) -# pylint: disable=wrong-import-position from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor # noqa: E402 class SimpleSpanContext: """Simple mock span context without using MagicMock""" - def __init__(self, trace_id, span_id): # pylint: disable=no-self-use + def __init__(self, trace_id: int, span_id: int) -> None: self.trace_id = trace_id self.span_id = span_id @@ -31,11 +31,11 @@ def __init__(self, trace_id, span_id): # pylint: disable=no-self-use class SimpleTracerProvider: """Simple mock tracer provider without using MagicMock""" - def __init__(self): # pylint: disable=no-self-use + def __init__(self) -> None: self.get_tracer_called = False - self.tracer_name = None + self.tracer_name: Optional[str] = None - def get_tracer(self, name): # pylint: disable=no-self-use + def get_tracer(self, name: str) -> str: self.get_tracer_called = True self.tracer_name = name return "mock_tracer_from_provider" @@ -44,10 +44,10 @@ def get_tracer(self, name): # pylint: disable=no-self-use class TestInjectTraceContext(unittest.TestCase): """Test the _inject_trace_context method""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self) -> None: self.instrumentor = MCPInstrumentor() - def test_inject_trace_context_empty_dict(self): # pylint: disable=no-self-use + def test_inject_trace_context_empty_dict(self) -> None: """Test injecting trace context into empty dictionary""" # Setup request_data = {} @@ -70,7 +70,7 @@ def test_inject_trace_context_empty_dict(self): # pylint: disable=no-self-use self.assertEqual(int(parts[1], 16), 12345) # trace_id self.assertEqual(int(parts[2], 16), 67890) # span_id - def test_inject_trace_context_existing_params(self): # pylint: disable=no-self-use + def test_inject_trace_context_existing_params(self) -> None: """Test injecting trace context when params already exist""" # Setup request_data = {"params": {"existing_field": "test_value"}} @@ -94,13 +94,13 @@ def test_inject_trace_context_existing_params(self): # pylint: disable=no-self- class TestTracerProvider(unittest.TestCase): """Test the tracer provider kwargs logic in _instrument method""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self) -> None: self.instrumentor = MCPInstrumentor() # Reset tracer to ensure test isolation if hasattr(self.instrumentor, "tracer"): delattr(self.instrumentor, "tracer") - def test_instrument_without_tracer_provider_kwargs(self): # pylint: disable=no-self-use + def test_instrument_without_tracer_provider_kwargs(self) -> None: """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" # Execute - Actually test the mcpinstrumentor method with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer: @@ -112,7 +112,7 @@ def test_instrument_without_tracer_provider_kwargs(self): # pylint: disable=no- self.assertEqual(self.instrumentor.tracer, "default_tracer") mock_get_tracer.assert_called_with("mcp") - def test_instrument_with_tracer_provider_kwargs(self): # pylint: disable=no-self-use + def test_instrument_with_tracer_provider_kwargs(self) -> None: """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer""" # Setup provider = SimpleTracerProvider() @@ -130,10 +130,10 @@ def test_instrument_with_tracer_provider_kwargs(self): # pylint: disable=no-sel class TestInstrumentationDependencies(unittest.TestCase): """Test the instrumentation_dependencies method""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self) -> None: self.instrumentor = MCPInstrumentor() - def test_instrumentation_dependencies(self): # pylint: disable=no-self-use + def test_instrumentation_dependencies(self) -> None: """Test that instrumentation_dependencies method returns the expected dependencies""" # Execute - Actually test the mcpinstrumentor method dependencies = self.instrumentor.instrumentation_dependencies() @@ -147,19 +147,21 @@ def test_instrumentation_dependencies(self): # pylint: disable=no-self-use class TestTraceContextInjection(unittest.TestCase): """Test trace context injection using actual mcpinstrumentor methods""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self) -> None: self.instrumentor = MCPInstrumentor() - def test_trace_context_injection_with_realistic_request(self): # pylint: disable=no-self-use + def test_trace_context_injection_with_realistic_request(self) -> None: """Test actual trace context injection using mcpinstrumentor._inject_trace_context with realistic MCP request""" # Create a realistic MCP request structure class CallToolRequest: - def __init__(self, tool_name, arguments=None): # pylint: disable=no-self-use + def __init__(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> None: self.root = self self.params = CallToolParams(tool_name, arguments) - def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use + def model_dump( + self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True + ) -> Dict[str, Any]: result = {"method": "call_tool", "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments @@ -168,8 +170,9 @@ def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: result["params"]["_meta"] = self.params._meta return result + # converting raw dictionary data back into an instance of this class @classmethod - def model_validate(cls, data): # pylint: disable=no-self-use + def model_validate(cls, data: Dict[str, Any]) -> "CallToolRequest": instance = cls(data["params"]["name"], data["params"].get("arguments")) # Restore _meta field if present if "_meta" in data["params"]: @@ -177,10 +180,10 @@ def model_validate(cls, data): # pylint: disable=no-self-use return instance class CallToolParams: - def __init__(self, name, arguments=None): # pylint: disable=no-self-use + def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: self.name = name self.arguments = arguments - self._meta = None # Will hold trace context + self._meta: Optional[Dict[str, Any]] = None # Will hold trace context # Client creates original request client_request = CallToolRequest("create_metric", {"metric_name": "response_time", "value": 250}) @@ -214,27 +217,28 @@ def __init__(self, name, arguments=None): # pylint: disable=no-self-use class TestInstrumentedMCPServer(unittest.TestCase): """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" - def setUp(self): # pylint: disable=no-self-use + def setUp(self) -> None: self.instrumentor = MCPInstrumentor() # Initialize tracer so the instrumentor can work mock_tracer = MagicMock() self.instrumentor.tracer = mock_tracer - def test_no_trace_context_fallback(self): # pylint: disable=no-self-use + def test_no_trace_context_fallback(self) -> None: """Test graceful handling when no trace context is present on server side""" class MockServerNoTrace: - async def _handle_request(self, session, request): # pylint: disable=no-self-use + @staticmethod + async def _handle_request(session: Any, request: Any) -> Dict[str, Any]: return {"success": True, "handled_without_trace": True} class MockServerRequestNoTrace: - def __init__(self, tool_name): # pylint: disable=no-self-use + def __init__(self, tool_name: str) -> None: self.params = MockServerRequestParamsNoTrace(tool_name) class MockServerRequestParamsNoTrace: - def __init__(self, name): # pylint: disable=no-self-use + def __init__(self, name: str) -> None: self.name = name - self.meta = None # No trace context + self.meta: Optional[Any] = None # No trace context mock_server = MockServerNoTrace() server_request = MockServerRequestNoTrace("create_metric") @@ -248,8 +252,8 @@ def __init__(self, name): # pylint: disable=no-self-use # Test server handling without trace context (fallback scenario) with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( "sys.modules", {"mcp.types": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_span_name", return_value="tools/create_metric" + ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" ): result = asyncio.run( @@ -264,17 +268,23 @@ def __init__(self, name): # pylint: disable=no-self-use # Should not create traced spans when no trace context is present mock_tracer.start_as_current_span.assert_not_called() - def test_end_to_end_client_server_communication(self): # pylint: disable=too-many-locals,too-many-statements + def test_end_to_end_client_server_communication( + self, + ) -> None: """Test where server actually receives what client sends (including injected trace context)""" # Create realistic request/response classes class MCPRequest: - def __init__(self, tool_name, arguments=None, method="call_tool"): # pylint: disable=no-self-use + def __init__( + self, tool_name: str, arguments: Optional[Dict[str, Any]] = None, method: str = "call_tool" + ) -> None: self.root = self self.params = MCPRequestParams(tool_name, arguments) self.method = method - def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: disable=no-self-use + def model_dump( + self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True + ) -> Dict[str, Any]: result = {"method": self.method, "params": {"name": self.params.name}} if self.params.arguments: result["params"]["arguments"] = self.params.arguments @@ -284,7 +294,7 @@ def model_dump(self, by_alias=True, mode="json", exclude_none=True): # pylint: return result @classmethod - def model_validate(cls, data): # pylint: disable=no-self-use + def model_validate(cls, data: Dict[str, Any]) -> "MCPRequest": method = data.get("method", "call_tool") instance = cls(data["params"]["name"], data["params"].get("arguments"), method) # Restore _meta field if present @@ -293,19 +303,19 @@ def model_validate(cls, data): # pylint: disable=no-self-use return instance class MCPRequestParams: - def __init__(self, name, arguments=None): # pylint: disable=no-self-use + def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: self.name = name self.arguments = arguments - self._meta = None + self._meta: Optional[Dict[str, Any]] = None class MCPServerRequest: - def __init__(self, client_request_data): # pylint: disable=no-self-use + def __init__(self, client_request_data: Dict[str, Any]) -> None: """Server request created from client's serialized data""" self.method = client_request_data.get("method", "call_tool") self.params = MCPServerRequestParams(client_request_data["params"]) class MCPServerRequestParams: - def __init__(self, params_data): # pylint: disable=no-self-use + def __init__(self, params_data: Dict[str, Any]) -> None: self.name = params_data["name"] self.arguments = params_data.get("arguments") # Extract traceparent from _meta if present @@ -315,16 +325,16 @@ def __init__(self, params_data): # pylint: disable=no-self-use self.meta = None class MCPServerRequestMeta: - def __init__(self, traceparent): # pylint: disable=no-self-use + def __init__(self, traceparent: str) -> None: self.traceparent = traceparent # Mock client and server that actually communicate class EndToEndMCPSystem: - def __init__(self): # pylint: disable=no-self-use - self.communication_log = [] - self.last_sent_request = None + def __init__(self) -> None: + self.communication_log: List[str] = [] + self.last_sent_request: Optional[Any] = None - async def client_send_request(self, request): # pylint: disable=no-self-use + async def client_send_request(self, request: Any) -> Dict[str, Any]: """Client sends request - captures what gets sent""" self.communication_log.append("CLIENT: Preparing to send request") self.last_sent_request = request # Capture the modified request @@ -336,7 +346,7 @@ async def client_send_request(self, request): # pylint: disable=no-self-use # Return client response return {"success": True, "client_response": "Request sent successfully"} - async def server_handle_request(self, session, server_request): # pylint: disable=no-self-use + async def server_handle_request(self, session: Any, server_request: Any) -> Dict[str, Any]: """Server handles the request it received""" self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") @@ -377,7 +387,7 @@ async def server_handle_request(self, session, server_request): # pylint: disab # STEP 1: Client sends request through instrumentation with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( "sys.modules", {"mcp.types": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"): + ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): # Override the setup tracer with the properly mocked one self.instrumentor.tracer = mock_tracer @@ -413,8 +423,8 @@ async def server_handle_request(self, session, server_request): # pylint: disab # Server processes the request it received with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( "sys.modules", {"mcp.types": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "handle_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_span_name", return_value="tools/create_metric" + ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" ): server_result = asyncio.run( From f54609e3d6edf5608f619814d1d3df2e0714e108 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Fri, 25 Jul 2025 15:48:50 -0700 Subject: [PATCH 12/41] Working Version of Cleaned Up Code --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 15 ++++++++++----- .../distro/mcpinstrumentor/package.py | 3 --- 2 files changed, 10 insertions(+), 8 deletions(-) delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index b3a85d894..d6df3e1f2 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -10,13 +10,13 @@ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap -from .package import _instruments +_instruments = ("mcp >= 1.6.0",) -def setup_logger_two(): - logger = logging.getLogger("loggertwo") +def setup_logger(): + logger = logging.getLogger("logger") logger.setLevel(logging.DEBUG) - handler = logging.FileHandler("loggertwo.log", mode="w") + handler = logging.FileHandler("logger.log", mode="w") handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) @@ -146,6 +146,9 @@ def _generate_mcp_attributes(self, span: trace.Span, request: ClientRequest, is_ elif isinstance(request, types.CallToolRequest): operation = request.params.name span.set_attribute("mcp.call_tool", True) + elif isinstance(request, types.InitializeRequest): + operation = "Initialize" + span.set_attribute("mcp.initialize", True) if is_client: self._add_client_attributes(span, operation, request) else: @@ -191,13 +194,15 @@ def _get_mcp_operation(req: ClientRequest) -> str: span_name = "tools/list" elif isinstance(req, types.CallToolRequest): span_name = f"tools/{req.params.name}" + elif isinstance(req, types.InitializeRequest): + span_name = "tools/initialize" return span_name @staticmethod def _add_client_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: import os # pylint: disable=import-outside-toplevel - service_name = os.environ.get("MCP_SERVICE_NAME", "Generic MCP Server") + service_name = os.environ.get("MCP_INSTRUMENTATION_SERVER_NAME", "mcp server") span.set_attribute("aws.remote.service", service_name) span.set_attribute("aws.remote.operation", operation) if hasattr(request, "params") and request.params and hasattr(request.params, "name"): diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py deleted file mode 100644 index 2694dbbb5..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/package.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -_instruments = ("mcp >= 1.6.0",) From f8e71720c8dbc21e0a4923b75aa027c878fee446 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 29 Jul 2025 11:44:54 -0700 Subject: [PATCH 13/41] Contract Test --- aws-opentelemetry-distro/pyproject.toml | 114 +++++++++--------- .../amazon/opentelemetry/distro/loggertwo.log | 0 .../images/applications/mcp/Dockerfile | 21 ++++ .../images/applications/mcp/client.py | 44 +++++++ .../images/applications/mcp/mcp_server.py | 16 +++ .../images/applications/mcp/pyproject.toml | 10 ++ .../images/applications/mcp/requirements.txt | 2 + .../tests/test/amazon/mcp/mcp_test.py | 112 +++++++++++++++++ 8 files changed, 261 insertions(+), 58 deletions(-) delete mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/loggertwo.log create mode 100644 contract-tests/images/applications/mcp/Dockerfile create mode 100644 contract-tests/images/applications/mcp/client.py create mode 100644 contract-tests/images/applications/mcp/mcp_server.py create mode 100644 contract-tests/images/applications/mcp/pyproject.toml create mode 100644 contract-tests/images/applications/mcp/requirements.txt create mode 100644 contract-tests/tests/test/amazon/mcp/mcp_test.py diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 947d0e22f..d4f9ca204 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "AWS OpenTelemetry Python Distro" readme = "README.rst" license = "Apache-2.0" -requires-python = ">=3.9" +requires-python = ">=3.8" authors = [ { name = "Amazon Web Services" }, ] @@ -18,70 +18,68 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", ] dependencies = [ - "opentelemetry-api == 1.33.1", - "opentelemetry-sdk == 1.33.1", - "opentelemetry-exporter-otlp-proto-grpc == 1.33.1", - "opentelemetry-exporter-otlp-proto-http == 1.33.1", - "opentelemetry-propagator-b3 == 1.33.1", - "opentelemetry-propagator-jaeger == 1.33.1", - "opentelemetry-exporter-otlp-proto-common == 1.33.1", + "opentelemetry-api == 1.27.0", + "opentelemetry-sdk == 1.27.0", + "opentelemetry-exporter-otlp-proto-grpc == 1.27.0", + "opentelemetry-exporter-otlp-proto-http == 1.27.0", + "opentelemetry-propagator-b3 == 1.27.0", + "opentelemetry-propagator-jaeger == 1.27.0", + "opentelemetry-exporter-otlp-proto-common == 1.27.0", "opentelemetry-sdk-extension-aws == 2.0.2", "opentelemetry-propagator-aws-xray == 1.0.1", - "opentelemetry-distro == 0.54b1", - "opentelemetry-processor-baggage == 0.54b1", - "opentelemetry-propagator-ot-trace == 0.54b1", - "opentelemetry-instrumentation == 0.54b1", - "opentelemetry-instrumentation-aws-lambda == 0.54b1", - "opentelemetry-instrumentation-aio-pika == 0.54b1", - "opentelemetry-instrumentation-aiohttp-client == 0.54b1", - "opentelemetry-instrumentation-aiopg == 0.54b1", - "opentelemetry-instrumentation-asgi == 0.54b1", - "opentelemetry-instrumentation-asyncpg == 0.54b1", - "opentelemetry-instrumentation-boto == 0.54b1", - "opentelemetry-instrumentation-boto3sqs == 0.54b1", - "opentelemetry-instrumentation-botocore == 0.54b1", - "opentelemetry-instrumentation-celery == 0.54b1", - "opentelemetry-instrumentation-confluent-kafka == 0.54b1", - "opentelemetry-instrumentation-dbapi == 0.54b1", - "opentelemetry-instrumentation-django == 0.54b1", - "opentelemetry-instrumentation-elasticsearch == 0.54b1", - "opentelemetry-instrumentation-falcon == 0.54b1", - "opentelemetry-instrumentation-fastapi == 0.54b1", - "opentelemetry-instrumentation-flask == 0.54b1", - "opentelemetry-instrumentation-grpc == 0.54b1", - "opentelemetry-instrumentation-httpx == 0.54b1", - "opentelemetry-instrumentation-jinja2 == 0.54b1", - "opentelemetry-instrumentation-kafka-python == 0.54b1", - "opentelemetry-instrumentation-logging == 0.54b1", - "opentelemetry-instrumentation-mysql == 0.54b1", - "opentelemetry-instrumentation-mysqlclient == 0.54b1", - "opentelemetry-instrumentation-pika == 0.54b1", - "opentelemetry-instrumentation-psycopg2 == 0.54b1", - "opentelemetry-instrumentation-pymemcache == 0.54b1", - "opentelemetry-instrumentation-pymongo == 0.54b1", - "opentelemetry-instrumentation-pymysql == 0.54b1", - "opentelemetry-instrumentation-pyramid == 0.54b1", - "opentelemetry-instrumentation-redis == 0.54b1", - "opentelemetry-instrumentation-remoulade == 0.54b1", - "opentelemetry-instrumentation-requests == 0.54b1", - "opentelemetry-instrumentation-sqlalchemy == 0.54b1", - "opentelemetry-instrumentation-sqlite3 == 0.54b1", - "opentelemetry-instrumentation-starlette == 0.54b1", - "opentelemetry-instrumentation-system-metrics == 0.54b1", - "opentelemetry-instrumentation-tornado == 0.54b1", - "opentelemetry-instrumentation-tortoiseorm == 0.54b1", - "opentelemetry-instrumentation-urllib == 0.54b1", - "opentelemetry-instrumentation-urllib3 == 0.54b1", - "opentelemetry-instrumentation-wsgi == 0.54b1", - "opentelemetry-instrumentation-cassandra == 0.54b1", + "opentelemetry-distro == 0.48b0", + "opentelemetry-propagator-ot-trace == 0.48b0", + "opentelemetry-instrumentation == 0.48b0", + "opentelemetry-instrumentation-aws-lambda == 0.48b0", + "opentelemetry-instrumentation-aio-pika == 0.48b0", + "opentelemetry-instrumentation-aiohttp-client == 0.48b0", + "opentelemetry-instrumentation-aiopg == 0.48b0", + "opentelemetry-instrumentation-asgi == 0.48b0", + "opentelemetry-instrumentation-asyncpg == 0.48b0", + "opentelemetry-instrumentation-boto == 0.48b0", + "opentelemetry-instrumentation-boto3sqs == 0.48b0", + "opentelemetry-instrumentation-botocore == 0.48b0", + "opentelemetry-instrumentation-celery == 0.48b0", + "opentelemetry-instrumentation-confluent-kafka == 0.48b0", + "opentelemetry-instrumentation-dbapi == 0.48b0", + "opentelemetry-instrumentation-django == 0.48b0", + "opentelemetry-instrumentation-elasticsearch == 0.48b0", + "opentelemetry-instrumentation-falcon == 0.48b0", + "opentelemetry-instrumentation-fastapi == 0.48b0", + "opentelemetry-instrumentation-flask == 0.48b0", + "opentelemetry-instrumentation-grpc == 0.48b0", + "opentelemetry-instrumentation-httpx == 0.48b0", + "opentelemetry-instrumentation-jinja2 == 0.48b0", + "opentelemetry-instrumentation-kafka-python == 0.48b0", + "opentelemetry-instrumentation-logging == 0.48b0", + "opentelemetry-instrumentation-mysql == 0.48b0", + "opentelemetry-instrumentation-mysqlclient == 0.48b0", + "opentelemetry-instrumentation-pika == 0.48b0", + "opentelemetry-instrumentation-psycopg2 == 0.48b0", + "opentelemetry-instrumentation-pymemcache == 0.48b0", + "opentelemetry-instrumentation-pymongo == 0.48b0", + "opentelemetry-instrumentation-pymysql == 0.48b0", + "opentelemetry-instrumentation-pyramid == 0.48b0", + "opentelemetry-instrumentation-redis == 0.48b0", + "opentelemetry-instrumentation-remoulade == 0.48b0", + "opentelemetry-instrumentation-requests == 0.48b0", + "opentelemetry-instrumentation-sqlalchemy == 0.48b0", + "opentelemetry-instrumentation-sqlite3 == 0.48b0", + "opentelemetry-instrumentation-starlette == 0.48b0", + "opentelemetry-instrumentation-system-metrics == 0.48b0", + "opentelemetry-instrumentation-tornado == 0.48b0", + "opentelemetry-instrumentation-tortoiseorm == 0.48b0", + "opentelemetry-instrumentation-urllib == 0.48b0", + "opentelemetry-instrumentation-urllib3 == 0.48b0", + "opentelemetry-instrumentation-wsgi == 0.48b0", + "opentelemetry-instrumentation-cassandra == 0.48b0", ] [project.optional-dependencies] @@ -111,4 +109,4 @@ include = [ ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] \ No newline at end of file +packages = ["src/amazon"] diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/loggertwo.log b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/loggertwo.log deleted file mode 100644 index e69de29bb..000000000 diff --git a/contract-tests/images/applications/mcp/Dockerfile b/contract-tests/images/applications/mcp/Dockerfile new file mode 100644 index 000000000..8c1792b2d --- /dev/null +++ b/contract-tests/images/applications/mcp/Dockerfile @@ -0,0 +1,21 @@ +# Meant to be run from aws-otel-python-instrumentation/contract-tests. +# Assumes existence of dist/aws_opentelemetry_distro--py3-none-any.whl. +# Assumes filename of aws_opentelemetry_distro--py3-none-any.whl is passed in as "DISTRO" arg. +FROM python:3.10 +WORKDIR /mcp +COPY ./dist/$DISTRO /mcp +COPY ./contract-tests/images/applications/mcp /mcp +COPY ./aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor /mcp/mcpinstrumentor + +ENV PIP_ROOT_USER_ACTION=ignore +ARG DISTRO + +# Install your MCP instrumentor first (use -e for editable install) +RUN pip install --upgrade pip && pip install -e ./mcpinstrumentor/ + +# Then install other requirements and the main distro +RUN pip install -r requirements.txt && pip install ${DISTRO} --force-reinstall +RUN opentelemetry-bootstrap -a install + +# Without `-u`, logs will be buffered and `wait_for_logs` will never return. +CMD ["opentelemetry-instrument", "python3", "-u", "./simple_client.py"] \ No newline at end of file diff --git a/contract-tests/images/applications/mcp/client.py b/contract-tests/images/applications/mcp/client.py new file mode 100644 index 000000000..b67b04757 --- /dev/null +++ b/contract-tests/images/applications/mcp/client.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os +from http.server import BaseHTTPRequestHandler, HTTPServer + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + + +class MCPHandler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/mcp/echo": + asyncio.run(self._call_mcp_tool("echo", {"text": "Hello from HTTP request!"})) + self.send_response(200) + self.end_headers() + else: + self.send_response(404) + self.end_headers() + + async def _call_mcp_tool(self, tool_name, arguments): + server_env = { + "OTEL_PYTHON_DISTRO": "aws_distro", + "OTEL_PYTHON_CONFIGURATOR": "aws_configurator", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", ""), + "OTEL_EXPORTER_OTLP_PROTOCOL": "grpc", + "OTEL_TRACES_SAMPLER": "always_on", + "OTEL_METRICS_EXPORTER": "none", + "OTEL_LOGS_EXPORTER": "none", + } + server_params = StdioServerParameters( + command="opentelemetry-instrument", args=["python3", "simple_mcp_server.py"], env=server_env + ) + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool(tool_name, arguments) + return result + + +if __name__ == "__main__": + print("Ready") + server = HTTPServer(("0.0.0.0", 8080), MCPHandler) + server.serve_forever() diff --git a/contract-tests/images/applications/mcp/mcp_server.py b/contract-tests/images/applications/mcp/mcp_server.py new file mode 100644 index 000000000..24e48bca5 --- /dev/null +++ b/contract-tests/images/applications/mcp/mcp_server.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from fastmcp import FastMCP + +# Create FastMCP server instance +mcp = FastMCP("Simple MCP Server") + + +@mcp.tool(name="echo", description="Echo the provided text") +def echo(text: str) -> str: + """Echo the provided text""" + return f"Echo: {text}" + + +if __name__ == "__main__": + mcp.run() diff --git a/contract-tests/images/applications/mcp/pyproject.toml b/contract-tests/images/applications/mcp/pyproject.toml new file mode 100644 index 000000000..d2b4ed047 --- /dev/null +++ b/contract-tests/images/applications/mcp/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "mcp-client-app" +description = "Simple MCP client that calls echo tool for testing MCP instrumentation" +version = "1.0.0" +license = "Apache-2.0" +requires-python = ">=3.10" +dependencies = [ + "mcp>=1.1.0", + "fastmcp>=0.1.0" +] \ No newline at end of file diff --git a/contract-tests/images/applications/mcp/requirements.txt b/contract-tests/images/applications/mcp/requirements.txt new file mode 100644 index 000000000..f8a56ec2e --- /dev/null +++ b/contract-tests/images/applications/mcp/requirements.txt @@ -0,0 +1,2 @@ +mcp>=1.1.0 +fastmcp>=0.1.0 \ No newline at end of file diff --git a/contract-tests/tests/test/amazon/mcp/mcp_test.py b/contract-tests/tests/test/amazon/mcp/mcp_test.py new file mode 100644 index 000000000..83642f093 --- /dev/null +++ b/contract-tests/tests/test/amazon/mcp/mcp_test.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing_extensions import override + +from amazon.base.contract_test_base import ContractTestBase +from opentelemetry.proto.trace.v1.trace_pb2 import Span + + +class MCPTest(ContractTestBase): + + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-mcp-app" + + def test_mcp_echo_tool(self): + """Test MCP echo tool call creates proper spans""" + self.do_test_requests("mcp/echo", "GET", 200, 0, 0, tool_name="echo") + + @override + def _assert_aws_span_attributes(self, resource_scope_spans, path: str, **kwargs) -> None: + pass + + @override + def _assert_semantic_conventions_span_attributes( + self, resource_scope_spans, method: str, path: str, status_code: int, **kwargs + ) -> None: + + tool_name = kwargs.get("tool_name", "echo") + initialize_client_span = None + list_tools_client_span = None + list_tools_server_span = None + call_tool_client_span = None + call_tool_server_span = None + + for resource_scope_span in resource_scope_spans: + span = resource_scope_span.span + + if span.name == "client.send_request" and span.kind == Span.SPAN_KIND_CLIENT: + for attr in span.attributes: + if attr.key == "mcp.initialize" and attr.value.bool_value: + initialize_client_span = span + break + elif attr.key == "mcp.list_tools" and attr.value.bool_value: + list_tools_client_span = span + break + elif attr.key == "mcp.call_tool" and attr.value.bool_value: + call_tool_client_span = span + break + + elif span.name == "tools/list" and span.kind == Span.SPAN_KIND_SERVER: + list_tools_server_span = span + elif span.name == f"tools/{tool_name}" and span.kind == Span.SPAN_KIND_SERVER: + call_tool_server_span = span + + # Validate initialize client span (no server span expected) + self.assertIsNotNone(initialize_client_span, "Initialize client span not found") + self.assertEqual(initialize_client_span.kind, Span.SPAN_KIND_CLIENT) + + init_attributes = {attr.key: attr.value for attr in initialize_client_span.attributes} + self.assertIn("mcp.initialize", init_attributes) + self.assertTrue(init_attributes["mcp.initialize"].bool_value) + + # Validate list tools client span + self.assertIsNotNone(list_tools_client_span, "List tools client span not found") + self.assertEqual(list_tools_client_span.kind, Span.SPAN_KIND_CLIENT) + + list_client_attributes = {attr.key: attr.value for attr in list_tools_client_span.attributes} + self.assertIn("mcp.list_tools", list_client_attributes) + self.assertTrue(list_client_attributes["mcp.list_tools"].bool_value) + + # Validate list tools server span + self.assertIsNotNone(list_tools_server_span, "List tools server span not found") + self.assertEqual(list_tools_server_span.kind, Span.SPAN_KIND_SERVER) + + list_server_attributes = {attr.key: attr.value for attr in list_tools_server_span.attributes} + self.assertIn("mcp.list_tools", list_server_attributes) + self.assertTrue(list_server_attributes["mcp.list_tools"].bool_value) + + # Validate call tool client span + self.assertIsNotNone(call_tool_client_span, f"Call tool client span for {tool_name} not found") + self.assertEqual(call_tool_client_span.kind, Span.SPAN_KIND_CLIENT) + + call_client_attributes = {attr.key: attr.value for attr in call_tool_client_span.attributes} + self.assertIn("mcp.call_tool", call_client_attributes) + self.assertTrue(call_client_attributes["mcp.call_tool"].bool_value) + self.assertIn("aws.remote.operation", call_client_attributes) + self.assertEqual(call_client_attributes["aws.remote.operation"].string_value, tool_name) + + # Validate call tool server span + self.assertIsNotNone(call_tool_server_span, f"Call tool server span for {tool_name} not found") + self.assertEqual(call_tool_server_span.kind, Span.SPAN_KIND_SERVER) + + call_server_attributes = {attr.key: attr.value for attr in call_tool_server_span.attributes} + self.assertIn("mcp.call_tool", call_server_attributes) + self.assertTrue(call_server_attributes["mcp.call_tool"].bool_value) + + # Validate distributed tracing for paired spans + self.assertEqual( + list_tools_server_span.trace_id, + list_tools_client_span.trace_id, + "List tools client and server spans should have the same trace ID", + ) + self.assertEqual( + call_tool_server_span.trace_id, + call_tool_client_span.trace_id, + "Call tool client and server spans should have the same trace ID", + ) + + @override + def _assert_metric_attributes(self, resource_scope_metrics, metric_name: str, expected_sum: int, **kwargs) -> None: + pass From 1a9ce26bbb48c3aad156e80b14a32d831594b142 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 29 Jul 2025 13:31:01 -0700 Subject: [PATCH 14/41] Restore pyproject.toml file --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 14 --- .../distro/mcpinstrumentor/pyproject.toml | 104 ++++++++++++++---- 2 files changed, 84 insertions(+), 34 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index d6df3e1f2..32c55efd2 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -11,20 +11,6 @@ from opentelemetry.instrumentation.utils import unwrap _instruments = ("mcp >= 1.6.0",) - - -def setup_logger(): - logger = logging.getLogger("logger") - logger.setLevel(logging.DEBUG) - handler = logging.FileHandler("logger.log", mode="w") - handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - if not logger.handlers: - logger.addHandler(handler) - return logger - - class MCPInstrumentor(BaseInstrumentor): """ An instrumenter for MCP. diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml index 4d237fa4c..2fa1f36dd 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml @@ -3,14 +3,14 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "amazon-opentelemetry-distro-mcpinstrumentor" -version = "0.1.0" -description = "OpenTelemetry MCP instrumentation for AWS Distro" -readme = "README.md" +name = "aws-opentelemetry-distro" +dynamic = ["version"] +description = "AWS OpenTelemetry Python Distro" +readme = "README.rst" license = "Apache-2.0" -requires-python = ">=3.10" +requires-python = ">=3.8" authors = [ - { name = "Johnny Lin", email = "jonzilin@amazon.com" }, + { name = "Amazon Web Services" }, ] classifiers = [ "Development Status :: 4 - Beta", @@ -18,31 +18,95 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", ] + dependencies = [ - "mcp", - "opentelemetry-api", - "opentelemetry-instrumentation", - "opentelemetry-sdk", - "opentelemetry-semantic-conventions", - "wrapt" + "opentelemetry-api == 1.27.0", + "opentelemetry-sdk == 1.27.0", + "opentelemetry-exporter-otlp-proto-grpc == 1.27.0", + "opentelemetry-exporter-otlp-proto-http == 1.27.0", + "opentelemetry-propagator-b3 == 1.27.0", + "opentelemetry-propagator-jaeger == 1.27.0", + "opentelemetry-exporter-otlp-proto-common == 1.27.0", + "opentelemetry-sdk-extension-aws == 2.0.2", + "opentelemetry-propagator-aws-xray == 1.0.1", + "opentelemetry-distro == 0.48b0", + "opentelemetry-propagator-ot-trace == 0.48b0", + "opentelemetry-instrumentation == 0.48b0", + "opentelemetry-instrumentation-aws-lambda == 0.48b0", + "opentelemetry-instrumentation-aio-pika == 0.48b0", + "opentelemetry-instrumentation-aiohttp-client == 0.48b0", + "opentelemetry-instrumentation-aiopg == 0.48b0", + "opentelemetry-instrumentation-asgi == 0.48b0", + "opentelemetry-instrumentation-asyncpg == 0.48b0", + "opentelemetry-instrumentation-boto == 0.48b0", + "opentelemetry-instrumentation-boto3sqs == 0.48b0", + "opentelemetry-instrumentation-botocore == 0.48b0", + "opentelemetry-instrumentation-celery == 0.48b0", + "opentelemetry-instrumentation-confluent-kafka == 0.48b0", + "opentelemetry-instrumentation-dbapi == 0.48b0", + "opentelemetry-instrumentation-django == 0.48b0", + "opentelemetry-instrumentation-elasticsearch == 0.48b0", + "opentelemetry-instrumentation-falcon == 0.48b0", + "opentelemetry-instrumentation-fastapi == 0.48b0", + "opentelemetry-instrumentation-flask == 0.48b0", + "opentelemetry-instrumentation-grpc == 0.48b0", + "opentelemetry-instrumentation-httpx == 0.48b0", + "opentelemetry-instrumentation-jinja2 == 0.48b0", + "opentelemetry-instrumentation-kafka-python == 0.48b0", + "opentelemetry-instrumentation-logging == 0.48b0", + "opentelemetry-instrumentation-mysql == 0.48b0", + "opentelemetry-instrumentation-mysqlclient == 0.48b0", + "opentelemetry-instrumentation-pika == 0.48b0", + "opentelemetry-instrumentation-psycopg2 == 0.48b0", + "opentelemetry-instrumentation-pymemcache == 0.48b0", + "opentelemetry-instrumentation-pymongo == 0.48b0", + "opentelemetry-instrumentation-pymysql == 0.48b0", + "opentelemetry-instrumentation-pyramid == 0.48b0", + "opentelemetry-instrumentation-redis == 0.48b0", + "opentelemetry-instrumentation-remoulade == 0.48b0", + "opentelemetry-instrumentation-requests == 0.48b0", + "opentelemetry-instrumentation-sqlalchemy == 0.48b0", + "opentelemetry-instrumentation-sqlite3 == 0.48b0", + "opentelemetry-instrumentation-starlette == 0.48b0", + "opentelemetry-instrumentation-system-metrics == 0.48b0", + "opentelemetry-instrumentation-tornado == 0.48b0", + "opentelemetry-instrumentation-tortoiseorm == 0.48b0", + "opentelemetry-instrumentation-urllib == 0.48b0", + "opentelemetry-instrumentation-urllib3 == 0.48b0", + "opentelemetry-instrumentation-wsgi == 0.48b0", + "opentelemetry-instrumentation-cassandra == 0.48b0", ] [project.optional-dependencies] -instruments = ["mcp"] +# The 'patch' optional dependency is used for applying patches to specific libraries. +# If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch +patch = [ + "botocore ~= 1.0", +] +test = [] + +[project.entry-points.opentelemetry_configurator] +aws_configurator = "amazon.opentelemetry.distro.aws_opentelemetry_configurator:AwsOpenTelemetryConfigurator" + +[project.entry-points.opentelemetry_distro] +aws_distro = "amazon.opentelemetry.distro.aws_opentelemetry_distro:AwsOpenTelemetryDistro" + +[project.urls] +Homepage = "https://github.com/aws-observability/aws-otel-python-instrumentation/tree/main/aws-opentelemetry-distro" -[project.entry-points.opentelemetry_instrumentor] -mcp = "mcpinstrumentor:MCPInstrumentor" +[tool.hatch.version] +path = "src/amazon/opentelemetry/distro/version.py" [tool.hatch.build.targets.sdist] include = [ - "mcpinstrumentor.py", - "README.md" + "/src", + "/tests", ] [tool.hatch.build.targets.wheel] -packages = ["."] \ No newline at end of file +packages = ["src/amazon"] \ No newline at end of file From fb9d1c70d94c323559eb44cebf4767a8870605b1 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 29 Jul 2025 13:41:48 -0700 Subject: [PATCH 15/41] pyproject.toml --- aws-opentelemetry-distro/pyproject.toml | 114 ++++++++++++------------ 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index d4f9ca204..947d0e22f 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "AWS OpenTelemetry Python Distro" readme = "README.rst" license = "Apache-2.0" -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ { name = "Amazon Web Services" }, ] @@ -18,68 +18,70 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ - "opentelemetry-api == 1.27.0", - "opentelemetry-sdk == 1.27.0", - "opentelemetry-exporter-otlp-proto-grpc == 1.27.0", - "opentelemetry-exporter-otlp-proto-http == 1.27.0", - "opentelemetry-propagator-b3 == 1.27.0", - "opentelemetry-propagator-jaeger == 1.27.0", - "opentelemetry-exporter-otlp-proto-common == 1.27.0", + "opentelemetry-api == 1.33.1", + "opentelemetry-sdk == 1.33.1", + "opentelemetry-exporter-otlp-proto-grpc == 1.33.1", + "opentelemetry-exporter-otlp-proto-http == 1.33.1", + "opentelemetry-propagator-b3 == 1.33.1", + "opentelemetry-propagator-jaeger == 1.33.1", + "opentelemetry-exporter-otlp-proto-common == 1.33.1", "opentelemetry-sdk-extension-aws == 2.0.2", "opentelemetry-propagator-aws-xray == 1.0.1", - "opentelemetry-distro == 0.48b0", - "opentelemetry-propagator-ot-trace == 0.48b0", - "opentelemetry-instrumentation == 0.48b0", - "opentelemetry-instrumentation-aws-lambda == 0.48b0", - "opentelemetry-instrumentation-aio-pika == 0.48b0", - "opentelemetry-instrumentation-aiohttp-client == 0.48b0", - "opentelemetry-instrumentation-aiopg == 0.48b0", - "opentelemetry-instrumentation-asgi == 0.48b0", - "opentelemetry-instrumentation-asyncpg == 0.48b0", - "opentelemetry-instrumentation-boto == 0.48b0", - "opentelemetry-instrumentation-boto3sqs == 0.48b0", - "opentelemetry-instrumentation-botocore == 0.48b0", - "opentelemetry-instrumentation-celery == 0.48b0", - "opentelemetry-instrumentation-confluent-kafka == 0.48b0", - "opentelemetry-instrumentation-dbapi == 0.48b0", - "opentelemetry-instrumentation-django == 0.48b0", - "opentelemetry-instrumentation-elasticsearch == 0.48b0", - "opentelemetry-instrumentation-falcon == 0.48b0", - "opentelemetry-instrumentation-fastapi == 0.48b0", - "opentelemetry-instrumentation-flask == 0.48b0", - "opentelemetry-instrumentation-grpc == 0.48b0", - "opentelemetry-instrumentation-httpx == 0.48b0", - "opentelemetry-instrumentation-jinja2 == 0.48b0", - "opentelemetry-instrumentation-kafka-python == 0.48b0", - "opentelemetry-instrumentation-logging == 0.48b0", - "opentelemetry-instrumentation-mysql == 0.48b0", - "opentelemetry-instrumentation-mysqlclient == 0.48b0", - "opentelemetry-instrumentation-pika == 0.48b0", - "opentelemetry-instrumentation-psycopg2 == 0.48b0", - "opentelemetry-instrumentation-pymemcache == 0.48b0", - "opentelemetry-instrumentation-pymongo == 0.48b0", - "opentelemetry-instrumentation-pymysql == 0.48b0", - "opentelemetry-instrumentation-pyramid == 0.48b0", - "opentelemetry-instrumentation-redis == 0.48b0", - "opentelemetry-instrumentation-remoulade == 0.48b0", - "opentelemetry-instrumentation-requests == 0.48b0", - "opentelemetry-instrumentation-sqlalchemy == 0.48b0", - "opentelemetry-instrumentation-sqlite3 == 0.48b0", - "opentelemetry-instrumentation-starlette == 0.48b0", - "opentelemetry-instrumentation-system-metrics == 0.48b0", - "opentelemetry-instrumentation-tornado == 0.48b0", - "opentelemetry-instrumentation-tortoiseorm == 0.48b0", - "opentelemetry-instrumentation-urllib == 0.48b0", - "opentelemetry-instrumentation-urllib3 == 0.48b0", - "opentelemetry-instrumentation-wsgi == 0.48b0", - "opentelemetry-instrumentation-cassandra == 0.48b0", + "opentelemetry-distro == 0.54b1", + "opentelemetry-processor-baggage == 0.54b1", + "opentelemetry-propagator-ot-trace == 0.54b1", + "opentelemetry-instrumentation == 0.54b1", + "opentelemetry-instrumentation-aws-lambda == 0.54b1", + "opentelemetry-instrumentation-aio-pika == 0.54b1", + "opentelemetry-instrumentation-aiohttp-client == 0.54b1", + "opentelemetry-instrumentation-aiopg == 0.54b1", + "opentelemetry-instrumentation-asgi == 0.54b1", + "opentelemetry-instrumentation-asyncpg == 0.54b1", + "opentelemetry-instrumentation-boto == 0.54b1", + "opentelemetry-instrumentation-boto3sqs == 0.54b1", + "opentelemetry-instrumentation-botocore == 0.54b1", + "opentelemetry-instrumentation-celery == 0.54b1", + "opentelemetry-instrumentation-confluent-kafka == 0.54b1", + "opentelemetry-instrumentation-dbapi == 0.54b1", + "opentelemetry-instrumentation-django == 0.54b1", + "opentelemetry-instrumentation-elasticsearch == 0.54b1", + "opentelemetry-instrumentation-falcon == 0.54b1", + "opentelemetry-instrumentation-fastapi == 0.54b1", + "opentelemetry-instrumentation-flask == 0.54b1", + "opentelemetry-instrumentation-grpc == 0.54b1", + "opentelemetry-instrumentation-httpx == 0.54b1", + "opentelemetry-instrumentation-jinja2 == 0.54b1", + "opentelemetry-instrumentation-kafka-python == 0.54b1", + "opentelemetry-instrumentation-logging == 0.54b1", + "opentelemetry-instrumentation-mysql == 0.54b1", + "opentelemetry-instrumentation-mysqlclient == 0.54b1", + "opentelemetry-instrumentation-pika == 0.54b1", + "opentelemetry-instrumentation-psycopg2 == 0.54b1", + "opentelemetry-instrumentation-pymemcache == 0.54b1", + "opentelemetry-instrumentation-pymongo == 0.54b1", + "opentelemetry-instrumentation-pymysql == 0.54b1", + "opentelemetry-instrumentation-pyramid == 0.54b1", + "opentelemetry-instrumentation-redis == 0.54b1", + "opentelemetry-instrumentation-remoulade == 0.54b1", + "opentelemetry-instrumentation-requests == 0.54b1", + "opentelemetry-instrumentation-sqlalchemy == 0.54b1", + "opentelemetry-instrumentation-sqlite3 == 0.54b1", + "opentelemetry-instrumentation-starlette == 0.54b1", + "opentelemetry-instrumentation-system-metrics == 0.54b1", + "opentelemetry-instrumentation-tornado == 0.54b1", + "opentelemetry-instrumentation-tortoiseorm == 0.54b1", + "opentelemetry-instrumentation-urllib == 0.54b1", + "opentelemetry-instrumentation-urllib3 == 0.54b1", + "opentelemetry-instrumentation-wsgi == 0.54b1", + "opentelemetry-instrumentation-cassandra == 0.54b1", ] [project.optional-dependencies] @@ -109,4 +111,4 @@ include = [ ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] +packages = ["src/amazon"] \ No newline at end of file From 7b9a8f4493d7535743b3bdfe7efa573c6686f2cc Mon Sep 17 00:00:00 2001 From: Johnnyl202 <143136129+Johnnyl202@users.noreply.github.com> Date: Tue, 29 Jul 2025 13:44:59 -0700 Subject: [PATCH 16/41] Update pyproject.toml --- aws-opentelemetry-distro/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 947d0e22f..414b09221 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -111,4 +111,4 @@ include = [ ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] \ No newline at end of file +packages = ["src/amazon"] From de0ec1de9ed87f7e2b5f19d045698d6e97e1984d Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 29 Jul 2025 14:49:13 -0700 Subject: [PATCH 17/41] Fixed pyproject.toml for mcpinstrumentor --- .../distro/mcpinstrumentor/mcpinstrumentor.py | 3 +- .../distro/mcpinstrumentor/pyproject.toml | 103 ++++-------------- 2 files changed, 22 insertions(+), 84 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 32c55efd2..747bd7ff2 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -1,6 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import logging from typing import Any, Callable, Collection, Dict, Tuple from mcp import ClientRequest @@ -11,6 +10,8 @@ from opentelemetry.instrumentation.utils import unwrap _instruments = ("mcp >= 1.6.0",) + + class MCPInstrumentor(BaseInstrumentor): """ An instrumenter for MCP. diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml index 2fa1f36dd..0ad7a8877 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml @@ -3,14 +3,14 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "aws-opentelemetry-distro" -dynamic = ["version"] -description = "AWS OpenTelemetry Python Distro" -readme = "README.rst" +name = "amazon-opentelemetry-distro-mcpinstrumentor" +version = "0.1.0" +description = "OpenTelemetry MCP instrumentation for AWS Distro" +readme = "README.md" license = "Apache-2.0" -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ - { name = "Amazon Web Services" }, + { name = "Johnny Lin", email = "jonzilin@amazon.com" }, ] classifiers = [ "Development Status :: 4 - Beta", @@ -18,95 +18,32 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] - dependencies = [ - "opentelemetry-api == 1.27.0", - "opentelemetry-sdk == 1.27.0", - "opentelemetry-exporter-otlp-proto-grpc == 1.27.0", - "opentelemetry-exporter-otlp-proto-http == 1.27.0", - "opentelemetry-propagator-b3 == 1.27.0", - "opentelemetry-propagator-jaeger == 1.27.0", - "opentelemetry-exporter-otlp-proto-common == 1.27.0", - "opentelemetry-sdk-extension-aws == 2.0.2", - "opentelemetry-propagator-aws-xray == 1.0.1", - "opentelemetry-distro == 0.48b0", - "opentelemetry-propagator-ot-trace == 0.48b0", - "opentelemetry-instrumentation == 0.48b0", - "opentelemetry-instrumentation-aws-lambda == 0.48b0", - "opentelemetry-instrumentation-aio-pika == 0.48b0", - "opentelemetry-instrumentation-aiohttp-client == 0.48b0", - "opentelemetry-instrumentation-aiopg == 0.48b0", - "opentelemetry-instrumentation-asgi == 0.48b0", - "opentelemetry-instrumentation-asyncpg == 0.48b0", - "opentelemetry-instrumentation-boto == 0.48b0", - "opentelemetry-instrumentation-boto3sqs == 0.48b0", - "opentelemetry-instrumentation-botocore == 0.48b0", - "opentelemetry-instrumentation-celery == 0.48b0", - "opentelemetry-instrumentation-confluent-kafka == 0.48b0", - "opentelemetry-instrumentation-dbapi == 0.48b0", - "opentelemetry-instrumentation-django == 0.48b0", - "opentelemetry-instrumentation-elasticsearch == 0.48b0", - "opentelemetry-instrumentation-falcon == 0.48b0", - "opentelemetry-instrumentation-fastapi == 0.48b0", - "opentelemetry-instrumentation-flask == 0.48b0", - "opentelemetry-instrumentation-grpc == 0.48b0", - "opentelemetry-instrumentation-httpx == 0.48b0", - "opentelemetry-instrumentation-jinja2 == 0.48b0", - "opentelemetry-instrumentation-kafka-python == 0.48b0", - "opentelemetry-instrumentation-logging == 0.48b0", - "opentelemetry-instrumentation-mysql == 0.48b0", - "opentelemetry-instrumentation-mysqlclient == 0.48b0", - "opentelemetry-instrumentation-pika == 0.48b0", - "opentelemetry-instrumentation-psycopg2 == 0.48b0", - "opentelemetry-instrumentation-pymemcache == 0.48b0", - "opentelemetry-instrumentation-pymongo == 0.48b0", - "opentelemetry-instrumentation-pymysql == 0.48b0", - "opentelemetry-instrumentation-pyramid == 0.48b0", - "opentelemetry-instrumentation-redis == 0.48b0", - "opentelemetry-instrumentation-remoulade == 0.48b0", - "opentelemetry-instrumentation-requests == 0.48b0", - "opentelemetry-instrumentation-sqlalchemy == 0.48b0", - "opentelemetry-instrumentation-sqlite3 == 0.48b0", - "opentelemetry-instrumentation-starlette == 0.48b0", - "opentelemetry-instrumentation-system-metrics == 0.48b0", - "opentelemetry-instrumentation-tornado == 0.48b0", - "opentelemetry-instrumentation-tortoiseorm == 0.48b0", - "opentelemetry-instrumentation-urllib == 0.48b0", - "opentelemetry-instrumentation-urllib3 == 0.48b0", - "opentelemetry-instrumentation-wsgi == 0.48b0", - "opentelemetry-instrumentation-cassandra == 0.48b0", + "mcp", + "opentelemetry-api", + "opentelemetry-instrumentation", + "opentelemetry-semantic-conventions", + "wrapt", + "opentelemetry-sdk", ] [project.optional-dependencies] -# The 'patch' optional dependency is used for applying patches to specific libraries. -# If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch -patch = [ - "botocore ~= 1.0", -] -test = [] - -[project.entry-points.opentelemetry_configurator] -aws_configurator = "amazon.opentelemetry.distro.aws_opentelemetry_configurator:AwsOpenTelemetryConfigurator" - -[project.entry-points.opentelemetry_distro] -aws_distro = "amazon.opentelemetry.distro.aws_opentelemetry_distro:AwsOpenTelemetryDistro" - -[project.urls] -Homepage = "https://github.com/aws-observability/aws-otel-python-instrumentation/tree/main/aws-opentelemetry-distro" +instruments = ["mcp"] -[tool.hatch.version] -path = "src/amazon/opentelemetry/distro/version.py" +[project.entry-points.opentelemetry_instrumentor] +mcp = "mcpinstrumentor:MCPInstrumentor" [tool.hatch.build.targets.sdist] include = [ - "/src", - "/tests", + "mcpinstrumentor.py", + "README.md" ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] \ No newline at end of file +packages = ["."] \ No newline at end of file From 43526f66b334f045f2e9b1a90c71e3b47db6ead9 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 29 Jul 2025 14:54:07 -0700 Subject: [PATCH 18/41] Fixed pyproject.toml for mcpinstrumentor --- .../opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py index 747bd7ff2..f1e5fd5bb 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py @@ -11,7 +11,6 @@ _instruments = ("mcp >= 1.6.0",) - class MCPInstrumentor(BaseInstrumentor): """ An instrumenter for MCP. From bd28ccc7a77c7de700c563440160c966f7732f8d Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 31 Jul 2025 13:59:11 -0700 Subject: [PATCH 19/41] adde semconv file, client span name changed, new folder name, mcpinstrumentor as an entry point --- aws-opentelemetry-distro/pyproject.toml | 8 +- .../distro/instrumentation/mcp/README.md | 28 ++++ .../distro/instrumentation/mcp/__init__.py | 2 + .../mcp/mcp_instrumentor.py} | 52 +++--- .../distro/instrumentation/mcp/semconv.py | 155 ++++++++++++++++++ .../mcp}/version.py | 0 .../distro/mcpinstrumentor/README.md | 33 ---- .../distro/mcpinstrumentor/pyproject.toml | 49 ------ .../distro/test_mcpinstrumentor.py | 22 ++- .../images/applications/mcp/Dockerfile | 10 +- .../images/applications/mcp/client.py | 2 +- 11 files changed, 234 insertions(+), 127 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/{mcpinstrumentor/mcpinstrumentor.py => instrumentation/mcp/mcp_instrumentor.py} (79%) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/{mcpinstrumentor => instrumentation/mcp}/version.py (100%) delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 414b09221..df47fda19 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -21,8 +21,6 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", ] dependencies = [ @@ -89,12 +87,16 @@ dependencies = [ # If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch patch = [ "botocore ~= 1.0", + "mcp >= 1.1.0" ] test = [] [project.entry-points.opentelemetry_configurator] aws_configurator = "amazon.opentelemetry.distro.aws_opentelemetry_configurator:AwsOpenTelemetryConfigurator" +[project.entry-points.opentelemetry_instrumentor] +mcp = "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor:MCPInstrumentor" + [project.entry-points.opentelemetry_distro] aws_distro = "amazon.opentelemetry.distro.aws_opentelemetry_distro:AwsOpenTelemetryDistro" @@ -111,4 +113,4 @@ include = [ ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] +packages = ["src/amazon"] \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md new file mode 100644 index 000000000..893dc2409 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md @@ -0,0 +1,28 @@ +# MCP Instrumentor + +OpenTelemetry instrumentation for Model Context Protocol (MCP). + +## Installation + +Included in AWS OpenTelemetry Distro: + +```bash +pip install aws-opentelemetry-distro +``` + +## Usage + +Automatically enabled with: + +```bash +opentelemetry-instrument python your_mcp_app.py +``` + +## Configuration + +- `MCP_INSTRUMENTATION_SERVER_NAME`: Override default server name (default: "mcp server") + +## Spans Created + +- **Client**: `client.send_request` +- **Server**: `tools/initialize`, `tools/list`, `tools/{tool_name}` \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py new file mode 100644 index 000000000..04f8b7b76 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py similarity index 79% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py index f1e5fd5bb..0c40968c0 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/mcpinstrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py @@ -8,7 +8,7 @@ from opentelemetry import trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap - +from .semconv import MCPAttributes, MCPSpanNames, MCPOperations, MCPTraceContext, MCPEnvironmentVariables _instruments = ("mcp >= 1.6.0",) class MCPInstrumentor(BaseInstrumentor): @@ -27,9 +27,9 @@ def instrumentation_dependencies() -> Collection[str]: def _instrument(self, **kwargs: Any) -> None: tracer_provider = kwargs.get("tracer_provider") if tracer_provider: - self.tracer = tracer_provider.get_tracer("mcp") + self.tracer = tracer_provider.get_tracer("instrumentation.mcp") else: - self.tracer = trace.get_tracer("mcp") + self.tracer = trace.get_tracer("instrumentation.mcp") register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.shared.session", @@ -65,7 +65,7 @@ def _wrap_send_request( """ async def async_wrapper(): - with self.tracer.start_as_current_span("client.send_request", kind=trace.SpanKind.CLIENT) as span: + with self.tracer.start_as_current_span(MCPSpanNames.CLIENT_SEND_REQUEST, kind=trace.SpanKind.CLIENT) as span: span_ctx = span.get_span_context() request = args[0] if len(args) > 0 else kwargs.get("request") if request: @@ -107,7 +107,7 @@ async def _wrap_handle_request( traceparent = None if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - traceparent = getattr(req.params.meta, "traceparent", None) + traceparent = getattr(req.params.meta, MCPTraceContext.TRACEPARENT_HEADER, None) span_context = self._extract_span_context_from_traceparent(traceparent) if traceparent else None if span_context: span_name = self._get_mcp_operation(req) @@ -125,16 +125,24 @@ async def _wrap_handle_request( def _generate_mcp_attributes(self, span: trace.Span, request: ClientRequest, is_client: bool) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - operation = "UnknownOperation" + operation = MCPOperations.UNKNOWN_OPERATION + if isinstance(request, types.ListToolsRequest): - operation = "ListTool" - span.set_attribute("mcp.list_tools", True) + operation = MCPOperations.LIST_TOOL + span.set_attribute(MCPAttributes.MCP_LIST_TOOLS, True) + if is_client: + span.update_name(MCPSpanNames.CLIENT_LIST_TOOLS) elif isinstance(request, types.CallToolRequest): operation = request.params.name - span.set_attribute("mcp.call_tool", True) + span.set_attribute(MCPAttributes.MCP_CALL_TOOL, True) + if is_client: + span.update_name(MCPSpanNames.client_call_tool(request.params.name)) elif isinstance(request, types.InitializeRequest): - operation = "Initialize" - span.set_attribute("mcp.initialize", True) + operation = MCPOperations.INITIALIZE + span.set_attribute(MCPAttributes.MCP_INITIALIZE, True) + if is_client: + span.update_name(MCPSpanNames.CLIENT_INITIALIZE) + if is_client: self._add_client_attributes(span, operation, request) else: @@ -148,9 +156,9 @@ def _inject_trace_context(request_data: Dict[str, Any], span_ctx) -> None: request_data["params"]["_meta"] = {} trace_id_hex = f"{span_ctx.trace_id:032x}" span_id_hex = f"{span_ctx.span_id:016x}" - trace_flags = "01" - traceparent = f"00-{trace_id_hex}-{span_id_hex}-{trace_flags}" - request_data["params"]["_meta"]["traceparent"] = traceparent + trace_flags = MCPTraceContext.TRACE_FLAGS_SAMPLED + traceparent = f"{MCPTraceContext.TRACEPARENT_VERSION}-{trace_id_hex}-{span_id_hex}-{trace_flags}" + request_data["params"]["_meta"][MCPTraceContext.TRACEPARENT_HEADER] = traceparent @staticmethod def _extract_span_context_from_traceparent(traceparent: str): @@ -177,24 +185,24 @@ def _get_mcp_operation(req: ClientRequest) -> str: span_name = "unknown" if isinstance(req, types.ListToolsRequest): - span_name = "tools/list" + span_name = MCPSpanNames.TOOLS_LIST elif isinstance(req, types.CallToolRequest): - span_name = f"tools/{req.params.name}" + span_name = MCPSpanNames.tools_call(req.params.name) elif isinstance(req, types.InitializeRequest): - span_name = "tools/initialize" + span_name = MCPSpanNames.TOOLS_INITIALIZE return span_name @staticmethod def _add_client_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: import os # pylint: disable=import-outside-toplevel - service_name = os.environ.get("MCP_INSTRUMENTATION_SERVER_NAME", "mcp server") - span.set_attribute("aws.remote.service", service_name) - span.set_attribute("aws.remote.operation", operation) + service_name = os.environ.get(MCPEnvironmentVariables.SERVER_NAME, "mcp server") + span.set_attribute(MCPAttributes.AWS_REMOTE_SERVICE, service_name) + span.set_attribute(MCPAttributes.AWS_REMOTE_OPERATION, operation) if hasattr(request, "params") and request.params and hasattr(request.params, "name"): - span.set_attribute("tool.name", request.params.name) + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) @staticmethod def _add_server_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: if hasattr(request, "params") and request.params and hasattr(request.params, "name"): - span.set_attribute("tool.name", request.params.name) + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py new file mode 100644 index 000000000..a7291c558 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py @@ -0,0 +1,155 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +MCP (Model Context Protocol) Semantic Conventions for OpenTelemetry. + +This module defines semantic conventions for MCP instrumentation following +OpenTelemetry standards for consistent telemetry data. +""" + + +class MCPAttributes: + """MCP-specific span attributes for OpenTelemetry instrumentation.""" + + # MCP Operation Type Attributes + MCP_INITIALIZE = "mcp.initialize" + """ + Boolean attribute indicating this span represents an MCP initialize operation. + Set to True when the span tracks session initialization between client and server. + """ + + MCP_LIST_TOOLS = "mcp.list_tools" + """ + Boolean attribute indicating this span represents an MCP list tools operation. + Set to True when the span tracks discovery of available tools on the server. + """ + + MCP_CALL_TOOL = "mcp.call_tool" + """ + Boolean attribute indicating this span represents an MCP call tool operation. + Set to True when the span tracks execution of a specific tool. + """ + + # MCP Tool Information + MCP_TOOL_NAME = "mcp.tool.name" + """ + The name of the MCP tool being called. + Example: "echo", "search", "calculator" + """ + + # AWS-specific Remote Service Attributes + AWS_REMOTE_SERVICE = "aws.remote.service" + """ + The name of the remote MCP service being called. + Default: "mcp server" (can be overridden via MCP_INSTRUMENTATION_SERVER_NAME env var) + """ + + AWS_REMOTE_OPERATION = "aws.remote.operation" + """ + The specific MCP operation being performed. + Values: "Initialize", "ListTool", or the specific tool name for call operations + """ + + +class MCPSpanNames: + """Standard span names for MCP operations.""" + + # Client-side span names + CLIENT_SEND_REQUEST = "client.send_request" + """ + Span name for client-side MCP request operations. + Used for all outgoing MCP requests (initialize, list tools, call tool). + """ + + CLIENT_INITIALIZE = "mcp.initialize" + """ + Span name for client-side MCP initialization requests. + """ + + CLIENT_LIST_TOOLS = "mcp.list_tools" + """ + Span name for client-side MCP list tools requests. + """ + + @staticmethod + def client_call_tool(tool_name: str) -> str: + """ + Generate span name for client-side MCP tool call requests. + + Args: + tool_name: Name of the tool being called + + Returns: + Formatted span name like "mcp.call_tool.echo", "mcp.call_tool.search" + """ + return f"mcp.call_tool.{tool_name}" + + # Server-side span names + TOOLS_INITIALIZE = "tools/initialize" + """ + Span name for server-side MCP initialization handling. + Tracks server processing of client initialization requests. + """ + + TOOLS_LIST = "tools/list" + """ + Span name for server-side MCP list tools handling. + Tracks server processing of tool discovery requests. + """ + + @staticmethod + def tools_call(tool_name: str) -> str: + """ + Generate span name for server-side MCP tool call handling. + + Args: + tool_name: Name of the tool being called + + Returns: + Formatted span name like "tools/echo", "tools/search" + """ + return f"tools/{tool_name}" + + +class MCPOperations: + """Standard operation names for MCP semantic conventions.""" + + INITIALIZE = "Initialize" + """Operation name for MCP session initialization.""" + + LIST_TOOL = "ListTool" + """Operation name for MCP tool discovery.""" + + UNKNOWN_OPERATION = "UnknownOperation" + """Fallback operation name for unrecognized MCP operations.""" + + +class MCPTraceContext: + """Constants for MCP distributed tracing context propagation.""" + + TRACEPARENT_HEADER = "traceparent" + """ + W3C Trace Context traceparent header name. + Used for propagating trace context in MCP request metadata. + """ + + TRACE_FLAGS_SAMPLED = "01" + """ + W3C Trace Context flags indicating the trace is sampled. + """ + + TRACEPARENT_VERSION = "00" + """ + W3C Trace Context version identifier. + """ + + +class MCPEnvironmentVariables: + """Environment variable names for MCP instrumentation configuration.""" + + SERVER_NAME = "MCP_INSTRUMENTATION_SERVER_NAME" + """ + Environment variable to override the default MCP server name. + Default value: "mcp server" + """ diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/version.py similarity index 100% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/version.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/version.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md deleted file mode 100644 index 53628380b..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# MCP Instrumentor - -OpenTelemetry MCP instrumentation package for AWS Distro. - -## Installation - -```bash -pip install amazon-opentelemetry-distro-mcpinstrumentor -``` - -## Usage - -```python -from mcpinstrumentor import MCPInstrumentor - -MCPInstrumentor().instrument() -``` - -## Configuration - -### Environment Variables - -- `MCP_SERVICE_NAME`: Sets the service name for MCP client spans. Defaults to "Generic MCP Server" if not set. - -```bash -export MCP_SERVICE_NAME="My Custom MCP Server" -``` - -## Features - -- Automatic instrumentation of MCP client and server requests -- Distributed tracing support with trace context propagation -- Configurable service naming via environment variables \ No newline at end of file diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml deleted file mode 100644 index 0ad7a8877..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor/pyproject.toml +++ /dev/null @@ -1,49 +0,0 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[project] -name = "amazon-opentelemetry-distro-mcpinstrumentor" -version = "0.1.0" -description = "OpenTelemetry MCP instrumentation for AWS Distro" -readme = "README.md" -license = "Apache-2.0" -requires-python = ">=3.9" -authors = [ - { name = "Johnny Lin", email = "jonzilin@amazon.com" }, -] -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", -] -dependencies = [ - "mcp", - "opentelemetry-api", - "opentelemetry-instrumentation", - "opentelemetry-semantic-conventions", - "wrapt", - "opentelemetry-sdk", -] - -[project.optional-dependencies] -instruments = ["mcp"] - -[project.entry-points.opentelemetry_instrumentor] -mcp = "mcpinstrumentor:MCPInstrumentor" - -[tool.hatch.build.targets.sdist] -include = [ - "mcpinstrumentor.py", - "README.md" -] - -[tool.hatch.build.targets.wheel] -packages = ["."] \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 0bfab1f00..105817248 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -6,18 +6,11 @@ """ import asyncio -import os -import sys import unittest from typing import Any, Dict, List, Optional from unittest.mock import MagicMock -# Add src path for imports -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) -src_path = os.path.join(project_root, "src") -sys.path.insert(0, src_path) - -from amazon.opentelemetry.distro.mcpinstrumentor.mcpinstrumentor import MCPInstrumentor # noqa: E402 +from amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor import MCPInstrumentor class SimpleSpanContext: @@ -103,7 +96,9 @@ def setUp(self) -> None: def test_instrument_without_tracer_provider_kwargs(self) -> None: """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" # Execute - Actually test the mcpinstrumentor method - with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer: + with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer, unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" + ): mock_get_tracer.return_value = "default_tracer" self.instrumentor._instrument() @@ -118,7 +113,10 @@ def test_instrument_with_tracer_provider_kwargs(self) -> None: provider = SimpleTracerProvider() # Execute - Actually test the mcpinstrumentor method - self.instrumentor._instrument(tracer_provider=provider) + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" + ): + self.instrumentor._instrument(tracer_provider=provider) # Verify - tracer should be set from the provided tracer_provider self.assertTrue(hasattr(self.instrumentor, "tracer")) @@ -140,8 +138,8 @@ def test_instrumentation_dependencies(self) -> None: # Verify - should return the _instruments collection self.assertIsNotNone(dependencies) - # The dependencies come from openinference.instrumentation.mcp.package._instruments - # which should be a collection + # Should contain mcp dependency + self.assertIn("mcp >= 1.6.0", dependencies) class TestTraceContextInjection(unittest.TestCase): diff --git a/contract-tests/images/applications/mcp/Dockerfile b/contract-tests/images/applications/mcp/Dockerfile index 8c1792b2d..717846edc 100644 --- a/contract-tests/images/applications/mcp/Dockerfile +++ b/contract-tests/images/applications/mcp/Dockerfile @@ -5,17 +5,13 @@ FROM python:3.10 WORKDIR /mcp COPY ./dist/$DISTRO /mcp COPY ./contract-tests/images/applications/mcp /mcp -COPY ./aws-opentelemetry-distro/src/amazon/opentelemetry/distro/mcpinstrumentor /mcp/mcpinstrumentor ENV PIP_ROOT_USER_ACTION=ignore ARG DISTRO -# Install your MCP instrumentor first (use -e for editable install) -RUN pip install --upgrade pip && pip install -e ./mcpinstrumentor/ - -# Then install other requirements and the main distro -RUN pip install -r requirements.txt && pip install ${DISTRO} --force-reinstall +# Install requirements and the main distro (MCP instrumentor is included) +RUN pip install --upgrade pip && pip install -r requirements.txt && pip install ${DISTRO} --force-reinstall RUN opentelemetry-bootstrap -a install # Without `-u`, logs will be buffered and `wait_for_logs` will never return. -CMD ["opentelemetry-instrument", "python3", "-u", "./simple_client.py"] \ No newline at end of file +CMD ["opentelemetry-instrument", "python3", "-u", "./client.py"] \ No newline at end of file diff --git a/contract-tests/images/applications/mcp/client.py b/contract-tests/images/applications/mcp/client.py index b67b04757..9903cdbd7 100644 --- a/contract-tests/images/applications/mcp/client.py +++ b/contract-tests/images/applications/mcp/client.py @@ -29,7 +29,7 @@ async def _call_mcp_tool(self, tool_name, arguments): "OTEL_LOGS_EXPORTER": "none", } server_params = StdioServerParameters( - command="opentelemetry-instrument", args=["python3", "simple_mcp_server.py"], env=server_env + command="opentelemetry-instrument", args=["python3", "mcp_server.py"], env=server_env ) async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: From 70b0da3dea9fabebccdf3e6320b1df934f491273 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 31 Jul 2025 14:06:57 -0700 Subject: [PATCH 20/41] updated instrumentation README.md,revert some changes in distro/pyproject.toml --- aws-opentelemetry-distro/pyproject.toml | 4 +++- .../opentelemetry/distro/instrumentation/mcp/README.md | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index df47fda19..272c68167 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -21,6 +21,8 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ @@ -113,4 +115,4 @@ include = [ ] [tool.hatch.build.targets.wheel] -packages = ["src/amazon"] \ No newline at end of file +packages = ["src/amazon"] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md index 893dc2409..06cf96152 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md @@ -24,5 +24,8 @@ opentelemetry-instrument python your_mcp_app.py ## Spans Created -- **Client**: `client.send_request` +- **Client**: + - Initialize: `mcp.initialize` + - List Tools: `mcp.list_tools` + - Call Tool: `mcp.call_tool.{tool_name}` - **Server**: `tools/initialize`, `tools/list`, `tools/{tool_name}` \ No newline at end of file From 0af52d93fd3c8e17a02d9deaae59899bc703f42d Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Fri, 1 Aug 2025 15:23:30 -0700 Subject: [PATCH 21/41] Fixed span attributes, updated mcp model, lint, semconv changes --- aws-opentelemetry-distro/pyproject.toml | 2 +- .../distro/instrumentation/mcp/constants.py | 38 +++++++++++++ .../instrumentation/mcp/mcp_instrumentor.py | 40 +++++++------ .../distro/instrumentation/mcp/semconv.py | 56 +------------------ .../distro/test_mcpinstrumentor.py | 14 ++--- 5 files changed, 69 insertions(+), 81 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 272c68167..f36c66e4b 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -89,7 +89,7 @@ dependencies = [ # If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch patch = [ "botocore ~= 1.0", - "mcp >= 1.1.0" + "mcp >= 1.6.0" ] test = [] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py new file mode 100644 index 000000000..3ed60baa1 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +MCP (Model Context Protocol) Constants for OpenTelemetry instrumentation. + +This module defines constants and configuration variables used by the MCP instrumentor. +""" + + +class MCPTraceContext: + """Constants for MCP distributed tracing context propagation.""" + + TRACEPARENT_HEADER = "traceparent" + """ + W3C Trace Context traceparent header name. + Used for propagating trace context in MCP request metadata. + """ + + TRACE_FLAGS_SAMPLED = "01" + """ + W3C Trace Context flags indicating the trace is sampled. + """ + + TRACEPARENT_VERSION = "00" + """ + W3C Trace Context version identifier. + """ + + +class MCPEnvironmentVariables: + """Environment variable names for MCP instrumentation configuration.""" + + SERVER_NAME = "MCP_INSTRUMENTATION_SERVER_NAME" + """ + Environment variable to override the default MCP server name. + Default value: "mcp server" + """ diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py index 0c40968c0..b12961064 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py @@ -2,14 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Callable, Collection, Dict, Tuple -from mcp import ClientRequest from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap -from .semconv import MCPAttributes, MCPSpanNames, MCPOperations, MCPTraceContext, MCPEnvironmentVariables -_instruments = ("mcp >= 1.6.0",) +from opentelemetry.semconv.trace import SpanAttributes + +from .constants import MCPEnvironmentVariables, MCPTraceContext +from .semconv import MCPAttributes, MCPOperations, MCPSpanNames + class MCPInstrumentor(BaseInstrumentor): """ @@ -22,7 +24,7 @@ def __init__(self): @staticmethod def instrumentation_dependencies() -> Collection[str]: - return _instruments + return ("mcp >= 1.6.0",) def _instrument(self, **kwargs: Any) -> None: tracer_provider = kwargs.get("tracer_provider") @@ -65,7 +67,9 @@ def _wrap_send_request( """ async def async_wrapper(): - with self.tracer.start_as_current_span(MCPSpanNames.CLIENT_SEND_REQUEST, kind=trace.SpanKind.CLIENT) as span: + with self.tracer.start_as_current_span( + MCPSpanNames.CLIENT_SEND_REQUEST, kind=trace.SpanKind.CLIENT + ) as span: span_ctx = span.get_span_context() request = args[0] if len(args) > 0 else kwargs.get("request") if request: @@ -122,11 +126,12 @@ async def _wrap_handle_request( else: return await wrapped(*args, **kwargs) - def _generate_mcp_attributes(self, span: trace.Span, request: ClientRequest, is_client: bool) -> None: + @staticmethod + def _generate_mcp_attributes(span: trace.Span, request: Any, is_client: bool) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import operation = MCPOperations.UNKNOWN_OPERATION - + if isinstance(request, types.ListToolsRequest): operation = MCPOperations.LIST_TOOL span.set_attribute(MCPAttributes.MCP_LIST_TOOLS, True) @@ -142,11 +147,11 @@ def _generate_mcp_attributes(self, span: trace.Span, request: ClientRequest, is_ span.set_attribute(MCPAttributes.MCP_INITIALIZE, True) if is_client: span.update_name(MCPSpanNames.CLIENT_INITIALIZE) - + if is_client: - self._add_client_attributes(span, operation, request) + MCPInstrumentor._add_client_attributes(span, operation, request) else: - self._add_server_attributes(span, operation, request) + MCPInstrumentor._add_server_attributes(span, operation, request) @staticmethod def _inject_trace_context(request_data: Dict[str, Any], span_ctx) -> None: @@ -179,7 +184,8 @@ def _extract_span_context_from_traceparent(traceparent: str): return None @staticmethod - def _get_mcp_operation(req: ClientRequest) -> str: + def _get_mcp_operation(req: Any) -> str: + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import span_name = "unknown" @@ -188,21 +194,19 @@ def _get_mcp_operation(req: ClientRequest) -> str: span_name = MCPSpanNames.TOOLS_LIST elif isinstance(req, types.CallToolRequest): span_name = MCPSpanNames.tools_call(req.params.name) - elif isinstance(req, types.InitializeRequest): - span_name = MCPSpanNames.TOOLS_INITIALIZE return span_name @staticmethod - def _add_client_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: + def _add_client_attributes(span: trace.Span, operation: str, request: Any) -> None: import os # pylint: disable=import-outside-toplevel service_name = os.environ.get(MCPEnvironmentVariables.SERVER_NAME, "mcp server") - span.set_attribute(MCPAttributes.AWS_REMOTE_SERVICE, service_name) - span.set_attribute(MCPAttributes.AWS_REMOTE_OPERATION, operation) + span.set_attribute(SpanAttributes.RPC_SERVICE, service_name) + span.set_attribute(SpanAttributes.RPC_METHOD, operation) if hasattr(request, "params") and request.params and hasattr(request.params, "name"): span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) @staticmethod - def _add_server_attributes(span: trace.Span, operation: str, request: ClientRequest) -> None: + def _add_server_attributes(span: trace.Span, operation: str, request: Any) -> None: if hasattr(request, "params") and request.params and hasattr(request.params, "name"): - span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) \ No newline at end of file + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py index a7291c558..409092c2a 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py @@ -13,7 +13,7 @@ class MCPAttributes: """MCP-specific span attributes for OpenTelemetry instrumentation.""" # MCP Operation Type Attributes - MCP_INITIALIZE = "mcp.initialize" + MCP_INITIALIZE = "notifications/initialize" """ Boolean attribute indicating this span represents an MCP initialize operation. Set to True when the span tracks session initialization between client and server. @@ -38,19 +38,6 @@ class MCPAttributes: Example: "echo", "search", "calculator" """ - # AWS-specific Remote Service Attributes - AWS_REMOTE_SERVICE = "aws.remote.service" - """ - The name of the remote MCP service being called. - Default: "mcp server" (can be overridden via MCP_INSTRUMENTATION_SERVER_NAME env var) - """ - - AWS_REMOTE_OPERATION = "aws.remote.operation" - """ - The specific MCP operation being performed. - Values: "Initialize", "ListTool", or the specific tool name for call operations - """ - class MCPSpanNames: """Standard span names for MCP operations.""" @@ -62,7 +49,7 @@ class MCPSpanNames: Used for all outgoing MCP requests (initialize, list tools, call tool). """ - CLIENT_INITIALIZE = "mcp.initialize" + CLIENT_INITIALIZE = "notifications/initialize" """ Span name for client-side MCP initialization requests. """ @@ -85,13 +72,6 @@ def client_call_tool(tool_name: str) -> str: """ return f"mcp.call_tool.{tool_name}" - # Server-side span names - TOOLS_INITIALIZE = "tools/initialize" - """ - Span name for server-side MCP initialization handling. - Tracks server processing of client initialization requests. - """ - TOOLS_LIST = "tools/list" """ Span name for server-side MCP list tools handling. @@ -115,7 +95,7 @@ def tools_call(tool_name: str) -> str: class MCPOperations: """Standard operation names for MCP semantic conventions.""" - INITIALIZE = "Initialize" + INITIALIZE = "Notifications/Initialize" """Operation name for MCP session initialization.""" LIST_TOOL = "ListTool" @@ -123,33 +103,3 @@ class MCPOperations: UNKNOWN_OPERATION = "UnknownOperation" """Fallback operation name for unrecognized MCP operations.""" - - -class MCPTraceContext: - """Constants for MCP distributed tracing context propagation.""" - - TRACEPARENT_HEADER = "traceparent" - """ - W3C Trace Context traceparent header name. - Used for propagating trace context in MCP request metadata. - """ - - TRACE_FLAGS_SAMPLED = "01" - """ - W3C Trace Context flags indicating the trace is sampled. - """ - - TRACEPARENT_VERSION = "00" - """ - W3C Trace Context version identifier. - """ - - -class MCPEnvironmentVariables: - """Environment variable names for MCP instrumentation configuration.""" - - SERVER_NAME = "MCP_INSTRUMENTATION_SERVER_NAME" - """ - Environment variable to override the default MCP server name. - Default value: "mcp server" - """ diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 105817248..9b6d1ae5d 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -105,7 +105,7 @@ def test_instrument_without_tracer_provider_kwargs(self) -> None: # Verify - tracer should be set from trace.get_tracer self.assertTrue(hasattr(self.instrumentor, "tracer")) self.assertEqual(self.instrumentor.tracer, "default_tracer") - mock_get_tracer.assert_called_with("mcp") + mock_get_tracer.assert_called_with("instrumentation.mcp") def test_instrument_with_tracer_provider_kwargs(self) -> None: """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer""" @@ -122,7 +122,7 @@ def test_instrument_with_tracer_provider_kwargs(self) -> None: self.assertTrue(hasattr(self.instrumentor, "tracer")) self.assertEqual(self.instrumentor.tracer, "mock_tracer_from_provider") self.assertTrue(provider.get_tracer_called) - self.assertEqual(provider.tracer_name, "mcp") + self.assertEqual(provider.tracer_name, "instrumentation.mcp") class TestInstrumentationDependencies(unittest.TestCase): @@ -249,7 +249,7 @@ def __init__(self, name: str) -> None: # Test server handling without trace context (fallback scenario) with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock()} + "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" ): @@ -384,7 +384,7 @@ async def server_handle_request(self, session: Any, server_request: Any) -> Dict # STEP 1: Client sends request through instrumentation with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock()} + "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): # Override the setup tracer with the properly mocked one self.instrumentor.tracer = mock_tracer @@ -420,7 +420,7 @@ async def server_handle_request(self, session: Any, server_request: Any) -> Dict # Server processes the request it received with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock()} + "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" ): @@ -465,7 +465,3 @@ async def server_handle_request(self, session: Any, server_request: Any) -> Dict any(expected_entry in log_entry for log_entry in e2e_system.communication_log), f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", ) - - -if __name__ == "__main__": - unittest.main() From 775f0b1cbfd9252e3661e713ee348c04dc09ffd0 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Fri, 1 Aug 2025 18:54:51 -0700 Subject: [PATCH 22/41] lint and test coverage --- .../distro/test_mcpinstrumentor.py | 143 ++++++++++++++++++ .../images/applications/mcp/Dockerfile | 4 +- .../images/applications/mcp/client.py | 5 +- .../tests/test/amazon/mcp/mcp_test.py | 27 ++-- 4 files changed, 161 insertions(+), 18 deletions(-) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 9b6d1ae5d..de61c517e 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -266,6 +266,7 @@ def __init__(self, name: str) -> None: # Should not create traced spans when no trace context is present mock_tracer.start_as_current_span.assert_not_called() + # pylint: disable=too-many-locals,too-many-statements def test_end_to_end_client_server_communication( self, ) -> None: @@ -465,3 +466,145 @@ async def server_handle_request(self, session: Any, server_request: Any) -> Dict any(expected_entry in log_entry for log_entry in e2e_system.communication_log), f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", ) + + +class TestMCPInstrumentorEdgeCases(unittest.TestCase): + """Test edge cases and error conditions for MCP instrumentor""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_invalid_traceparent_format(self) -> None: + """Test handling of malformed traceparent headers""" + invalid_formats = [ + "invalid-format", + "00-invalid-hex-01", + "00-12345-67890", # Missing part + "00-12345-67890-01-extra", # Too many parts + "", # Empty string + ] + + for invalid_format in invalid_formats: + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + result = self.instrumentor._extract_span_context_from_traceparent(invalid_format) + self.assertIsNone(result, f"Should return None for invalid format: {invalid_format}") + + def test_version_import(self) -> None: + """Test that version can be imported""" + from amazon.opentelemetry.distro.instrumentation.mcp import version + + self.assertIsNotNone(version) + + def test_constants_import(self) -> None: + """Test that constants can be imported""" + from amazon.opentelemetry.distro.instrumentation.mcp.constants import MCPEnvironmentVariables + + self.assertIsNotNone(MCPEnvironmentVariables.SERVER_NAME) + + def test_add_client_attributes_default_server_name(self) -> None: + """Test _add_client_attributes uses default server name""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self) -> None: + self.params = MockParams() + + class MockParams: + def __init__(self) -> None: + self.name = "test_tool" + + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + + # Verify default server name is used + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_operation") + mock_span.set_attribute.assert_any_call("mcp.tool.name", "test_tool") + + def test_add_client_attributes_without_tool_name(self) -> None: + """Test _add_client_attributes when request has no tool name""" + mock_span = MagicMock() + + class MockRequestNoTool: + def __init__(self) -> None: + self.params = None + + request = MockRequestNoTool() + self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + + # Should still set service and method, but not tool name + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_operation") + + def test_add_server_attributes_without_tool_name(self) -> None: + """Test _add_server_attributes when request has no tool name""" + mock_span = MagicMock() + + class MockRequestNoTool: + def __init__(self) -> None: + self.params = None + + request = MockRequestNoTool() + self.instrumentor._add_server_attributes(mock_span, "test_operation", request) + + # Should not set any attributes for server when no tool name + mock_span.set_attribute.assert_not_called() + + def test_inject_trace_context_empty_request(self) -> None: + """Test trace context injection with minimal request data""" + request_data = {} + span_ctx = SimpleSpanContext(trace_id=111, span_id=222) + + self.instrumentor._inject_trace_context(request_data, span_ctx) + + # Should create params and _meta structure + self.assertIn("params", request_data) + self.assertIn("_meta", request_data["params"]) + self.assertIn("traceparent", request_data["params"]["_meta"]) + + # Verify traceparent format + traceparent = request_data["params"]["_meta"]["traceparent"] + parts = traceparent.split("-") + self.assertEqual(len(parts), 4) + self.assertEqual(int(parts[1], 16), 111) # trace_id + self.assertEqual(int(parts[2], 16), 222) # span_id + + def test_uninstrument(self) -> None: + """Test _uninstrument method removes instrumentation""" + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" + ) as mock_unwrap: + self.instrumentor._uninstrument() + + # Verify both unwrap calls are made + self.assertEqual(mock_unwrap.call_count, 2) + mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") + mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") + + def test_extract_span_context_valid_traceparent(self) -> None: + """Test _extract_span_context_from_traceparent with valid format""" + # Use correct hex values: 12345 = 0x3039, 67890 = 0x10932 + valid_traceparent = "00-0000000000003039-0000000000010932-01" + result = self.instrumentor._extract_span_context_from_traceparent(valid_traceparent) + + self.assertIsNotNone(result) + self.assertEqual(result.trace_id, 12345) + self.assertEqual(result.span_id, 67890) + self.assertTrue(result.is_remote) + + def test_extract_span_context_value_error(self) -> None: + """Test _extract_span_context_from_traceparent with invalid hex values""" + invalid_hex_traceparent = "00-invalid-hex-values-01" + result = self.instrumentor._extract_span_context_from_traceparent(invalid_hex_traceparent) + + self.assertIsNone(result) + + def test_instrument_method_coverage(self) -> None: + """Test _instrument method registers hooks""" + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" + ) as mock_register: + self.instrumentor._instrument() + + # Should register two hooks + self.assertEqual(mock_register.call_count, 2) diff --git a/contract-tests/images/applications/mcp/Dockerfile b/contract-tests/images/applications/mcp/Dockerfile index 717846edc..e5d77b593 100644 --- a/contract-tests/images/applications/mcp/Dockerfile +++ b/contract-tests/images/applications/mcp/Dockerfile @@ -9,8 +9,8 @@ COPY ./contract-tests/images/applications/mcp /mcp ENV PIP_ROOT_USER_ACTION=ignore ARG DISTRO -# Install requirements and the main distro (MCP instrumentor is included) -RUN pip install --upgrade pip && pip install -r requirements.txt && pip install ${DISTRO} --force-reinstall +# Install requirements and the main distro with patch dependencies (MCP instrumentor is included) +RUN pip install --upgrade pip && pip install -r requirements.txt && pip install ${DISTRO}[patch] --force-reinstall RUN opentelemetry-bootstrap -a install # Without `-u`, logs will be buffered and `wait_for_logs` will never return. diff --git a/contract-tests/images/applications/mcp/client.py b/contract-tests/images/applications/mcp/client.py index 9903cdbd7..cfc0a7f01 100644 --- a/contract-tests/images/applications/mcp/client.py +++ b/contract-tests/images/applications/mcp/client.py @@ -9,7 +9,7 @@ class MCPHandler(BaseHTTPRequestHandler): - def do_GET(self): + def do_GET(self): # pylint: disable=invalid-name if self.path == "/mcp/echo": asyncio.run(self._call_mcp_tool("echo", {"text": "Hello from HTTP request!"})) self.send_response(200) @@ -18,7 +18,8 @@ def do_GET(self): self.send_response(404) self.end_headers() - async def _call_mcp_tool(self, tool_name, arguments): + @staticmethod + async def _call_mcp_tool(tool_name, arguments): server_env = { "OTEL_PYTHON_DISTRO": "aws_distro", "OTEL_PYTHON_CONFIGURATOR": "aws_configurator", diff --git a/contract-tests/tests/test/amazon/mcp/mcp_test.py b/contract-tests/tests/test/amazon/mcp/mcp_test.py index 83642f093..1533c9335 100644 --- a/contract-tests/tests/test/amazon/mcp/mcp_test.py +++ b/contract-tests/tests/test/amazon/mcp/mcp_test.py @@ -22,6 +22,7 @@ def _assert_aws_span_attributes(self, resource_scope_spans, path: str, **kwargs) pass @override + # pylint: disable=too-many-locals,too-many-statements def _assert_semantic_conventions_span_attributes( self, resource_scope_spans, method: str, path: str, status_code: int, **kwargs ) -> None: @@ -36,30 +37,28 @@ def _assert_semantic_conventions_span_attributes( for resource_scope_span in resource_scope_spans: span = resource_scope_span.span - if span.name == "client.send_request" and span.kind == Span.SPAN_KIND_CLIENT: - for attr in span.attributes: - if attr.key == "mcp.initialize" and attr.value.bool_value: - initialize_client_span = span - break - elif attr.key == "mcp.list_tools" and attr.value.bool_value: - list_tools_client_span = span - break - elif attr.key == "mcp.call_tool" and attr.value.bool_value: - call_tool_client_span = span - break - + if span.name == "notifications/initialize" and span.kind == Span.SPAN_KIND_CLIENT: + initialize_client_span = span + elif span.name == "mcp.list_tools" and span.kind == Span.SPAN_KIND_CLIENT: + list_tools_client_span = span + elif span.name == f"mcp.call_tool.{tool_name}" and span.kind == Span.SPAN_KIND_CLIENT: + call_tool_client_span = span elif span.name == "tools/list" and span.kind == Span.SPAN_KIND_SERVER: list_tools_server_span = span elif span.name == f"tools/{tool_name}" and span.kind == Span.SPAN_KIND_SERVER: call_tool_server_span = span + # Validate list tools client span + self.assertIsNotNone(list_tools_client_span, "List tools client span not found") + self.assertEqual(list_tools_client_span.kind, Span.SPAN_KIND_CLIENT) + # Validate initialize client span (no server span expected) self.assertIsNotNone(initialize_client_span, "Initialize client span not found") self.assertEqual(initialize_client_span.kind, Span.SPAN_KIND_CLIENT) init_attributes = {attr.key: attr.value for attr in initialize_client_span.attributes} - self.assertIn("mcp.initialize", init_attributes) - self.assertTrue(init_attributes["mcp.initialize"].bool_value) + self.assertIn("notifications/initialize", init_attributes) + self.assertTrue(init_attributes["notifications/initialize"].bool_value) # Validate list tools client span self.assertIsNotNone(list_tools_client_span, "List tools client span not found") From 176fd90f304286ea14daf9809c6dc6674136a3b2 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Sun, 3 Aug 2025 14:39:11 -0700 Subject: [PATCH 23/41] lint and test coverage --- .../distro/test_mcpinstrumentor.py | 623 +++++++++++++++++- 1 file changed, 619 insertions(+), 4 deletions(-) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index de61c517e..024e25cae 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -10,6 +10,8 @@ from typing import Any, Dict, List, Optional from unittest.mock import MagicMock +from amazon.opentelemetry.distro.instrumentation.mcp import version +from amazon.opentelemetry.distro.instrumentation.mcp.constants import MCPEnvironmentVariables from amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor import MCPInstrumentor @@ -491,14 +493,10 @@ def test_invalid_traceparent_format(self) -> None: def test_version_import(self) -> None: """Test that version can be imported""" - from amazon.opentelemetry.distro.instrumentation.mcp import version - self.assertIsNotNone(version) def test_constants_import(self) -> None: """Test that constants can be imported""" - from amazon.opentelemetry.distro.instrumentation.mcp.constants import MCPEnvironmentVariables - self.assertIsNotNone(MCPEnvironmentVariables.SERVER_NAME) def test_add_client_attributes_default_server_name(self) -> None: @@ -608,3 +606,620 @@ def test_instrument_method_coverage(self) -> None: # Should register two hooks self.assertEqual(mock_register.call_count, 2) + + +class TestGenerateMCPAttributes(unittest.TestCase): + """Test _generate_mcp_attributes method with mocked imports""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_generate_attributes_with_mock_types(self) -> None: + """Test _generate_mcp_attributes with mocked MCP types""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + request = MockRequest() + + # Mock the isinstance checks to avoid importing mcp.types + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance" + ) as mock_isinstance: + mock_isinstance.side_effect = lambda obj, cls: False # No matches + + self.instrumentor._generate_mcp_attributes(mock_span, request, True) + + # Should call _add_client_attributes since is_client=True + self.assertTrue(mock_isinstance.called) + + +class TestGetMCPOperation(unittest.TestCase): + """Test _get_mcp_operation method with mocked imports""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_get_operation_with_mock_types(self) -> None: + """Test _get_mcp_operation with mocked MCP types""" + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + request = MockRequest() + + # Mock the isinstance checks to return unknown + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance" + ) as mock_isinstance: + mock_isinstance.side_effect = lambda obj, cls: False # No matches + + result = self.instrumentor._get_mcp_operation(request) + + self.assertEqual(result, "unknown") + + +class TestAddClientAttributes(unittest.TestCase): + """Test _add_client_attributes method""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_add_client_attributes_with_env_var(self) -> None: + """Test _add_client_attributes with custom server name from environment""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "env_tool" + + with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "custom server"}): + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_op", request) + + mock_span.set_attribute.assert_any_call("rpc.service", "custom server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_op") + mock_span.set_attribute.assert_any_call("mcp.tool.name", "env_tool") + + def test_add_client_attributes_no_params(self) -> None: + """Test _add_client_attributes when request has no params""" + mock_span = MagicMock() + + class MockRequestNoParams: + def __init__(self): + self.params = None + + request = MockRequestNoParams() + self.instrumentor._add_client_attributes(mock_span, "test_op", request) + + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_op") + # Should not set tool name when no params + + +class TestAddServerAttributes(unittest.TestCase): + """Test _add_server_attributes method""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_add_server_attributes_with_tool_name(self) -> None: + """Test _add_server_attributes when request has tool name""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "server_tool" + + request = MockRequest() + self.instrumentor._add_server_attributes(mock_span, "test_op", request) + + mock_span.set_attribute.assert_called_once_with("mcp.tool.name", "server_tool") + + def test_add_server_attributes_no_params(self) -> None: + """Test _add_server_attributes when request has no params""" + mock_span = MagicMock() + + class MockRequestNoParams: + def __init__(self): + self.params = None + + request = MockRequestNoParams() + self.instrumentor._add_server_attributes(mock_span, "test_op", request) + + mock_span.set_attribute.assert_not_called() + + +class TestWrapSendRequestEdgeCases(unittest.TestCase): + """Test _wrap_send_request edge cases""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_span_context = MagicMock() + mock_span_context.trace_id = 12345 + mock_span_context.span_id = 67890 + mock_span.get_span_context.return_value = mock_span_context + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + self.instrumentor.tracer = mock_tracer + + def test_wrap_send_request_no_request_in_args(self) -> None: + """Test _wrap_send_request when no request in args""" + + async def mock_wrapped(): + return {"result": "no_request"} + + result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {})) + + self.assertEqual(result["result"], "no_request") + + def test_wrap_send_request_request_in_kwargs(self) -> None: + """Test _wrap_send_request when request is in kwargs""" + + class MockRequest: + def __init__(self): + self.root = self + self.params = MockParams() + + def model_dump(self, **kwargs): + return {"method": "test", "params": {"name": "test_tool"}} + + @classmethod + def model_validate(cls, data): + return cls() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + async def mock_wrapped(*args, **kwargs): + return {"result": "kwargs_request"} + + request = MockRequest() + + with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): + result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {"request": request})) + + self.assertEqual(result["result"], "kwargs_request") + + def test_wrap_send_request_no_root_attribute(self) -> None: + """Test _wrap_send_request when request has no root attribute""" + + class MockRequestNoRoot: + def __init__(self): + self.params = MockParams() + + def model_dump(self, **kwargs): + return {"method": "test", "params": {"name": "test_tool"}} + + @classmethod + def model_validate(cls, data): + return cls() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + async def mock_wrapped(*args, **kwargs): + return {"result": "no_root"} + + request = MockRequestNoRoot() + + with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): + result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (request,), {})) + + self.assertEqual(result["result"], "no_root") + + +class TestWrapHandleRequestEdgeCases(unittest.TestCase): + """Test _wrap_handle_request edge cases""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + self.instrumentor.tracer = mock_tracer + + def test_wrap_handle_request_no_request(self) -> None: + """Test _wrap_handle_request when no request in args""" + + async def mock_wrapped(*args, **kwargs): + return {"result": "no_request"} + + result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session",), {})) + + self.assertEqual(result["result"], "no_request") + + def test_wrap_handle_request_no_params(self) -> None: + """Test _wrap_handle_request when request has no params""" + + class MockRequestNoParams: + def __init__(self): + self.params = None + + async def mock_wrapped(*args, **kwargs): + return {"result": "no_params"} + + request = MockRequestNoParams() + result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) + + self.assertEqual(result["result"], "no_params") + + def test_wrap_handle_request_no_meta(self) -> None: + """Test _wrap_handle_request when request params has no meta""" + + class MockRequestNoMeta: + def __init__(self): + self.params = MockParamsNoMeta() + + class MockParamsNoMeta: + def __init__(self): + self.meta = None + + async def mock_wrapped(*args, **kwargs): + return {"result": "no_meta"} + + request = MockRequestNoMeta() + result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) + + self.assertEqual(result["result"], "no_meta") + + def test_wrap_handle_request_with_valid_traceparent(self) -> None: + """Test _wrap_handle_request with valid traceparent""" + + class MockRequestWithTrace: + def __init__(self): + self.params = MockParamsWithTrace() + + class MockParamsWithTrace: + def __init__(self): + self.meta = MockMeta() + + class MockMeta: + def __init__(self): + self.traceparent = "00-0000000000003039-0000000000010932-01" + + async def mock_wrapped(*args, **kwargs): + return {"result": "with_trace"} + + request = MockRequestWithTrace() + + with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_mcp_operation", return_value="tools/test" + ): + result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) + + self.assertEqual(result["result"], "with_trace") + + +class TestInstrumentorStaticMethods(unittest.TestCase): + """Test static methods of MCPInstrumentor""" + + def test_instrumentation_dependencies_static(self) -> None: + """Test instrumentation_dependencies as static method""" + deps = MCPInstrumentor.instrumentation_dependencies() + self.assertIn("mcp >= 1.6.0", deps) + + def test_uninstrument_static(self) -> None: + """Test _uninstrument as static method""" + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" + ) as mock_unwrap: + MCPInstrumentor._uninstrument() + + self.assertEqual(mock_unwrap.call_count, 2) + mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") + mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") + + +class TestEnvironmentVariableHandling(unittest.TestCase): + """Test environment variable handling""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_server_name_environment_variable(self) -> None: + """Test that MCP_INSTRUMENTATION_SERVER_NAME environment variable is used""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + # Test with environment variable set + with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "my-custom-server"}): + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + + mock_span.set_attribute.assert_any_call("rpc.service", "my-custom-server") + + def test_server_name_default_value(self) -> None: + """Test that default server name is used when environment variable is not set""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + # Test without environment variable + with unittest.mock.patch.dict("os.environ", {}, clear=True): + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + + +class TestTraceContextFormats(unittest.TestCase): + """Test trace context format handling""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_inject_trace_context_format(self) -> None: + """Test that injected trace context follows W3C format""" + request_data = {} + span_ctx = SimpleSpanContext(trace_id=0x12345678901234567890123456789012, span_id=0x1234567890123456) + + self.instrumentor._inject_trace_context(request_data, span_ctx) + + traceparent = request_data["params"]["_meta"]["traceparent"] + self.assertEqual(traceparent, "00-12345678901234567890123456789012-1234567890123456-01") + + def test_extract_span_context_edge_cases(self) -> None: + """Test _extract_span_context_from_traceparent with various edge cases""" + # Test with valid non-zero values + result = self.instrumentor._extract_span_context_from_traceparent( + "00-12345678901234567890123456789012-1234567890123456-01" + ) + self.assertIsNotNone(result) + self.assertEqual(result.trace_id, 0x12345678901234567890123456789012) + self.assertEqual(result.span_id, 0x1234567890123456) + self.assertTrue(result.is_remote) + + # Test with invalid format + result = self.instrumentor._extract_span_context_from_traceparent("invalid-format") + self.assertIsNone(result) + + # Test with invalid hex values + result = self.instrumentor._extract_span_context_from_traceparent("00-invalid-hex-01") + self.assertIsNone(result) + + +class TestComplexRequestStructures(unittest.TestCase): + """Test handling of complex request structures""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_inject_trace_context_nested_params(self) -> None: + """Test trace context injection with nested params structure""" + request_data = { + "method": "call_tool", + "params": {"name": "test_tool", "arguments": {"key": "value"}, "existing_meta": "should_be_preserved"}, + } + span_ctx = SimpleSpanContext(trace_id=999, span_id=888) + + self.instrumentor._inject_trace_context(request_data, span_ctx) + + # Verify existing structure is preserved + self.assertEqual(request_data["method"], "call_tool") + self.assertEqual(request_data["params"]["name"], "test_tool") + self.assertEqual(request_data["params"]["arguments"]["key"], "value") + self.assertEqual(request_data["params"]["existing_meta"], "should_be_preserved") + + # Verify trace context is added + self.assertIn("_meta", request_data["params"]) + self.assertIn("traceparent", request_data["params"]["_meta"]) + + traceparent = request_data["params"]["_meta"]["traceparent"] + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 999) + self.assertEqual(int(parts[2], 16), 888) + + +class TestMCPInstrumentorCoverage(unittest.TestCase): + """Additional tests to improve coverage without importing MCP""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + + def test_generate_mcp_attributes_unknown_operation(self) -> None: + """Test _generate_mcp_attributes with unknown request type""" + mock_span = MagicMock() + + class MockUnknownRequest: + pass + + request = MockUnknownRequest() + + # Mock isinstance to return False for all checks + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", return_value=False + ): + self.instrumentor._generate_mcp_attributes(mock_span, request, True) + + # Should call _add_client_attributes with unknown operation + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "UnknownOperation") + + def test_add_client_attributes_no_name_attribute(self) -> None: + """Test _add_client_attributes when request params has no name attribute""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + pass # No name attribute + + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_op", request) + + # Should set service and method but not tool name + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_op") + # Verify only 2 calls were made (service and method, no tool name) + self.assertEqual(mock_span.set_attribute.call_count, 2) + + def test_add_server_attributes_no_name_attribute(self) -> None: + """Test _add_server_attributes when request params has no name attribute""" + mock_span = MagicMock() + + class MockRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + pass # No name attribute + + request = MockRequest() + self.instrumentor._add_server_attributes(mock_span, "test_op", request) + + # Should not set any attributes when no name + mock_span.set_attribute.assert_not_called() + + def test_extract_span_context_zero_trace_id(self) -> None: + """Test _extract_span_context_from_traceparent with zero trace_id""" + # Zero trace_id should still create a valid span context in OpenTelemetry + result = self.instrumentor._extract_span_context_from_traceparent( + "00-00000000000000000000000000000000-1234567890123456-01" + ) + self.assertIsNotNone(result) + self.assertEqual(result.trace_id, 0) + self.assertEqual(result.span_id, 0x1234567890123456) + + def test_extract_span_context_zero_span_id(self) -> None: + """Test _extract_span_context_from_traceparent with zero span_id""" + # Zero span_id should still create a valid span context in OpenTelemetry + result = self.instrumentor._extract_span_context_from_traceparent( + "00-12345678901234567890123456789012-0000000000000000-01" + ) + self.assertIsNotNone(result) + self.assertEqual(result.trace_id, 0x12345678901234567890123456789012) + self.assertEqual(result.span_id, 0) + + def test_inject_trace_context_existing_meta(self) -> None: + """Test _inject_trace_context when _meta already exists""" + request_data = {"params": {"_meta": {"existing_field": "should_be_preserved"}}} + span_ctx = SimpleSpanContext(trace_id=123, span_id=456) + + self.instrumentor._inject_trace_context(request_data, span_ctx) + + # Should preserve existing _meta fields + self.assertEqual(request_data["params"]["_meta"]["existing_field"], "should_be_preserved") + self.assertIn("traceparent", request_data["params"]["_meta"]) + + def test_wrap_send_request_exception_handling(self) -> None: + """Test _wrap_send_request handles exceptions gracefully""" + + async def mock_wrapped_that_raises(): + raise ValueError("Test exception") + + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_span_context = MagicMock() + mock_span_context.trace_id = 12345 + mock_span_context.span_id = 67890 + mock_span.get_span_context.return_value = mock_span_context + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + self.instrumentor.tracer = mock_tracer + + with self.assertRaises(ValueError): + asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped_that_raises, None, (), {})) + + def test_wrap_handle_request_exception_handling(self) -> None: + """Test _wrap_handle_request handles exceptions gracefully""" + + async def mock_wrapped_that_raises(*args, **kwargs): + raise ValueError("Test exception") + + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + self.instrumentor.tracer = mock_tracer + + with self.assertRaises(ValueError): + asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped_that_raises, None, ("session", None), {})) + + def test_instrumentation_dependencies_return_type(self) -> None: + """Test that instrumentation_dependencies returns a Collection""" + deps = self.instrumentor.instrumentation_dependencies() + self.assertIsInstance(deps, tuple) + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0], "mcp >= 1.6.0") + + def test_instrument_with_none_tracer_provider(self) -> None: + """Test _instrument method when tracer_provider is None""" + with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer, unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" + ): + mock_get_tracer.return_value = "default_tracer" + + self.instrumentor._instrument(tracer_provider=None) + + # Should use default tracer when tracer_provider is None + self.assertEqual(self.instrumentor.tracer, "default_tracer") + mock_get_tracer.assert_called_with("instrumentation.mcp") + + def test_add_client_attributes_missing_params_attribute(self) -> None: + """Test _add_client_attributes when request has no params attribute""" + mock_span = MagicMock() + + class MockRequestNoParams: + pass # No params attribute at all + + request = MockRequestNoParams() + self.instrumentor._add_client_attributes(mock_span, "test_op", request) + + # Should still set service and method + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_op") + + def test_add_server_attributes_missing_params_attribute(self) -> None: + """Test _add_server_attributes when request has no params attribute""" + mock_span = MagicMock() + + class MockRequestNoParams: + pass # No params attribute at all + + request = MockRequestNoParams() + self.instrumentor._add_server_attributes(mock_span, "test_op", request) + + # Should not set any attributes + mock_span.set_attribute.assert_not_called() From 0ea969586b30cdb5d63219d1263d58d5460ca183 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Sun, 3 Aug 2025 17:31:57 -0700 Subject: [PATCH 24/41] lint and contract test coverage --- .../distro/test_mcpinstrumentor.py | 901 ++++++------------ 1 file changed, 301 insertions(+), 600 deletions(-) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 024e25cae..12d95065b 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -7,7 +7,7 @@ import asyncio import unittest -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from unittest.mock import MagicMock from amazon.opentelemetry.distro.instrumentation.mcp import version @@ -214,262 +214,6 @@ def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> Non self.assertEqual(client_data["params"]["arguments"]["metric_name"], "response_time") -class TestInstrumentedMCPServer(unittest.TestCase): - """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - # Initialize tracer so the instrumentor can work - mock_tracer = MagicMock() - self.instrumentor.tracer = mock_tracer - - def test_no_trace_context_fallback(self) -> None: - """Test graceful handling when no trace context is present on server side""" - - class MockServerNoTrace: - @staticmethod - async def _handle_request(session: Any, request: Any) -> Dict[str, Any]: - return {"success": True, "handled_without_trace": True} - - class MockServerRequestNoTrace: - def __init__(self, tool_name: str) -> None: - self.params = MockServerRequestParamsNoTrace(tool_name) - - class MockServerRequestParamsNoTrace: - def __init__(self, name: str) -> None: - self.name = name - self.meta: Optional[Any] = None # No trace context - - mock_server = MockServerNoTrace() - server_request = MockServerRequestNoTrace("create_metric") - - # Setup mocks - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - - # Test server handling without trace context (fallback scenario) - with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" - ): - - result = asyncio.run( - self.instrumentor._wrap_handle_request(mock_server._handle_request, None, (None, server_request), {}) - ) - - # Verify graceful fallback - no tracing spans should be created when no trace context - # The wrapper should call the original function without creating distributed trace spans - self.assertEqual(result["success"], True) - self.assertEqual(result["handled_without_trace"], True) - - # Should not create traced spans when no trace context is present - mock_tracer.start_as_current_span.assert_not_called() - - # pylint: disable=too-many-locals,too-many-statements - def test_end_to_end_client_server_communication( - self, - ) -> None: - """Test where server actually receives what client sends (including injected trace context)""" - - # Create realistic request/response classes - class MCPRequest: - def __init__( - self, tool_name: str, arguments: Optional[Dict[str, Any]] = None, method: str = "call_tool" - ) -> None: - self.root = self - self.params = MCPRequestParams(tool_name, arguments) - self.method = method - - def model_dump( - self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True - ) -> Dict[str, Any]: - result = {"method": self.method, "params": {"name": self.params.name}} - if self.params.arguments: - result["params"]["arguments"] = self.params.arguments - # Include _meta if it exists (for trace context) - if hasattr(self.params, "_meta") and self.params._meta: - result["params"]["_meta"] = self.params._meta - return result - - @classmethod - def model_validate(cls, data: Dict[str, Any]) -> "MCPRequest": - method = data.get("method", "call_tool") - instance = cls(data["params"]["name"], data["params"].get("arguments"), method) - # Restore _meta field if present - if "_meta" in data["params"]: - instance.params._meta = data["params"]["_meta"] - return instance - - class MCPRequestParams: - def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: - self.name = name - self.arguments = arguments - self._meta: Optional[Dict[str, Any]] = None - - class MCPServerRequest: - def __init__(self, client_request_data: Dict[str, Any]) -> None: - """Server request created from client's serialized data""" - self.method = client_request_data.get("method", "call_tool") - self.params = MCPServerRequestParams(client_request_data["params"]) - - class MCPServerRequestParams: - def __init__(self, params_data: Dict[str, Any]) -> None: - self.name = params_data["name"] - self.arguments = params_data.get("arguments") - # Extract traceparent from _meta if present - if "_meta" in params_data and "traceparent" in params_data["_meta"]: - self.meta = MCPServerRequestMeta(params_data["_meta"]["traceparent"]) - else: - self.meta = None - - class MCPServerRequestMeta: - def __init__(self, traceparent: str) -> None: - self.traceparent = traceparent - - # Mock client and server that actually communicate - class EndToEndMCPSystem: - def __init__(self) -> None: - self.communication_log: List[str] = [] - self.last_sent_request: Optional[Any] = None - - async def client_send_request(self, request: Any) -> Dict[str, Any]: - """Client sends request - captures what gets sent""" - self.communication_log.append("CLIENT: Preparing to send request") - self.last_sent_request = request # Capture the modified request - - # Simulate sending over network - serialize the request - serialized_request = request.model_dump() - self.communication_log.append(f"CLIENT: Sent {serialized_request}") - - # Return client response - return {"success": True, "client_response": "Request sent successfully"} - - async def server_handle_request(self, session: Any, server_request: Any) -> Dict[str, Any]: - """Server handles the request it received""" - self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") - - # Check if traceparent was received - if server_request.params.meta and server_request.params.meta.traceparent: - traceparent = server_request.params.meta.traceparent - # Parse traceparent to extract trace_id and span_id - parts = traceparent.split("-") - if len(parts) == 4: - trace_id = int(parts[1], 16) - span_id = int(parts[2], 16) - self.communication_log.append( - f"SERVER: Found trace context - trace_id: {trace_id}, " f"span_id: {span_id}" - ) - else: - self.communication_log.append("SERVER: Invalid traceparent format") - else: - self.communication_log.append("SERVER: No trace context found") - - return {"success": True, "server_response": f"Handled {server_request.params.name}"} - - # Create the end-to-end system - e2e_system = EndToEndMCPSystem() - - # Create original client request - original_request = MCPRequest("create_metric", {"name": "cpu_usage", "value": 85}) - - # Setup OpenTelemetry mocks - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_span_context = MagicMock() - mock_span_context.trace_id = 12345 - mock_span_context.span_id = 67890 - mock_span.get_span_context.return_value = mock_span_context - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - - # STEP 1: Client sends request through instrumentation - with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): - # Override the setup tracer with the properly mocked one - self.instrumentor.tracer = mock_tracer - - client_result = asyncio.run( - self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {}) - ) - - # Verify client side worked - self.assertEqual(client_result["success"], True) - self.assertIn("CLIENT: Preparing to send request", e2e_system.communication_log) - - # Get the request that was actually sent (with trace context injected) - sent_request = e2e_system.last_sent_request - sent_request_data = sent_request.model_dump() - - # Verify traceparent was injected by client instrumentation - self.assertIn("_meta", sent_request_data["params"]) - self.assertIn("traceparent", sent_request_data["params"]["_meta"]) - - # Parse and verify traceparent contains correct trace/span IDs - traceparent = sent_request_data["params"]["_meta"]["traceparent"] - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 12345) # trace_id - self.assertEqual(int(parts[2], 16), 67890) # span_id - - # STEP 2: Server receives the EXACT request that client sent - # Create server request from the client's serialized data - server_request = MCPServerRequest(sent_request_data) - - # Reset tracer mock for server side - mock_tracer.reset_mock() - - # Server processes the request it received - with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" - ): - - server_result = asyncio.run( - self.instrumentor._wrap_handle_request( - e2e_system.server_handle_request, None, (None, server_request), {} - ) - ) - - # Verify server side worked - self.assertEqual(server_result["success"], True) - - # Verify end-to-end trace context propagation - self.assertIn("SERVER: Found trace context - trace_id: 12345, span_id: 67890", e2e_system.communication_log) - - # Verify the server received the exact same data the client sent - self.assertEqual(server_request.params.name, "create_metric") - self.assertEqual(server_request.params.arguments["name"], "cpu_usage") - self.assertEqual(server_request.params.arguments["value"], 85) - - # Verify the traceparent made it through end-to-end - self.assertIsNotNone(server_request.params.meta) - self.assertIsNotNone(server_request.params.meta.traceparent) - - # Parse traceparent and verify trace/span IDs - traceparent = server_request.params.meta.traceparent - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 12345) # trace_id - self.assertEqual(int(parts[2], 16), 67890) # span_id - - # Verify complete communication flow - expected_log_entries = [ - "CLIENT: Preparing to send request", - "CLIENT: Sent", # Part of the serialized request log - "SERVER: Received request for create_metric", - "SERVER: Found trace context - trace_id: 12345, span_id: 67890", - ] - - for expected_entry in expected_log_entries: - self.assertTrue( - any(expected_entry in log_entry for log_entry in e2e_system.communication_log), - f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", - ) - - class TestMCPInstrumentorEdgeCases(unittest.TestCase): """Test edge cases and error conditions for MCP instrumentor""" @@ -670,168 +414,6 @@ def __init__(self): self.assertEqual(result, "unknown") -class TestAddClientAttributes(unittest.TestCase): - """Test _add_client_attributes method""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_add_client_attributes_with_env_var(self) -> None: - """Test _add_client_attributes with custom server name from environment""" - mock_span = MagicMock() - - class MockRequest: - def __init__(self): - self.params = MockParams() - - class MockParams: - def __init__(self): - self.name = "env_tool" - - with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "custom server"}): - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_any_call("rpc.service", "custom server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - mock_span.set_attribute.assert_any_call("mcp.tool.name", "env_tool") - - def test_add_client_attributes_no_params(self) -> None: - """Test _add_client_attributes when request has no params""" - mock_span = MagicMock() - - class MockRequestNoParams: - def __init__(self): - self.params = None - - request = MockRequestNoParams() - self.instrumentor._add_client_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - # Should not set tool name when no params - - -class TestAddServerAttributes(unittest.TestCase): - """Test _add_server_attributes method""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_add_server_attributes_with_tool_name(self) -> None: - """Test _add_server_attributes when request has tool name""" - mock_span = MagicMock() - - class MockRequest: - def __init__(self): - self.params = MockParams() - - class MockParams: - def __init__(self): - self.name = "server_tool" - - request = MockRequest() - self.instrumentor._add_server_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_called_once_with("mcp.tool.name", "server_tool") - - def test_add_server_attributes_no_params(self) -> None: - """Test _add_server_attributes when request has no params""" - mock_span = MagicMock() - - class MockRequestNoParams: - def __init__(self): - self.params = None - - request = MockRequestNoParams() - self.instrumentor._add_server_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_not_called() - - -class TestWrapSendRequestEdgeCases(unittest.TestCase): - """Test _wrap_send_request edge cases""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_span_context = MagicMock() - mock_span_context.trace_id = 12345 - mock_span_context.span_id = 67890 - mock_span.get_span_context.return_value = mock_span_context - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - self.instrumentor.tracer = mock_tracer - - def test_wrap_send_request_no_request_in_args(self) -> None: - """Test _wrap_send_request when no request in args""" - - async def mock_wrapped(): - return {"result": "no_request"} - - result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {})) - - self.assertEqual(result["result"], "no_request") - - def test_wrap_send_request_request_in_kwargs(self) -> None: - """Test _wrap_send_request when request is in kwargs""" - - class MockRequest: - def __init__(self): - self.root = self - self.params = MockParams() - - def model_dump(self, **kwargs): - return {"method": "test", "params": {"name": "test_tool"}} - - @classmethod - def model_validate(cls, data): - return cls() - - class MockParams: - def __init__(self): - self.name = "test_tool" - - async def mock_wrapped(*args, **kwargs): - return {"result": "kwargs_request"} - - request = MockRequest() - - with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): - result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {"request": request})) - - self.assertEqual(result["result"], "kwargs_request") - - def test_wrap_send_request_no_root_attribute(self) -> None: - """Test _wrap_send_request when request has no root attribute""" - - class MockRequestNoRoot: - def __init__(self): - self.params = MockParams() - - def model_dump(self, **kwargs): - return {"method": "test", "params": {"name": "test_tool"}} - - @classmethod - def model_validate(cls, data): - return cls() - - class MockParams: - def __init__(self): - self.name = "test_tool" - - async def mock_wrapped(*args, **kwargs): - return {"result": "no_root"} - - request = MockRequestNoRoot() - - with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): - result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (request,), {})) - - self.assertEqual(result["result"], "no_root") - - class TestWrapHandleRequestEdgeCases(unittest.TestCase): """Test _wrap_handle_request edge cases""" @@ -918,203 +500,264 @@ async def mock_wrapped(*args, **kwargs): class TestInstrumentorStaticMethods(unittest.TestCase): """Test static methods of MCPInstrumentor""" - def test_instrumentation_dependencies_static(self) -> None: + @staticmethod + def test_instrumentation_dependencies_static() -> None: """Test instrumentation_dependencies as static method""" deps = MCPInstrumentor.instrumentation_dependencies() - self.assertIn("mcp >= 1.6.0", deps) + assert "mcp >= 1.6.0" in deps - def test_uninstrument_static(self) -> None: + @staticmethod + def test_uninstrument_static() -> None: """Test _uninstrument as static method""" with unittest.mock.patch( "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" ) as mock_unwrap: MCPInstrumentor._uninstrument() - self.assertEqual(mock_unwrap.call_count, 2) + assert mock_unwrap.call_count == 2 mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") -class TestEnvironmentVariableHandling(unittest.TestCase): - """Test environment variable handling""" +class TestMCPInstrumentorMissingCoverage(unittest.TestCase): + """Tests targeting specific uncovered lines in MCPInstrumentor""" def setUp(self) -> None: self.instrumentor = MCPInstrumentor() - def test_server_name_environment_variable(self) -> None: - """Test that MCP_INSTRUMENTATION_SERVER_NAME environment variable is used""" + def test_generate_mcp_attributes_list_tools_server_side(self) -> None: + """Test _generate_mcp_attributes for ListToolsRequest on server side""" mock_span = MagicMock() - class MockRequest: - def __init__(self): - self.params = MockParams() + class MockListToolsRequest: + pass - class MockParams: - def __init__(self): - self.name = "test_tool" + request = MockListToolsRequest() - # Test with environment variable set - with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "my-custom-server"}): - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + def mock_isinstance(obj, cls): + return cls.__name__ == "ListToolsRequest" - mock_span.set_attribute.assert_any_call("rpc.service", "my-custom-server") + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance + ): + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys + + sys.modules["mcp.types"].ListToolsRequest = MockListToolsRequest + + self.instrumentor._generate_mcp_attributes(mock_span, request, False) + + mock_span.set_attribute.assert_called_with("mcp.list_tools", True) + mock_span.update_name.assert_not_called() - def test_server_name_default_value(self) -> None: - """Test that default server name is used when environment variable is not set""" + def test_generate_mcp_attributes_initialize_server_side(self) -> None: + """Test _generate_mcp_attributes for InitializeRequest on server side""" mock_span = MagicMock() - class MockRequest: - def __init__(self): - self.params = MockParams() + class MockInitializeRequest: + pass - class MockParams: - def __init__(self): - self.name = "test_tool" + request = MockInitializeRequest() - # Test without environment variable - with unittest.mock.patch.dict("os.environ", {}, clear=True): - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + def mock_isinstance(obj, cls): + return cls.__name__ == "InitializeRequest" - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance + ): + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys + sys.modules["mcp.types"].InitializeRequest = MockInitializeRequest -class TestTraceContextFormats(unittest.TestCase): - """Test trace context format handling""" + self.instrumentor._generate_mcp_attributes(mock_span, request, False) - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() + mock_span.set_attribute.assert_called_with("notifications/initialize", True) + mock_span.update_name.assert_not_called() - def test_inject_trace_context_format(self) -> None: - """Test that injected trace context follows W3C format""" - request_data = {} - span_ctx = SimpleSpanContext(trace_id=0x12345678901234567890123456789012, span_id=0x1234567890123456) + def test_generate_mcp_attributes_call_tool_server_side(self) -> None: + """Test _generate_mcp_attributes for CallToolRequest on server side""" + mock_span = MagicMock() - self.instrumentor._inject_trace_context(request_data, span_ctx) + class MockCallToolRequest: + def __init__(self): + self.params = MockParams() - traceparent = request_data["params"]["_meta"]["traceparent"] - self.assertEqual(traceparent, "00-12345678901234567890123456789012-1234567890123456-01") + class MockParams: + def __init__(self): + self.name = "server_tool" - def test_extract_span_context_edge_cases(self) -> None: - """Test _extract_span_context_from_traceparent with various edge cases""" - # Test with valid non-zero values - result = self.instrumentor._extract_span_context_from_traceparent( - "00-12345678901234567890123456789012-1234567890123456-01" - ) - self.assertIsNotNone(result) - self.assertEqual(result.trace_id, 0x12345678901234567890123456789012) - self.assertEqual(result.span_id, 0x1234567890123456) - self.assertTrue(result.is_remote) + request = MockCallToolRequest() - # Test with invalid format - result = self.instrumentor._extract_span_context_from_traceparent("invalid-format") - self.assertIsNone(result) + def mock_isinstance(obj, cls): + return cls.__name__ == "CallToolRequest" - # Test with invalid hex values - result = self.instrumentor._extract_span_context_from_traceparent("00-invalid-hex-01") - self.assertIsNone(result) + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance + ): + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys + sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest -class TestComplexRequestStructures(unittest.TestCase): - """Test handling of complex request structures""" + self.instrumentor._generate_mcp_attributes(mock_span, request, False) - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() + # Should set both mcp.call_tool and mcp.tool.name attributes + mock_span.set_attribute.assert_any_call("mcp.call_tool", True) + mock_span.set_attribute.assert_any_call("mcp.tool.name", "server_tool") + mock_span.update_name.assert_not_called() - def test_inject_trace_context_nested_params(self) -> None: - """Test trace context injection with nested params structure""" - request_data = { - "method": "call_tool", - "params": {"name": "test_tool", "arguments": {"key": "value"}, "existing_meta": "should_be_preserved"}, - } - span_ctx = SimpleSpanContext(trace_id=999, span_id=888) + def test_get_mcp_operation_list_tools_request(self) -> None: + """Test _get_mcp_operation for ListToolsRequest""" - self.instrumentor._inject_trace_context(request_data, span_ctx) + class MockListToolsRequest: + pass - # Verify existing structure is preserved - self.assertEqual(request_data["method"], "call_tool") - self.assertEqual(request_data["params"]["name"], "test_tool") - self.assertEqual(request_data["params"]["arguments"]["key"], "value") - self.assertEqual(request_data["params"]["existing_meta"], "should_be_preserved") + request = MockListToolsRequest() - # Verify trace context is added - self.assertIn("_meta", request_data["params"]) - self.assertIn("traceparent", request_data["params"]["_meta"]) + def mock_isinstance(obj, cls): + return cls.__name__ == "ListToolsRequest" - traceparent = request_data["params"]["_meta"]["traceparent"] - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 999) - self.assertEqual(int(parts[2], 16), 888) + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance + ): + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys + sys.modules["mcp.types"].ListToolsRequest = MockListToolsRequest -class TestMCPInstrumentorCoverage(unittest.TestCase): - """Additional tests to improve coverage without importing MCP""" + result = self.instrumentor._get_mcp_operation(request) - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() + self.assertEqual(result, "tools/list") - def test_generate_mcp_attributes_unknown_operation(self) -> None: - """Test _generate_mcp_attributes with unknown request type""" - mock_span = MagicMock() + def test_get_mcp_operation_call_tool_request(self) -> None: + """Test _get_mcp_operation for CallToolRequest""" - class MockUnknownRequest: - pass + class MockCallToolRequest: + def __init__(self): + self.params = MockParams() - request = MockUnknownRequest() + class MockParams: + def __init__(self): + self.name = "test_tool" + + request = MockCallToolRequest() + + def mock_isinstance(obj, cls): + return cls.__name__ == "CallToolRequest" - # Mock isinstance to return False for all checks with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", return_value=False + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance ): - self.instrumentor._generate_mcp_attributes(mock_span, request, True) + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys + + sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest - # Should call _add_client_attributes with unknown operation - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "UnknownOperation") + result = self.instrumentor._get_mcp_operation(request) - def test_add_client_attributes_no_name_attribute(self) -> None: - """Test _add_client_attributes when request params has no name attribute""" + self.assertEqual(result, "tools/test_tool") + + def test_add_client_attributes_with_params_but_no_name(self) -> None: + """Test _add_client_attributes when params exists but has no name""" mock_span = MagicMock() class MockRequest: def __init__(self): - self.params = MockParams() + self.params = MockParamsNoName() - class MockParams: + class MockParamsNoName: def __init__(self): - pass # No name attribute + self.other_field = "value" request = MockRequest() self.instrumentor._add_client_attributes(mock_span, "test_op", request) - # Should set service and method but not tool name mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - # Verify only 2 calls were made (service and method, no tool name) self.assertEqual(mock_span.set_attribute.call_count, 2) - def test_add_server_attributes_no_name_attribute(self) -> None: - """Test _add_server_attributes when request params has no name attribute""" + def test_add_server_attributes_with_params_but_no_name(self) -> None: + """Test _add_server_attributes when params exists but has no name""" mock_span = MagicMock() class MockRequest: def __init__(self): - self.params = MockParams() + self.params = MockParamsNoName() - class MockParams: + class MockParamsNoName: def __init__(self): - pass # No name attribute + self.other_field = "value" request = MockRequest() self.instrumentor._add_server_attributes(mock_span, "test_op", request) - # Should not set any attributes when no name mock_span.set_attribute.assert_not_called() - def test_extract_span_context_zero_trace_id(self) -> None: - """Test _extract_span_context_from_traceparent with zero trace_id""" - # Zero trace_id should still create a valid span context in OpenTelemetry + def test_extract_span_context_with_wrong_part_count(self) -> None: + """Test _extract_span_context_from_traceparent with wrong number of parts""" + result = self.instrumentor._extract_span_context_from_traceparent("00-12345") + self.assertIsNone(result) + + result = self.instrumentor._extract_span_context_from_traceparent("00-12345-67890-01-extra") + self.assertIsNone(result) + + def test_wrap_handle_request_with_hasattr_false(self) -> None: + """Test _wrap_handle_request when hasattr returns False for meta""" + + class MockRequestNoMetaAttr: + def __init__(self): + self.params = MockParamsNoMetaAttr() + + class MockParamsNoMetaAttr: + def __init__(self): + pass + + async def mock_wrapped(*args, **kwargs): + return {"result": "no_meta_attr"} + + request = MockRequestNoMetaAttr() + result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) + + self.assertEqual(result["result"], "no_meta_attr") + + def test_add_client_attributes_missing_params_attribute(self) -> None: + """Test _add_client_attributes when request has no params attribute""" + mock_span = MagicMock() + + class MockRequestNoParams: + pass + + request = MockRequestNoParams() + self.instrumentor._add_client_attributes(mock_span, "test_op", request) + + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + mock_span.set_attribute.assert_any_call("rpc.method", "test_op") + + def test_add_server_attributes_missing_params_attribute(self) -> None: + """Test _add_server_attributes when request has no params attribute""" + mock_span = MagicMock() + + class MockRequestNoParams: + pass + + request = MockRequestNoParams() + self.instrumentor._add_server_attributes(mock_span, "test_op", request) + + mock_span.set_attribute.assert_not_called() + + def test_inject_trace_context_existing_meta(self) -> None: + """Test _inject_trace_context when _meta already exists""" + request_data = {"params": {"_meta": {"existing_field": "should_be_preserved"}}} + span_ctx = SimpleSpanContext(trace_id=123, span_id=456) + + self.instrumentor._inject_trace_context(request_data, span_ctx) + + self.assertEqual(request_data["params"]["_meta"]["existing_field"], "should_be_preserved") + self.assertIn("traceparent", request_data["params"]["_meta"]) + + def test_extract_span_context_zero_values(self) -> None: + """Test _extract_span_context_from_traceparent with zero values""" result = self.instrumentor._extract_span_context_from_traceparent( "00-00000000000000000000000000000000-1234567890123456-01" ) @@ -1122,9 +765,6 @@ def test_extract_span_context_zero_trace_id(self) -> None: self.assertEqual(result.trace_id, 0) self.assertEqual(result.span_id, 0x1234567890123456) - def test_extract_span_context_zero_span_id(self) -> None: - """Test _extract_span_context_from_traceparent with zero span_id""" - # Zero span_id should still create a valid span context in OpenTelemetry result = self.instrumentor._extract_span_context_from_traceparent( "00-12345678901234567890123456789012-0000000000000000-01" ) @@ -1132,94 +772,155 @@ def test_extract_span_context_zero_span_id(self) -> None: self.assertEqual(result.trace_id, 0x12345678901234567890123456789012) self.assertEqual(result.span_id, 0) - def test_inject_trace_context_existing_meta(self) -> None: - """Test _inject_trace_context when _meta already exists""" - request_data = {"params": {"_meta": {"existing_field": "should_be_preserved"}}} - span_ctx = SimpleSpanContext(trace_id=123, span_id=456) - self.instrumentor._inject_trace_context(request_data, span_ctx) +class TestMCPCoverage(unittest.TestCase): + """Essential tests for missing coverage""" - # Should preserve existing _meta fields - self.assertEqual(request_data["params"]["_meta"]["existing_field"], "should_be_preserved") - self.assertIn("traceparent", request_data["params"]["_meta"]) + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() - def test_wrap_send_request_exception_handling(self) -> None: - """Test _wrap_send_request handles exceptions gracefully""" + def test_generate_mcp_attributes_server_side(self) -> None: + """Test server-side MCP attribute generation""" + mock_span = MagicMock() - async def mock_wrapped_that_raises(): - raise ValueError("Test exception") + class MockCallToolRequest: + def __init__(self): + self.params = MockParams() - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_span_context = MagicMock() - mock_span_context.trace_id = 12345 - mock_span_context.span_id = 67890 - mock_span.get_span_context.return_value = mock_span_context - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - self.instrumentor.tracer = mock_tracer + class MockParams: + def __init__(self): + self.name = "server_tool" - with self.assertRaises(ValueError): - asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped_that_raises, None, (), {})) + request = MockCallToolRequest() - def test_wrap_handle_request_exception_handling(self) -> None: - """Test _wrap_handle_request handles exceptions gracefully""" + def mock_isinstance(obj, cls): + return cls.__name__ == "CallToolRequest" - async def mock_wrapped_that_raises(*args, **kwargs): - raise ValueError("Test exception") + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance + ): + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - self.instrumentor.tracer = mock_tracer + sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest - with self.assertRaises(ValueError): - asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped_that_raises, None, ("session", None), {})) + self.instrumentor._generate_mcp_attributes(mock_span, request, False) - def test_instrumentation_dependencies_return_type(self) -> None: - """Test that instrumentation_dependencies returns a Collection""" - deps = self.instrumentor.instrumentation_dependencies() - self.assertIsInstance(deps, tuple) - self.assertEqual(len(deps), 1) - self.assertEqual(deps[0], "mcp >= 1.6.0") + # Should set both mcp.call_tool and mcp.tool.name attributes + mock_span.set_attribute.assert_any_call("mcp.call_tool", True) + mock_span.set_attribute.assert_any_call("mcp.tool.name", "server_tool") + mock_span.update_name.assert_not_called() - def test_instrument_with_none_tracer_provider(self) -> None: - """Test _instrument method when tracer_provider is None""" - with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer, unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" + def test_get_mcp_operation_list_tools(self) -> None: + """Test _get_mcp_operation for ListToolsRequest""" + + class MockListToolsRequest: + pass + + request = MockListToolsRequest() + + def mock_isinstance(obj, cls): + return cls.__name__ == "ListToolsRequest" + + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance ): - mock_get_tracer.return_value = "default_tracer" + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys - self.instrumentor._instrument(tracer_provider=None) + sys.modules["mcp.types"].ListToolsRequest = MockListToolsRequest - # Should use default tracer when tracer_provider is None - self.assertEqual(self.instrumentor.tracer, "default_tracer") - mock_get_tracer.assert_called_with("instrumentation.mcp") + result = self.instrumentor._get_mcp_operation(request) + self.assertEqual(result, "tools/list") - def test_add_client_attributes_missing_params_attribute(self) -> None: - """Test _add_client_attributes when request has no params attribute""" + def test_get_mcp_operation_call_tool(self) -> None: + """Test _get_mcp_operation for CallToolRequest""" + + class MockCallToolRequest: + def __init__(self): + self.params = MockParams() + + class MockParams: + def __init__(self): + self.name = "test_tool" + + request = MockCallToolRequest() + + def mock_isinstance(obj, cls): + return cls.__name__ == "CallToolRequest" + + with unittest.mock.patch( + "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance + ): + with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): + import sys + + sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest + + result = self.instrumentor._get_mcp_operation(request) + self.assertEqual(result, "tools/test_tool") + + def test_add_attributes_edge_cases(self) -> None: + """Test attribute setting edge cases""" mock_span = MagicMock() + # Test client attributes with no params class MockRequestNoParams: - pass # No params attribute at all + def __init__(self): + self.params = None request = MockRequestNoParams() self.instrumentor._add_client_attributes(mock_span, "test_op", request) - - # Should still set service and method mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - def test_add_server_attributes_missing_params_attribute(self) -> None: - """Test _add_server_attributes when request has no params attribute""" - mock_span = MagicMock() + # Test server attributes with no params + mock_span.reset_mock() + self.instrumentor._add_server_attributes(mock_span, "test_op", request) + mock_span.set_attribute.assert_not_called() - class MockRequestNoParams: - pass # No params attribute at all + def test_extract_span_context_edge_cases(self) -> None: + """Test span context extraction edge cases""" + # Test wrong part count + result = self.instrumentor._extract_span_context_from_traceparent("00-12345") + self.assertIsNone(result) - request = MockRequestNoParams() - self.instrumentor._add_server_attributes(mock_span, "test_op", request) + # Test invalid hex + result = self.instrumentor._extract_span_context_from_traceparent("00-invalid-hex-values-01") + self.assertIsNone(result) - # Should not set any attributes - mock_span.set_attribute.assert_not_called() + # Test zero values (should still work) + result = self.instrumentor._extract_span_context_from_traceparent( + "00-00000000000000000000000000000000-1234567890123456-01" + ) + self.assertIsNotNone(result) + self.assertEqual(result.trace_id, 0) + + def test_wrap_handle_request_no_meta_attr(self) -> None: + """Test _wrap_handle_request when meta attribute doesn't exist""" + + class MockRequestNoMetaAttr: + def __init__(self): + self.params = MockParamsNoMetaAttr() + + class MockParamsNoMetaAttr: + def __init__(self): + pass + + async def mock_wrapped(*args, **kwargs): + return {"result": "no_meta_attr"} + + request = MockRequestNoMetaAttr() + result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) + + self.assertEqual(result["result"], "no_meta_attr") + + def test_inject_trace_context_existing_meta(self) -> None: + """Test trace context injection preserves existing _meta""" + request_data = {"params": {"_meta": {"existing": "preserved"}}} + span_ctx = SimpleSpanContext(trace_id=123, span_id=456) + + self.instrumentor._inject_trace_context(request_data, span_ctx) + + self.assertEqual(request_data["params"]["_meta"]["existing"], "preserved") + self.assertIn("traceparent", request_data["params"]["_meta"]) From e6c479e25c007e6791d407870ebf2742c9ae0b73 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Sun, 3 Aug 2025 20:40:40 -0700 Subject: [PATCH 25/41] test coverage and lint --- .../distro/test_mcpinstrumentor.py | 849 +++++++++--------- 1 file changed, 428 insertions(+), 421 deletions(-) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 12d95065b..4e66bfbb0 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -6,15 +6,41 @@ """ import asyncio +import sys import unittest -from typing import Any, Dict, Optional -from unittest.mock import MagicMock +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch from amazon.opentelemetry.distro.instrumentation.mcp import version from amazon.opentelemetry.distro.instrumentation.mcp.constants import MCPEnvironmentVariables from amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor import MCPInstrumentor +# Mock the mcp module to prevent import errors +class MockMCPTypes: + class CallToolRequest: + pass + + class ListToolsRequest: + pass + + class InitializeRequest: + pass + + class ClientResult: + pass + + +mock_mcp_types = MockMCPTypes() +sys.modules["mcp"] = MagicMock() +sys.modules["mcp.types"] = mock_mcp_types +sys.modules["mcp.shared"] = MagicMock() +sys.modules["mcp.shared.session"] = MagicMock() +sys.modules["mcp.server"] = MagicMock() +sys.modules["mcp.server.lowlevel"] = MagicMock() +sys.modules["mcp.server.lowlevel.server"] = MagicMock() + + class SimpleSpanContext: """Simple mock span context without using MagicMock""" @@ -214,6 +240,262 @@ def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> Non self.assertEqual(client_data["params"]["arguments"]["metric_name"], "response_time") +class TestInstrumentedMCPServer(unittest.TestCase): + """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + # Initialize tracer so the instrumentor can work + mock_tracer = MagicMock() + self.instrumentor.tracer = mock_tracer + + def test_no_trace_context_fallback(self) -> None: + """Test graceful handling when no trace context is present on server side""" + + class MockServerNoTrace: + @staticmethod + async def _handle_request(session: Any, request: Any) -> Dict[str, Any]: + return {"success": True, "handled_without_trace": True} + + class MockServerRequestNoTrace: + def __init__(self, tool_name: str) -> None: + self.params = MockServerRequestParamsNoTrace(tool_name) + + class MockServerRequestParamsNoTrace: + def __init__(self, name: str) -> None: + self.name = name + self.meta: Optional[Any] = None # No trace context + + mock_server = MockServerNoTrace() + server_request = MockServerRequestNoTrace("create_metric") + + # Setup mocks + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + + # Test server handling without trace context (fallback scenario) + with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( + "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} + ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" + ): + + result = asyncio.run( + self.instrumentor._wrap_handle_request(mock_server._handle_request, None, (None, server_request), {}) + ) + + # Verify graceful fallback - no tracing spans should be created when no trace context + # The wrapper should call the original function without creating distributed trace spans + self.assertEqual(result["success"], True) + self.assertEqual(result["handled_without_trace"], True) + + # Should not create traced spans when no trace context is present + mock_tracer.start_as_current_span.assert_not_called() + + # pylint: disable=too-many-locals,too-many-statements + def test_end_to_end_client_server_communication( + self, + ) -> None: + """Test where server actually receives what client sends (including injected trace context)""" + + # Create realistic request/response classes + class MCPRequest: + def __init__( + self, tool_name: str, arguments: Optional[Dict[str, Any]] = None, method: str = "call_tool" + ) -> None: + self.root = self + self.params = MCPRequestParams(tool_name, arguments) + self.method = method + + def model_dump( + self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True + ) -> Dict[str, Any]: + result = {"method": self.method, "params": {"name": self.params.name}} + if self.params.arguments: + result["params"]["arguments"] = self.params.arguments + # Include _meta if it exists (for trace context) + if hasattr(self.params, "_meta") and self.params._meta: + result["params"]["_meta"] = self.params._meta + return result + + @classmethod + def model_validate(cls, data: Dict[str, Any]) -> "MCPRequest": + method = data.get("method", "call_tool") + instance = cls(data["params"]["name"], data["params"].get("arguments"), method) + # Restore _meta field if present + if "_meta" in data["params"]: + instance.params._meta = data["params"]["_meta"] + return instance + + class MCPRequestParams: + def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: + self.name = name + self.arguments = arguments + self._meta: Optional[Dict[str, Any]] = None + + class MCPServerRequest: + def __init__(self, client_request_data: Dict[str, Any]) -> None: + """Server request created from client's serialized data""" + self.method = client_request_data.get("method", "call_tool") + self.params = MCPServerRequestParams(client_request_data["params"]) + + class MCPServerRequestParams: + def __init__(self, params_data: Dict[str, Any]) -> None: + self.name = params_data["name"] + self.arguments = params_data.get("arguments") + # Extract traceparent from _meta if present + if "_meta" in params_data and "traceparent" in params_data["_meta"]: + self.meta = MCPServerRequestMeta(params_data["_meta"]["traceparent"]) + else: + self.meta = None + + class MCPServerRequestMeta: + def __init__(self, traceparent: str) -> None: + self.traceparent = traceparent + + # Mock client and server that actually communicate + class EndToEndMCPSystem: + def __init__(self) -> None: + self.communication_log: List[str] = [] + self.last_sent_request: Optional[Any] = None + + async def client_send_request(self, request: Any) -> Dict[str, Any]: + """Client sends request - captures what gets sent""" + self.communication_log.append("CLIENT: Preparing to send request") + self.last_sent_request = request # Capture the modified request + + # Simulate sending over network - serialize the request + serialized_request = request.model_dump() + self.communication_log.append(f"CLIENT: Sent {serialized_request}") + + # Return client response + return {"success": True, "client_response": "Request sent successfully"} + + async def server_handle_request(self, session: Any, server_request: Any) -> Dict[str, Any]: + """Server handles the request it received""" + self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") + + # Check if traceparent was received + if server_request.params.meta and server_request.params.meta.traceparent: + traceparent = server_request.params.meta.traceparent + # Parse traceparent to extract trace_id and span_id + parts = traceparent.split("-") + if len(parts) == 4: + trace_id = int(parts[1], 16) + span_id = int(parts[2], 16) + self.communication_log.append( + f"SERVER: Found trace context - trace_id: {trace_id}, " f"span_id: {span_id}" + ) + else: + self.communication_log.append("SERVER: Invalid traceparent format") + else: + self.communication_log.append("SERVER: No trace context found") + + return {"success": True, "server_response": f"Handled {server_request.params.name}"} + + # Create the end-to-end system + e2e_system = EndToEndMCPSystem() + + # Create original client request + original_request = MCPRequest("create_metric", {"name": "cpu_usage", "value": 85}) + + # Setup OpenTelemetry mocks + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_span_context = MagicMock() + mock_span_context.trace_id = 12345 + mock_span_context.span_id = 67890 + mock_span.get_span_context.return_value = mock_span_context + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + + # STEP 1: Client sends request through instrumentation + with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( + "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} + ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): + # Override the setup tracer with the properly mocked one + self.instrumentor.tracer = mock_tracer + + client_result = asyncio.run( + self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {}) + ) + + # Verify client side worked + self.assertEqual(client_result["success"], True) + self.assertIn("CLIENT: Preparing to send request", e2e_system.communication_log) + + # Get the request that was actually sent (with trace context injected) + sent_request = e2e_system.last_sent_request + sent_request_data = sent_request.model_dump() + + # Verify traceparent was injected by client instrumentation + self.assertIn("_meta", sent_request_data["params"]) + self.assertIn("traceparent", sent_request_data["params"]["_meta"]) + + # Parse and verify traceparent contains correct trace/span IDs + traceparent = sent_request_data["params"]["_meta"]["traceparent"] + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 12345) # trace_id + self.assertEqual(int(parts[2], 16), 67890) # span_id + + # STEP 2: Server receives the EXACT request that client sent + # Create server request from the client's serialized data + server_request = MCPServerRequest(sent_request_data) + + # Reset tracer mock for server side + mock_tracer.reset_mock() + + # Server processes the request it received + with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( + "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} + ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( + self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" + ): + + server_result = asyncio.run( + self.instrumentor._wrap_handle_request( + e2e_system.server_handle_request, None, (None, server_request), {} + ) + ) + + # Verify server side worked + self.assertEqual(server_result["success"], True) + + # Verify end-to-end trace context propagation + self.assertIn("SERVER: Found trace context - trace_id: 12345, span_id: 67890", e2e_system.communication_log) + + # Verify the server received the exact same data the client sent + self.assertEqual(server_request.params.name, "create_metric") + self.assertEqual(server_request.params.arguments["name"], "cpu_usage") + self.assertEqual(server_request.params.arguments["value"], 85) + + # Verify the traceparent made it through end-to-end + self.assertIsNotNone(server_request.params.meta) + self.assertIsNotNone(server_request.params.meta.traceparent) + + # Parse traceparent and verify trace/span IDs + traceparent = server_request.params.meta.traceparent + parts = traceparent.split("-") + self.assertEqual(int(parts[1], 16), 12345) # trace_id + self.assertEqual(int(parts[2], 16), 67890) # span_id + + # Verify complete communication flow + expected_log_entries = [ + "CLIENT: Preparing to send request", + "CLIENT: Sent", # Part of the serialized request log + "SERVER: Received request for create_metric", + "SERVER: Found trace context - trace_id: 12345, span_id: 67890", + ] + + for expected_entry in expected_log_entries: + self.assertTrue( + any(expected_entry in log_entry for log_entry in e2e_system.communication_log), + f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", + ) + + class TestMCPInstrumentorEdgeCases(unittest.TestCase): """Test edge cases and error conditions for MCP instrumentor""" @@ -317,8 +599,6 @@ def test_uninstrument(self) -> None: "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" ) as mock_unwrap: self.instrumentor._uninstrument() - - # Verify both unwrap calls are made self.assertEqual(mock_unwrap.call_count, 2) mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") @@ -328,7 +608,6 @@ def test_extract_span_context_valid_traceparent(self) -> None: # Use correct hex values: 12345 = 0x3039, 67890 = 0x10932 valid_traceparent = "00-0000000000003039-0000000000010932-01" result = self.instrumentor._extract_span_context_from_traceparent(valid_traceparent) - self.assertIsNotNone(result) self.assertEqual(result.trace_id, 12345) self.assertEqual(result.span_id, 67890) @@ -338,7 +617,6 @@ def test_extract_span_context_value_error(self) -> None: """Test _extract_span_context_from_traceparent with invalid hex values""" invalid_hex_traceparent = "00-invalid-hex-values-01" result = self.instrumentor._extract_span_context_from_traceparent(invalid_hex_traceparent) - self.assertIsNone(result) def test_instrument_method_coverage(self) -> None: @@ -347,71 +625,87 @@ def test_instrument_method_coverage(self) -> None: "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" ) as mock_register: self.instrumentor._instrument() - - # Should register two hooks self.assertEqual(mock_register.call_count, 2) -class TestGenerateMCPAttributes(unittest.TestCase): - """Test _generate_mcp_attributes method with mocked imports""" +class TestWrapSendRequestEdgeCases(unittest.TestCase): + """Test _wrap_send_request edge cases""" def setUp(self) -> None: self.instrumentor = MCPInstrumentor() - - def test_generate_attributes_with_mock_types(self) -> None: - """Test _generate_mcp_attributes with mocked MCP types""" + mock_tracer = MagicMock() mock_span = MagicMock() + mock_span_context = MagicMock() + mock_span_context.trace_id = 12345 + mock_span_context.span_id = 67890 + mock_span.get_span_context.return_value = mock_span_context + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + self.instrumentor.tracer = mock_tracer + + def test_wrap_send_request_no_request_in_args(self) -> None: + """Test _wrap_send_request when no request in args""" + + async def mock_wrapped(): + return {"result": "no_request"} + + result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {})) + self.assertEqual(result["result"], "no_request") + + def test_wrap_send_request_request_in_kwargs(self) -> None: + """Test _wrap_send_request when request is in kwargs""" class MockRequest: def __init__(self): + self.root = self self.params = MockParams() + def model_dump(self, **kwargs): + return {"method": "test", "params": {"name": "test_tool"}} + + @classmethod + def model_validate(cls, data): + return cls() + class MockParams: def __init__(self): self.name = "test_tool" - request = MockRequest() - - # Mock the isinstance checks to avoid importing mcp.types - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance" - ) as mock_isinstance: - mock_isinstance.side_effect = lambda obj, cls: False # No matches - - self.instrumentor._generate_mcp_attributes(mock_span, request, True) - - # Should call _add_client_attributes since is_client=True - self.assertTrue(mock_isinstance.called) - + async def mock_wrapped(*args, **kwargs): + return {"result": "kwargs_request"} -class TestGetMCPOperation(unittest.TestCase): - """Test _get_mcp_operation method with mocked imports""" + request = MockRequest() - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() + with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): + result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {"request": request})) + self.assertEqual(result["result"], "kwargs_request") - def test_get_operation_with_mock_types(self) -> None: - """Test _get_mcp_operation with mocked MCP types""" + def test_wrap_send_request_no_root_attribute(self) -> None: + """Test _wrap_send_request when request has no root attribute""" - class MockRequest: + class MockRequestNoRoot: def __init__(self): self.params = MockParams() + def model_dump(self, **kwargs): + return {"method": "test", "params": {"name": "test_tool"}} + + @classmethod + def model_validate(cls, data): + return cls() + class MockParams: def __init__(self): self.name = "test_tool" - request = MockRequest() - - # Mock the isinstance checks to return unknown - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance" - ) as mock_isinstance: - mock_isinstance.side_effect = lambda obj, cls: False # No matches + async def mock_wrapped(*args, **kwargs): + return {"result": "no_root"} - result = self.instrumentor._get_mcp_operation(request) + request = MockRequestNoRoot() - self.assertEqual(result, "unknown") + with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): + result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (request,), {})) + self.assertEqual(result["result"], "no_root") class TestWrapHandleRequestEdgeCases(unittest.TestCase): @@ -432,7 +726,6 @@ async def mock_wrapped(*args, **kwargs): return {"result": "no_request"} result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session",), {})) - self.assertEqual(result["result"], "no_request") def test_wrap_handle_request_no_params(self) -> None: @@ -447,7 +740,6 @@ async def mock_wrapped(*args, **kwargs): request = MockRequestNoParams() result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - self.assertEqual(result["result"], "no_params") def test_wrap_handle_request_no_meta(self) -> None: @@ -466,7 +758,6 @@ async def mock_wrapped(*args, **kwargs): request = MockRequestNoMeta() result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - self.assertEqual(result["result"], "no_meta") def test_wrap_handle_request_with_valid_traceparent(self) -> None: @@ -493,147 +784,39 @@ async def mock_wrapped(*args, **kwargs): self.instrumentor, "_get_mcp_operation", return_value="tools/test" ): result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - self.assertEqual(result["result"], "with_trace") class TestInstrumentorStaticMethods(unittest.TestCase): """Test static methods of MCPInstrumentor""" - @staticmethod - def test_instrumentation_dependencies_static() -> None: + def test_instrumentation_dependencies_static(self) -> None: """Test instrumentation_dependencies as static method""" deps = MCPInstrumentor.instrumentation_dependencies() - assert "mcp >= 1.6.0" in deps + self.assertIn("mcp >= 1.6.0", deps) - @staticmethod - def test_uninstrument_static() -> None: + def test_uninstrument_static(self) -> None: """Test _uninstrument as static method""" with unittest.mock.patch( "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" ) as mock_unwrap: MCPInstrumentor._uninstrument() - - assert mock_unwrap.call_count == 2 + self.assertEqual(mock_unwrap.call_count, 2) mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") -class TestMCPInstrumentorMissingCoverage(unittest.TestCase): - """Tests targeting specific uncovered lines in MCPInstrumentor""" +class TestEnvironmentVariableHandling(unittest.TestCase): + """Test environment variable handling""" def setUp(self) -> None: self.instrumentor = MCPInstrumentor() - def test_generate_mcp_attributes_list_tools_server_side(self) -> None: - """Test _generate_mcp_attributes for ListToolsRequest on server side""" - mock_span = MagicMock() - - class MockListToolsRequest: - pass - - request = MockListToolsRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "ListToolsRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].ListToolsRequest = MockListToolsRequest - - self.instrumentor._generate_mcp_attributes(mock_span, request, False) - - mock_span.set_attribute.assert_called_with("mcp.list_tools", True) - mock_span.update_name.assert_not_called() - - def test_generate_mcp_attributes_initialize_server_side(self) -> None: - """Test _generate_mcp_attributes for InitializeRequest on server side""" - mock_span = MagicMock() - - class MockInitializeRequest: - pass - - request = MockInitializeRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "InitializeRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].InitializeRequest = MockInitializeRequest - - self.instrumentor._generate_mcp_attributes(mock_span, request, False) - - mock_span.set_attribute.assert_called_with("notifications/initialize", True) - mock_span.update_name.assert_not_called() - - def test_generate_mcp_attributes_call_tool_server_side(self) -> None: - """Test _generate_mcp_attributes for CallToolRequest on server side""" + def test_server_name_environment_variable(self) -> None: + """Test that MCP_INSTRUMENTATION_SERVER_NAME environment variable is used""" mock_span = MagicMock() - class MockCallToolRequest: - def __init__(self): - self.params = MockParams() - - class MockParams: - def __init__(self): - self.name = "server_tool" - - request = MockCallToolRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "CallToolRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest - - self.instrumentor._generate_mcp_attributes(mock_span, request, False) - - # Should set both mcp.call_tool and mcp.tool.name attributes - mock_span.set_attribute.assert_any_call("mcp.call_tool", True) - mock_span.set_attribute.assert_any_call("mcp.tool.name", "server_tool") - mock_span.update_name.assert_not_called() - - def test_get_mcp_operation_list_tools_request(self) -> None: - """Test _get_mcp_operation for ListToolsRequest""" - - class MockListToolsRequest: - pass - - request = MockListToolsRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "ListToolsRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].ListToolsRequest = MockListToolsRequest - - result = self.instrumentor._get_mcp_operation(request) - - self.assertEqual(result, "tools/list") - - def test_get_mcp_operation_call_tool_request(self) -> None: - """Test _get_mcp_operation for CallToolRequest""" - - class MockCallToolRequest: + class MockRequest: def __init__(self): self.params = MockParams() @@ -641,202 +824,17 @@ class MockParams: def __init__(self): self.name = "test_tool" - request = MockCallToolRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "CallToolRequest" + # Test with environment variable set + with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "my-custom-server"}): + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + mock_span.set_attribute.assert_any_call("rpc.service", "my-custom-server") - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest - - result = self.instrumentor._get_mcp_operation(request) - - self.assertEqual(result, "tools/test_tool") - - def test_add_client_attributes_with_params_but_no_name(self) -> None: - """Test _add_client_attributes when params exists but has no name""" + def test_server_name_default_value(self) -> None: + """Test that default server name is used when environment variable is not set""" mock_span = MagicMock() class MockRequest: - def __init__(self): - self.params = MockParamsNoName() - - class MockParamsNoName: - def __init__(self): - self.other_field = "value" - - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - self.assertEqual(mock_span.set_attribute.call_count, 2) - - def test_add_server_attributes_with_params_but_no_name(self) -> None: - """Test _add_server_attributes when params exists but has no name""" - mock_span = MagicMock() - - class MockRequest: - def __init__(self): - self.params = MockParamsNoName() - - class MockParamsNoName: - def __init__(self): - self.other_field = "value" - - request = MockRequest() - self.instrumentor._add_server_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_not_called() - - def test_extract_span_context_with_wrong_part_count(self) -> None: - """Test _extract_span_context_from_traceparent with wrong number of parts""" - result = self.instrumentor._extract_span_context_from_traceparent("00-12345") - self.assertIsNone(result) - - result = self.instrumentor._extract_span_context_from_traceparent("00-12345-67890-01-extra") - self.assertIsNone(result) - - def test_wrap_handle_request_with_hasattr_false(self) -> None: - """Test _wrap_handle_request when hasattr returns False for meta""" - - class MockRequestNoMetaAttr: - def __init__(self): - self.params = MockParamsNoMetaAttr() - - class MockParamsNoMetaAttr: - def __init__(self): - pass - - async def mock_wrapped(*args, **kwargs): - return {"result": "no_meta_attr"} - - request = MockRequestNoMetaAttr() - result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - - self.assertEqual(result["result"], "no_meta_attr") - - def test_add_client_attributes_missing_params_attribute(self) -> None: - """Test _add_client_attributes when request has no params attribute""" - mock_span = MagicMock() - - class MockRequestNoParams: - pass - - request = MockRequestNoParams() - self.instrumentor._add_client_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - - def test_add_server_attributes_missing_params_attribute(self) -> None: - """Test _add_server_attributes when request has no params attribute""" - mock_span = MagicMock() - - class MockRequestNoParams: - pass - - request = MockRequestNoParams() - self.instrumentor._add_server_attributes(mock_span, "test_op", request) - - mock_span.set_attribute.assert_not_called() - - def test_inject_trace_context_existing_meta(self) -> None: - """Test _inject_trace_context when _meta already exists""" - request_data = {"params": {"_meta": {"existing_field": "should_be_preserved"}}} - span_ctx = SimpleSpanContext(trace_id=123, span_id=456) - - self.instrumentor._inject_trace_context(request_data, span_ctx) - - self.assertEqual(request_data["params"]["_meta"]["existing_field"], "should_be_preserved") - self.assertIn("traceparent", request_data["params"]["_meta"]) - - def test_extract_span_context_zero_values(self) -> None: - """Test _extract_span_context_from_traceparent with zero values""" - result = self.instrumentor._extract_span_context_from_traceparent( - "00-00000000000000000000000000000000-1234567890123456-01" - ) - self.assertIsNotNone(result) - self.assertEqual(result.trace_id, 0) - self.assertEqual(result.span_id, 0x1234567890123456) - - result = self.instrumentor._extract_span_context_from_traceparent( - "00-12345678901234567890123456789012-0000000000000000-01" - ) - self.assertIsNotNone(result) - self.assertEqual(result.trace_id, 0x12345678901234567890123456789012) - self.assertEqual(result.span_id, 0) - - -class TestMCPCoverage(unittest.TestCase): - """Essential tests for missing coverage""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_generate_mcp_attributes_server_side(self) -> None: - """Test server-side MCP attribute generation""" - mock_span = MagicMock() - - class MockCallToolRequest: - def __init__(self): - self.params = MockParams() - - class MockParams: - def __init__(self): - self.name = "server_tool" - - request = MockCallToolRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "CallToolRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest - - self.instrumentor._generate_mcp_attributes(mock_span, request, False) - - # Should set both mcp.call_tool and mcp.tool.name attributes - mock_span.set_attribute.assert_any_call("mcp.call_tool", True) - mock_span.set_attribute.assert_any_call("mcp.tool.name", "server_tool") - mock_span.update_name.assert_not_called() - - def test_get_mcp_operation_list_tools(self) -> None: - """Test _get_mcp_operation for ListToolsRequest""" - - class MockListToolsRequest: - pass - - request = MockListToolsRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "ListToolsRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].ListToolsRequest = MockListToolsRequest - - result = self.instrumentor._get_mcp_operation(request) - self.assertEqual(result, "tools/list") - - def test_get_mcp_operation_call_tool(self) -> None: - """Test _get_mcp_operation for CallToolRequest""" - - class MockCallToolRequest: def __init__(self): self.params = MockParams() @@ -844,83 +842,92 @@ class MockParams: def __init__(self): self.name = "test_tool" - request = MockCallToolRequest() - - def mock_isinstance(obj, cls): - return cls.__name__ == "CallToolRequest" - - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.isinstance", side_effect=mock_isinstance - ): - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - import sys - - sys.modules["mcp.types"].CallToolRequest = MockCallToolRequest + # Test without environment variable + with unittest.mock.patch.dict("os.environ", {}, clear=True): + request = MockRequest() + self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - result = self.instrumentor._get_mcp_operation(request) - self.assertEqual(result, "tools/test_tool") - - def test_add_attributes_edge_cases(self) -> None: - """Test attribute setting edge cases""" - mock_span = MagicMock() - - # Test client attributes with no params - class MockRequestNoParams: - def __init__(self): - self.params = None - - request = MockRequestNoParams() - self.instrumentor._add_client_attributes(mock_span, "test_op", request) - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_op") - - # Test server attributes with no params - mock_span.reset_mock() - self.instrumentor._add_server_attributes(mock_span, "test_op", request) - mock_span.set_attribute.assert_not_called() - def test_extract_span_context_edge_cases(self) -> None: - """Test span context extraction edge cases""" - # Test wrong part count - result = self.instrumentor._extract_span_context_from_traceparent("00-12345") - self.assertIsNone(result) +class TestTraceContextFormats(unittest.TestCase): + """Test trace context format handling""" - # Test invalid hex - result = self.instrumentor._extract_span_context_from_traceparent("00-invalid-hex-values-01") - self.assertIsNone(result) + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() - # Test zero values (should still work) - result = self.instrumentor._extract_span_context_from_traceparent( - "00-00000000000000000000000000000000-1234567890123456-01" - ) - self.assertIsNotNone(result) - self.assertEqual(result.trace_id, 0) + def test_inject_trace_context_format(self) -> None: + """Test that injected trace context follows W3C format""" + request_data = {} + span_ctx = SimpleSpanContext(trace_id=0x12345678901234567890123456789012, span_id=0x1234567890123456) + self.instrumentor._inject_trace_context(request_data, span_ctx) + traceparent = request_data["params"]["_meta"]["traceparent"] + self.assertEqual(traceparent, "00-12345678901234567890123456789012-1234567890123456-01") - def test_wrap_handle_request_no_meta_attr(self) -> None: - """Test _wrap_handle_request when meta attribute doesn't exist""" - class MockRequestNoMetaAttr: - def __init__(self): - self.params = MockParamsNoMetaAttr() +class FakeParams: + def __init__(self, name=None, meta=None): + self.name = name + self.meta = meta - class MockParamsNoMetaAttr: - def __init__(self): - pass - async def mock_wrapped(*args, **kwargs): - return {"result": "no_meta_attr"} +class FakeRequest: + def __init__(self, params=None): + self.params = params + self.root = self - request = MockRequestNoMetaAttr() - result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) + def model_dump(self, **kwargs): + return {"method": "call_tool", "params": {"name": self.params.name if self.params else None}} - self.assertEqual(result["result"], "no_meta_attr") + @classmethod + def model_validate(cls, data): + return cls(params=FakeParams(name=data["params"].get("name"))) - def test_inject_trace_context_existing_meta(self) -> None: - """Test trace context injection preserves existing _meta""" - request_data = {"params": {"_meta": {"existing": "preserved"}}} - span_ctx = SimpleSpanContext(trace_id=123, span_id=456) - self.instrumentor._inject_trace_context(request_data, span_ctx) +class TestAdditionalCoverage(unittest.TestCase): + def setUp(self): + self.instrumentor = MCPInstrumentor() + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_span_context = MagicMock(trace_id=123, span_id=456) + mock_span.get_span_context.return_value = mock_span_context + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + self.instrumentor.tracer = mock_tracer - self.assertEqual(request_data["params"]["_meta"]["existing"], "preserved") - self.assertIn("traceparent", request_data["params"]["_meta"]) + def test_wrap_send_request_kwargs_only_with_extra_args(self): + """Covers case where kwargs['request'] is set and args has extra elements.""" + + async def dummy_wrapped(*args, **kwargs): + return {"ok": True} + + req = FakeRequest(params=FakeParams("tool1")) + with patch.object(self.instrumentor, "_generate_mcp_attributes"): + result = asyncio.run(self.instrumentor._wrap_send_request(dummy_wrapped, None, (), {"request": req})) + self.assertTrue(result["ok"]) + + def test_wrap_handle_request_invalid_traceparent_path(self): + """Covers case where traceparent exists but is invalid -> no span created.""" + bad_meta = type("Meta", (), {"traceparent": "invalid-format"}) + req = FakeRequest(params=FakeParams(name="tool1", meta=bad_meta)) + + async def dummy_wrapped(*args, **kwargs): + return {"ok": True} + + result = asyncio.run(self.instrumentor._wrap_handle_request(dummy_wrapped, None, (None, req), {})) + self.assertTrue(result["ok"]) + self.instrumentor.tracer.start_as_current_span.assert_not_called() + + def test_add_client_attributes_missing_name_env_var(self): + """Covers _add_client_attributes with env var set but no name.""" + span = MagicMock() + req = FakeRequest(params=FakeParams(name=None)) + with patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "env_server"}): + self.instrumentor._add_client_attributes(span, "op", req) + span.set_attribute.assert_any_call("rpc.service", "env_server") + + def test_add_server_attributes_name_missing(self): + """Covers _add_server_attributes when .params exists but .name is None.""" + span = MagicMock() + req = FakeRequest(params=FakeParams(name=None)) + self.instrumentor._add_server_attributes(span, "op", req) + # The method should still be called but with None value, so we just verify it was called + span.set_attribute.assert_called_once() From 1f4f47602a94fa10f26d22ff36f5919ee60d7faa Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Sun, 3 Aug 2025 21:39:21 -0700 Subject: [PATCH 26/41] test coverage and lint --- .../instrumentation/mcp/mcp_instrumentor.py | 6 +- .../distro/test_mcpinstrumentor.py | 102 +++++++++++++++++- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py index b12961064..2090f64e6 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Callable, Collection, Dict, Tuple +import mcp.types as types from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace @@ -128,8 +129,6 @@ async def _wrap_handle_request( @staticmethod def _generate_mcp_attributes(span: trace.Span, request: Any, is_client: bool) -> None: - import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - operation = MCPOperations.UNKNOWN_OPERATION if isinstance(request, types.ListToolsRequest): @@ -185,9 +184,6 @@ def _extract_span_context_from_traceparent(traceparent: str): @staticmethod def _get_mcp_operation(req: Any) -> str: - - import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - span_name = "unknown" if isinstance(req, types.ListToolsRequest): diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 4e66bfbb0..6f2af29ca 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -660,7 +660,8 @@ def __init__(self): self.root = self self.params = MockParams() - def model_dump(self, **kwargs): + @staticmethod + def model_dump(**kwargs): return {"method": "test", "params": {"name": "test_tool"}} @classmethod @@ -687,7 +688,8 @@ class MockRequestNoRoot: def __init__(self): self.params = MockParams() - def model_dump(self, **kwargs): + @staticmethod + def model_dump(**kwargs): return {"method": "test", "params": {"name": "test_tool"}} @classmethod @@ -883,6 +885,102 @@ def model_validate(cls, data): return cls(params=FakeParams(name=data["params"].get("name"))) +class TestMCPTypesCoverage(unittest.TestCase): + """Test isinstance checks in _generate_mcp_attributes and _get_mcp_operation""" + + def setUp(self) -> None: + self.instrumentor = MCPInstrumentor() + self.mock_span = MagicMock() + + def test_generate_mcp_attributes_list_tools(self) -> None: + """Test _generate_mcp_attributes with ListToolsRequest""" + from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPAttributes + + class MockListToolsRequest: + pass + + with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: + mock_types.ListToolsRequest = MockListToolsRequest + mock_types.CallToolRequest = type("CallToolRequest", (), {}) + mock_types.InitializeRequest = type("InitializeRequest", (), {}) + + request = MockListToolsRequest() + self.instrumentor._generate_mcp_attributes(self.mock_span, request, is_client=True) + self.mock_span.set_attribute.assert_any_call(MCPAttributes.MCP_LIST_TOOLS, True) + + def test_generate_mcp_attributes_call_tool(self) -> None: + """Test _generate_mcp_attributes with CallToolRequest""" + from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPAttributes + + class MockCallToolRequest: + def __init__(self): + self.params = type("Params", (), {"name": "test_tool"})() + + with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: + mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) + mock_types.CallToolRequest = MockCallToolRequest + mock_types.InitializeRequest = type("InitializeRequest", (), {}) + + request = MockCallToolRequest() + self.instrumentor._generate_mcp_attributes(self.mock_span, request, is_client=True) + self.mock_span.set_attribute.assert_any_call(MCPAttributes.MCP_CALL_TOOL, True) + + def test_generate_mcp_attributes_initialize(self) -> None: + """Test _generate_mcp_attributes with InitializeRequest""" + from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPAttributes + + class MockInitializeRequest: + pass + + with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: + mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) + mock_types.CallToolRequest = type("CallToolRequest", (), {}) + mock_types.InitializeRequest = MockInitializeRequest + + request = MockInitializeRequest() + self.instrumentor._generate_mcp_attributes(self.mock_span, request, is_client=True) + self.mock_span.set_attribute.assert_any_call(MCPAttributes.MCP_INITIALIZE, True) + + def test_get_mcp_operation_list_tools(self) -> None: + """Test _get_mcp_operation with ListToolsRequest""" + + class MockListToolsRequest: + pass + + with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: + mock_types.ListToolsRequest = MockListToolsRequest + mock_types.CallToolRequest = type("CallToolRequest", (), {}) + + request = MockListToolsRequest() + result = self.instrumentor._get_mcp_operation(request) + self.assertEqual(result, "tools/list") + + def test_get_mcp_operation_call_tool(self) -> None: + """Test _get_mcp_operation with CallToolRequest""" + + class MockCallToolRequest: + def __init__(self): + self.params = type("Params", (), {"name": "my_tool"})() + + with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: + mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) + mock_types.CallToolRequest = MockCallToolRequest + + request = MockCallToolRequest() + result = self.instrumentor._get_mcp_operation(request) + self.assertEqual(result, "tools/my_tool") + + def test_get_mcp_operation_unknown(self) -> None: + """Test _get_mcp_operation with unknown request type""" + with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: + mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) + mock_types.CallToolRequest = type("CallToolRequest", (), {}) + + unknown_request = object() + result = self.instrumentor._get_mcp_operation(unknown_request) + self.assertEqual(result, "unknown") + + class TestAdditionalCoverage(unittest.TestCase): def setUp(self): self.instrumentor = MCPInstrumentor() From 8621644c88f8f78bd04461dc241fdf65a1419f8c Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 4 Aug 2025 01:10:57 -0700 Subject: [PATCH 27/41] test coverage and lint --- .../instrumentation/mcp/mcp_instrumentor.py | 6 +- .../distro/test_mcpinstrumentor.py | 162 ++++-------------- 2 files changed, 43 insertions(+), 125 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py index 2090f64e6..b12961064 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Callable, Collection, Dict, Tuple -import mcp.types as types from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace @@ -129,6 +128,8 @@ async def _wrap_handle_request( @staticmethod def _generate_mcp_attributes(span: trace.Span, request: Any, is_client: bool) -> None: + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + operation = MCPOperations.UNKNOWN_OPERATION if isinstance(request, types.ListToolsRequest): @@ -184,6 +185,9 @@ def _extract_span_context_from_traceparent(traceparent: str): @staticmethod def _get_mcp_operation(req: Any) -> str: + + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + span_name = "unknown" if isinstance(req, types.ListToolsRequest): diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py index 6f2af29ca..b5ae9c78a 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py @@ -9,7 +9,7 @@ import sys import unittest from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from amazon.opentelemetry.distro.instrumentation.mcp import version from amazon.opentelemetry.distro.instrumentation.mcp.constants import MCPEnvironmentVariables @@ -885,147 +885,61 @@ def model_validate(cls, data): return cls(params=FakeParams(name=data["params"].get("name"))) -class TestMCPTypesCoverage(unittest.TestCase): - """Test isinstance checks in _generate_mcp_attributes and _get_mcp_operation""" +class TestGetMCPOperation(unittest.TestCase): + """Test _get_mcp_operation function with patched types""" - def setUp(self) -> None: + def setUp(self): self.instrumentor = MCPInstrumentor() - self.mock_span = MagicMock() - def test_generate_mcp_attributes_list_tools(self) -> None: - """Test _generate_mcp_attributes with ListToolsRequest""" - from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPAttributes + @unittest.mock.patch("mcp.types") + def test_get_mcp_operation_coverage(self, mock_types): + """Test _get_mcp_operation with all request types""" - class MockListToolsRequest: + # Create actual classes for isinstance checks + class FakeListToolsRequest: pass - with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: - mock_types.ListToolsRequest = MockListToolsRequest - mock_types.CallToolRequest = type("CallToolRequest", (), {}) - mock_types.InitializeRequest = type("InitializeRequest", (), {}) - - request = MockListToolsRequest() - self.instrumentor._generate_mcp_attributes(self.mock_span, request, is_client=True) - self.mock_span.set_attribute.assert_any_call(MCPAttributes.MCP_LIST_TOOLS, True) - - def test_generate_mcp_attributes_call_tool(self) -> None: - """Test _generate_mcp_attributes with CallToolRequest""" - from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPAttributes - - class MockCallToolRequest: + class FakeCallToolRequest: def __init__(self): self.params = type("Params", (), {"name": "test_tool"})() - with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: - mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) - mock_types.CallToolRequest = MockCallToolRequest - mock_types.InitializeRequest = type("InitializeRequest", (), {}) - - request = MockCallToolRequest() - self.instrumentor._generate_mcp_attributes(self.mock_span, request, is_client=True) - self.mock_span.set_attribute.assert_any_call(MCPAttributes.MCP_CALL_TOOL, True) + # Set up mock types + mock_types.ListToolsRequest = FakeListToolsRequest + mock_types.CallToolRequest = FakeCallToolRequest - def test_generate_mcp_attributes_initialize(self) -> None: - """Test _generate_mcp_attributes with InitializeRequest""" - from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPAttributes + # Test ListToolsRequest path + result1 = self.instrumentor._get_mcp_operation(FakeListToolsRequest()) + self.assertEqual(result1, "tools/list") - class MockInitializeRequest: - pass - - with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: - mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) - mock_types.CallToolRequest = type("CallToolRequest", (), {}) - mock_types.InitializeRequest = MockInitializeRequest + # Test CallToolRequest path + result2 = self.instrumentor._get_mcp_operation(FakeCallToolRequest()) + self.assertEqual(result2, "tools/test_tool") - request = MockInitializeRequest() - self.instrumentor._generate_mcp_attributes(self.mock_span, request, is_client=True) - self.mock_span.set_attribute.assert_any_call(MCPAttributes.MCP_INITIALIZE, True) + # Test unknown request type + result3 = self.instrumentor._get_mcp_operation(object()) + self.assertEqual(result3, "unknown") - def test_get_mcp_operation_list_tools(self) -> None: - """Test _get_mcp_operation with ListToolsRequest""" + @unittest.mock.patch("mcp.types") + def test_generate_mcp_attributes_coverage(self, mock_types): + """Test _generate_mcp_attributes with all request types""" - class MockListToolsRequest: + class FakeListToolsRequest: pass - with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: - mock_types.ListToolsRequest = MockListToolsRequest - mock_types.CallToolRequest = type("CallToolRequest", (), {}) - - request = MockListToolsRequest() - result = self.instrumentor._get_mcp_operation(request) - self.assertEqual(result, "tools/list") - - def test_get_mcp_operation_call_tool(self) -> None: - """Test _get_mcp_operation with CallToolRequest""" - - class MockCallToolRequest: + class FakeCallToolRequest: def __init__(self): - self.params = type("Params", (), {"name": "my_tool"})() - - with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: - mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) - mock_types.CallToolRequest = MockCallToolRequest - - request = MockCallToolRequest() - result = self.instrumentor._get_mcp_operation(request) - self.assertEqual(result, "tools/my_tool") - - def test_get_mcp_operation_unknown(self) -> None: - """Test _get_mcp_operation with unknown request type""" - with patch("amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.types") as mock_types: - mock_types.ListToolsRequest = type("ListToolsRequest", (), {}) - mock_types.CallToolRequest = type("CallToolRequest", (), {}) + self.params = type("Params", (), {"name": "test_tool"})() - unknown_request = object() - result = self.instrumentor._get_mcp_operation(unknown_request) - self.assertEqual(result, "unknown") + class FakeInitializeRequest: + pass + mock_types.ListToolsRequest = FakeListToolsRequest + mock_types.CallToolRequest = FakeCallToolRequest + mock_types.InitializeRequest = FakeInitializeRequest -class TestAdditionalCoverage(unittest.TestCase): - def setUp(self): - self.instrumentor = MCPInstrumentor() - mock_tracer = MagicMock() mock_span = MagicMock() - mock_span_context = MagicMock(trace_id=123, span_id=456) - mock_span.get_span_context.return_value = mock_span_context - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - self.instrumentor.tracer = mock_tracer - - def test_wrap_send_request_kwargs_only_with_extra_args(self): - """Covers case where kwargs['request'] is set and args has extra elements.""" - - async def dummy_wrapped(*args, **kwargs): - return {"ok": True} - - req = FakeRequest(params=FakeParams("tool1")) - with patch.object(self.instrumentor, "_generate_mcp_attributes"): - result = asyncio.run(self.instrumentor._wrap_send_request(dummy_wrapped, None, (), {"request": req})) - self.assertTrue(result["ok"]) - - def test_wrap_handle_request_invalid_traceparent_path(self): - """Covers case where traceparent exists but is invalid -> no span created.""" - bad_meta = type("Meta", (), {"traceparent": "invalid-format"}) - req = FakeRequest(params=FakeParams(name="tool1", meta=bad_meta)) - - async def dummy_wrapped(*args, **kwargs): - return {"ok": True} - - result = asyncio.run(self.instrumentor._wrap_handle_request(dummy_wrapped, None, (None, req), {})) - self.assertTrue(result["ok"]) - self.instrumentor.tracer.start_as_current_span.assert_not_called() - - def test_add_client_attributes_missing_name_env_var(self): - """Covers _add_client_attributes with env var set but no name.""" - span = MagicMock() - req = FakeRequest(params=FakeParams(name=None)) - with patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "env_server"}): - self.instrumentor._add_client_attributes(span, "op", req) - span.set_attribute.assert_any_call("rpc.service", "env_server") - - def test_add_server_attributes_name_missing(self): - """Covers _add_server_attributes when .params exists but .name is None.""" - span = MagicMock() - req = FakeRequest(params=FakeParams(name=None)) - self.instrumentor._add_server_attributes(span, "op", req) - # The method should still be called but with None value, so we just verify it was called - span.set_attribute.assert_called_once() + self.instrumentor._generate_mcp_attributes(mock_span, FakeListToolsRequest(), True) + self.instrumentor._generate_mcp_attributes(mock_span, FakeCallToolRequest(), True) + self.instrumentor._generate_mcp_attributes(mock_span, FakeInitializeRequest(), True) + mock_span.set_attribute.assert_any_call("mcp.list_tools", True) + mock_span.set_attribute.assert_any_call("mcp.call_tool", True) From 5f70570cffeff882d43820b73044d36d52c53d3c Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Tue, 5 Aug 2025 16:02:21 -0700 Subject: [PATCH 28/41] fixes --- .../distro/instrumentation/mcp/__init__.py | 2 - .../distro/instrumentation/mcp/constants.py | 38 ----- .../README.md | 0 .../instrumentation/mcp/__init__.py | 7 + .../instrumentation/mcp/instrumentation.py} | 138 +++++++----------- .../instrumentation}/mcp/semconv.py | 30 ++-- .../instrumentation}/mcp/version.py | 0 7 files changed, 81 insertions(+), 134 deletions(-) delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{mcp => opentelemetry-instrumentation-mcp}/README.md (100%) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{mcp/mcp_instrumentor.py => opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py} (53%) rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{ => opentelemetry-instrumentation-mcp/opentelemetry/instrumentation}/mcp/semconv.py (81%) rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{ => opentelemetry-instrumentation-mcp/opentelemetry/instrumentation}/mcp/version.py (100%) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py deleted file mode 100644 index 04f8b7b76..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py deleted file mode 100644 index 3ed60baa1..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/constants.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -MCP (Model Context Protocol) Constants for OpenTelemetry instrumentation. - -This module defines constants and configuration variables used by the MCP instrumentor. -""" - - -class MCPTraceContext: - """Constants for MCP distributed tracing context propagation.""" - - TRACEPARENT_HEADER = "traceparent" - """ - W3C Trace Context traceparent header name. - Used for propagating trace context in MCP request metadata. - """ - - TRACE_FLAGS_SAMPLED = "01" - """ - W3C Trace Context flags indicating the trace is sampled. - """ - - TRACEPARENT_VERSION = "00" - """ - W3C Trace Context version identifier. - """ - - -class MCPEnvironmentVariables: - """Environment variable names for MCP instrumentation configuration.""" - - SERVER_NAME = "MCP_INSTRUMENTATION_SERVER_NAME" - """ - Environment variable to override the default MCP server name. - Default value: "mcp server" - """ diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/README.md similarity index 100% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/README.md diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py new file mode 100644 index 000000000..df8422a8c --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py @@ -0,0 +1,7 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from opentelemetry.instrumentation.mcp.version import __version__ +from opentelemetry.instrumentation.mcp.instrumentation import McpInstrumentor + +__all__ = ["McpInstrumentor", "__version__"] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py similarity index 53% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py index b12961064..7b904c4b8 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/mcp_instrumentor.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Collection, Dict, Tuple +from typing import Any, Callable, Collection, Dict, Optional, Tuple from wrapt import register_post_import_hook, wrap_function_wrapper @@ -8,30 +8,26 @@ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.instrumentation.mcp.version import __version__ +from opentelemetry.propagate import get_global_textmap -from .constants import MCPEnvironmentVariables, MCPTraceContext -from .semconv import MCPAttributes, MCPOperations, MCPSpanNames +from .semconv import CLIENT_INITIALIZED, MCP_METHOD_NAME, TOOLS_CALL, TOOLS_LIST, MCPAttributes, MCPOperations, MCPSpanNames -class MCPInstrumentor(BaseInstrumentor): +class McpInstrumentor(BaseInstrumentor): """ An instrumenter for MCP. """ - def __init__(self): + def __init__(self, **kwargs): super().__init__() - self.tracer = None + self.propagators = kwargs.get("propagators") or get_global_textmap() + self.tracer = trace.get_tracer(__name__, __version__, tracer_provider=kwargs.get("tracer_provider", None)) - @staticmethod - def instrumentation_dependencies() -> Collection[str]: - return ("mcp >= 1.6.0",) + def instrumentation_dependencies(self) -> Collection[str]: + return "mcp >= 1.6.0" def _instrument(self, **kwargs: Any) -> None: - tracer_provider = kwargs.get("tracer_provider") - if tracer_provider: - self.tracer = tracer_provider.get_tracer("instrumentation.mcp") - else: - self.tracer = trace.get_tracer("instrumentation.mcp") register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.shared.session", @@ -49,48 +45,61 @@ def _instrument(self, **kwargs: Any) -> None: "mcp.server.lowlevel.server", ) - @staticmethod - def _uninstrument(**kwargs: Any) -> None: + def _uninstrument(self, **kwargs: Any) -> None: unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.server.lowlevel.server", "Server._handle_request") - - # Send Request Wrapper + + def _wrap_send_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Callable: - """ - Changes made: - The wrapper intercepts the request before sending, injects distributed tracing context into the - request's params._meta field and creates OpenTelemetry spans. The wrapper does not change anything - else from the original function's behavior because it reconstructs the request object with the same - type and calling the original function with identical parameters. + import mcp.types as types + """ + Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server. + This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts + the current trace context, and embeds it into the request's params._meta.traceparent field + before forwarding the request to the MCP server. + + Args: + wrapped: The original BaseSession.send_request function + instance: The BaseSession instance + args: Positional arguments, where args[0] is typically the request object + kwargs: Keyword arguments, may contain 'request' parameter + + Returns: + Callable: Async wrapper function that handles trace context injection """ async def async_wrapper(): + request: Optional[types.ClientRequest] = args[0] if len(args) > 0 else None + + if not request: + return await wrapped(*args, **kwargs) + with self.tracer.start_as_current_span( MCPSpanNames.CLIENT_SEND_REQUEST, kind=trace.SpanKind.CLIENT ) as span: - span_ctx = span.get_span_context() - request = args[0] if len(args) > 0 else kwargs.get("request") + if request: - req_root = request.root if hasattr(request, "root") else request - - self._generate_mcp_attributes(span, req_root, is_client=True) + span_ctx = trace.set_span_in_context(span) + parent_span = {} + self.propagators.inject(carrier=parent_span, context=span_ctx) + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - self._inject_trace_context(request_data, span_ctx) + + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"].update(parent_span) + # Reconstruct request object with injected trace context - modified_request = type(request).model_validate(request_data) - if len(args) > 0: - new_args = (modified_request,) + args[1:] - result = await wrapped(*new_args, **kwargs) - else: - kwargs["request"] = modified_request - result = await wrapped(*args, **kwargs) - else: - result = await wrapped(*args, **kwargs) - return result + modified_request = request.model_validate(request_data) + new_args = (modified_request,) + args[1:] + + return await wrapped(*new_args, **kwargs) - return async_wrapper() + return async_wrapper # Handle Request Wrapper async def _wrap_handle_request( @@ -111,7 +120,7 @@ async def _wrap_handle_request( traceparent = None if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - traceparent = getattr(req.params.meta, MCPTraceContext.TRACEPARENT_HEADER, None) + traceparent = None span_context = self._extract_span_context_from_traceparent(traceparent) if traceparent else None if span_context: span_name = self._get_mcp_operation(req) @@ -130,40 +139,18 @@ async def _wrap_handle_request( def _generate_mcp_attributes(span: trace.Span, request: Any, is_client: bool) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - operation = MCPOperations.UNKNOWN_OPERATION - if isinstance(request, types.ListToolsRequest): - operation = MCPOperations.LIST_TOOL - span.set_attribute(MCPAttributes.MCP_LIST_TOOLS, True) + span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) if is_client: span.update_name(MCPSpanNames.CLIENT_LIST_TOOLS) elif isinstance(request, types.CallToolRequest): - operation = request.params.name - span.set_attribute(MCPAttributes.MCP_CALL_TOOL, True) + span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) if is_client: span.update_name(MCPSpanNames.client_call_tool(request.params.name)) elif isinstance(request, types.InitializeRequest): - operation = MCPOperations.INITIALIZE - span.set_attribute(MCPAttributes.MCP_INITIALIZE, True) - if is_client: - span.update_name(MCPSpanNames.CLIENT_INITIALIZE) + span.set_attribute(MCP_METHOD_NAME, CLIENT_INITIALIZED) - if is_client: - MCPInstrumentor._add_client_attributes(span, operation, request) - else: - MCPInstrumentor._add_server_attributes(span, operation, request) - - @staticmethod - def _inject_trace_context(request_data: Dict[str, Any], span_ctx) -> None: - if "params" not in request_data: - request_data["params"] = {} - if "_meta" not in request_data["params"]: - request_data["params"]["_meta"] = {} - trace_id_hex = f"{span_ctx.trace_id:032x}" - span_id_hex = f"{span_ctx.span_id:016x}" - trace_flags = MCPTraceContext.TRACE_FLAGS_SAMPLED - traceparent = f"{MCPTraceContext.TRACEPARENT_VERSION}-{trace_id_hex}-{span_id_hex}-{trace_flags}" - request_data["params"]["_meta"][MCPTraceContext.TRACEPARENT_HEADER] = traceparent + # Additional attributes can be added here if needed @staticmethod def _extract_span_context_from_traceparent(traceparent: str): @@ -195,18 +182,3 @@ def _get_mcp_operation(req: Any) -> str: elif isinstance(req, types.CallToolRequest): span_name = MCPSpanNames.tools_call(req.params.name) return span_name - - @staticmethod - def _add_client_attributes(span: trace.Span, operation: str, request: Any) -> None: - import os # pylint: disable=import-outside-toplevel - - service_name = os.environ.get(MCPEnvironmentVariables.SERVER_NAME, "mcp server") - span.set_attribute(SpanAttributes.RPC_SERVICE, service_name) - span.set_attribute(SpanAttributes.RPC_METHOD, operation) - if hasattr(request, "params") and request.params and hasattr(request.params, "name"): - span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) - - @staticmethod - def _add_server_attributes(span: trace.Span, operation: str, request: Any) -> None: - if hasattr(request, "params") and request.params and hasattr(request.params, "name"): - span.set_attribute(MCPAttributes.MCP_TOOL_NAME, request.params.name) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/semconv.py similarity index 81% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/semconv.py index 409092c2a..b498d08b2 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/semconv.py @@ -3,14 +3,27 @@ """ MCP (Model Context Protocol) Semantic Conventions for OpenTelemetry. - -This module defines semantic conventions for MCP instrumentation following -OpenTelemetry standards for consistent telemetry data. """ +MCP_METHOD_NAME = "mcp.method.name" +MCP_REQUEST_ID = "mcp.request.id" +MCP_SESSION_ID = "mcp.session.id" +MCP_TOOL_NAME = "mcp.tool.name" +MCP_PROMPT_NAME = "mcp.prompt.name" +MCP_REQUEST_ARGUMENT = "mcp.request.argument" + + +NOTIFICATIONS_CANCELLED = "notifications/cancelled" +NOTIFICATIONS_INITIALIZED = "notifications/initialized" +NOTIFICATIONS_PROGRESS = "notifications/progress" +RESOURCES_LIST = "resources/list" +TOOLS_LIST = "tools/list" +TOOLS_CALL = "tools/call" +CLIENT_INITIALIZED = "initialize" + + class MCPAttributes: - """MCP-specific span attributes for OpenTelemetry instrumentation.""" # MCP Operation Type Attributes MCP_INITIALIZE = "notifications/initialize" @@ -43,18 +56,13 @@ class MCPSpanNames: """Standard span names for MCP operations.""" # Client-side span names - CLIENT_SEND_REQUEST = "client.send_request" + CLIENT_SEND_REQUEST = "span.mcp.client" """ Span name for client-side MCP request operations. Used for all outgoing MCP requests (initialize, list tools, call tool). """ - CLIENT_INITIALIZE = "notifications/initialize" - """ - Span name for client-side MCP initialization requests. - """ - - CLIENT_LIST_TOOLS = "mcp.list_tools" + CLIENT_LIST_TOOLS = "span.mcp.server" """ Span name for client-side MCP list tools requests. """ diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/version.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/version.py similarity index 100% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/version.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/version.py From 064c71cc46fb42bf49ce3e2f42c8921eceb17204 Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Tue, 5 Aug 2025 18:19:43 -0700 Subject: [PATCH 29/41] add otel propagation library --- aws-opentelemetry-distro/pyproject.toml | 7 +- .../README.md | 0 .../instrumentation => }/mcp/__init__.py | 4 +- .../instrumentation/mcp/instrumentation.py | 158 +++++++++++++++ .../instrumentation => }/mcp/semconv.py | 4 +- .../instrumentation => }/mcp/version.py | 0 .../instrumentation/mcp/instrumentation.py | 184 ------------------ 7 files changed, 167 insertions(+), 190 deletions(-) rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{opentelemetry-instrumentation-mcp => mcp}/README.md (100%) rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{opentelemetry-instrumentation-mcp/opentelemetry/instrumentation => }/mcp/__init__.py (51%) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{opentelemetry-instrumentation-mcp/opentelemetry/instrumentation => }/mcp/semconv.py (97%) rename aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/{opentelemetry-instrumentation-mcp/opentelemetry/instrumentation => }/mcp/version.py (100%) delete mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index f36c66e4b..861a1f80f 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -89,7 +89,10 @@ dependencies = [ # If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch patch = [ "botocore ~= 1.0", - "mcp >= 1.6.0" + "mcp >= 0.1.6", +] +instruments = [ + "mcp >= 0.1.6", ] test = [] @@ -97,7 +100,7 @@ test = [] aws_configurator = "amazon.opentelemetry.distro.aws_opentelemetry_configurator:AwsOpenTelemetryConfigurator" [project.entry-points.opentelemetry_instrumentor] -mcp = "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor:MCPInstrumentor" +mcp = "amazon.opentelemetry.distro.instrumentation.mcp.instrumentation:McpInstrumentor" [project.entry-points.opentelemetry_distro] aws_distro = "amazon.opentelemetry.distro.aws_opentelemetry_distro:AwsOpenTelemetryDistro" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/README.md b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md similarity index 100% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/README.md rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/README.md diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py similarity index 51% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py index df8422a8c..571452d28 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/__init__.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/__init__.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from opentelemetry.instrumentation.mcp.version import __version__ -from opentelemetry.instrumentation.mcp.instrumentation import McpInstrumentor +from .version import __version__ +from .instrumentation import McpInstrumentor __all__ = ["McpInstrumentor", "__version__"] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py new file mode 100644 index 000000000..bc6e9eba7 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -0,0 +1,158 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, cast + +from wrapt import register_post_import_hook, wrap_function_wrapper + +from opentelemetry import trace +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.propagate import get_global_textmap + +from .version import __version__ + +from .semconv import ( + CLIENT_INITIALIZED, + MCP_METHOD_NAME, + TOOLS_CALL, + TOOLS_LIST, + MCPAttributes, + MCPOperations, + MCPSpanNames, +) + + +class McpInstrumentor(BaseInstrumentor): + """ + An instrumentor class for MCP. + """ + + def __init__(self, **kwargs): + super().__init__() + self.propagators = kwargs.get("propagators") or get_global_textmap() + self.tracer = trace.get_tracer(__name__, __version__, tracer_provider=kwargs.get("tracer_provider", None)) + + def instrumentation_dependencies(self) -> Collection[str]: + return "mcp >= 1.6.0" + + def _instrument(self, **kwargs: Any) -> None: + + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.shared.session", + "BaseSession.send_request", + self._wrap_send_request, + ), + "mcp.shared.session", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.lowlevel.server", + "Server._handle_request", + self._wrap_handle_request, + ), + "mcp.server.lowlevel.server", + ) + + def _uninstrument(self, **kwargs: Any) -> None: + unwrap("mcp.shared.session", "BaseSession.send_request") + unwrap("mcp.server.lowlevel.server", "Server._handle_request") + + def _wrap_send_request( + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Callable: + import mcp.types as types + + """ + Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server. + This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts + the current trace context, and embeds it into the request's params._meta.traceparent field + before forwarding the request to the MCP server. + """ + + async def async_wrapper(): + request: Optional[types.ClientRequest] = args[0] if len(args) > 0 else None + + if not request: + return await wrapped(*args, **kwargs) + + request_as_json = request.model_dump(by_alias=True, mode="json", exclude_none=True) + + if "params" not in request_as_json: + request_as_json["params"] = {} + + if "_meta" not in request_as_json["params"]: + request_as_json["params"]["_meta"] = {} + + with self.tracer.start_as_current_span( + MCPSpanNames.SPAN_MCP_CLIENT, kind=trace.SpanKind.CLIENT + ) as mcp_client_span: + + if request: + span_ctx = trace.set_span_in_context(mcp_client_span) + parent_span = {} + self.propagators.inject(carrier=parent_span, context=span_ctx) + + McpInstrumentor._set_mcp_client_attributes(mcp_client_span, request) + + request_as_json["params"]["_meta"].update(parent_span) + + # Reconstruct request object with injected trace context + modified_request = request.model_validate(request_as_json) + new_args = (modified_request,) + args[1:] + + return await wrapped(*new_args, **kwargs) + + return async_wrapper + + # Handle Request Wrapper + async def _wrap_handle_request( + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + """ + Patches Server._handle_request which is responsible for processing requests on the MCP server. + This patched MCP server intercepts incoming requests to extract tracing context from + the request's params._meta field and creates server-side spans linked to the client spans. + """ + req = args[1] if len(args) > 1 else None + carrier = {} + + if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: + carrier = req.params.meta.__dict__ + + parent_ctx = self.propagators.extract(carrier=carrier) + + if parent_ctx: + with self.tracer.start_as_current_span( + MCPSpanNames.SPAN_MCP_SERVER, kind=trace.SpanKind.SERVER, context=parent_ctx + ) as mcp_server_span: + self._set_mcp_server_attributes(mcp_server_span, req) + + return await wrapped(*args, **kwargs) + + @staticmethod + def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None: + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + + if isinstance(request, types.ListToolsRequest): + span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) + if isinstance(request, types.CallToolRequest): + tool_name = request.params.name + span.update_name(f"{TOOLS_CALL} {tool_name}") + span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) + if isinstance(request, types.InitializeRequest): + span.set_attribute(MCP_METHOD_NAME, CLIENT_INITIALIZED) + + @staticmethod + def _set_mcp_server_attributes(span: trace.Span, request: Any) -> None: + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + + if isinstance(span, types.ListToolsRequest): + span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) + if isinstance(span, types.CallToolRequest): + tool_name = request.params.name + span.update_name(f"{TOOLS_CALL} {tool_name}") + span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py similarity index 97% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/semconv.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py index b498d08b2..eab16ce02 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/semconv.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py @@ -56,13 +56,13 @@ class MCPSpanNames: """Standard span names for MCP operations.""" # Client-side span names - CLIENT_SEND_REQUEST = "span.mcp.client" + SPAN_MCP_CLIENT = "span.mcp.client" """ Span name for client-side MCP request operations. Used for all outgoing MCP requests (initialize, list tools, call tool). """ - CLIENT_LIST_TOOLS = "span.mcp.server" + SPAN_MCP_SERVER = "span.mcp.server" """ Span name for client-side MCP list tools requests. """ diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/version.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/version.py similarity index 100% rename from aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/version.py rename to aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/version.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py deleted file mode 100644 index 7b904c4b8..000000000 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/opentelemetry-instrumentation-mcp/opentelemetry/instrumentation/mcp/instrumentation.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Collection, Dict, Optional, Tuple - -from wrapt import register_post_import_hook, wrap_function_wrapper - -from opentelemetry import trace -from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.instrumentation.utils import unwrap -from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.instrumentation.mcp.version import __version__ -from opentelemetry.propagate import get_global_textmap - -from .semconv import CLIENT_INITIALIZED, MCP_METHOD_NAME, TOOLS_CALL, TOOLS_LIST, MCPAttributes, MCPOperations, MCPSpanNames - - -class McpInstrumentor(BaseInstrumentor): - """ - An instrumenter for MCP. - """ - - def __init__(self, **kwargs): - super().__init__() - self.propagators = kwargs.get("propagators") or get_global_textmap() - self.tracer = trace.get_tracer(__name__, __version__, tracer_provider=kwargs.get("tracer_provider", None)) - - def instrumentation_dependencies(self) -> Collection[str]: - return "mcp >= 1.6.0" - - def _instrument(self, **kwargs: Any) -> None: - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.shared.session", - "BaseSession.send_request", - self._wrap_send_request, - ), - "mcp.shared.session", - ) - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.server.lowlevel.server", - "Server._handle_request", - self._wrap_handle_request, - ), - "mcp.server.lowlevel.server", - ) - - def _uninstrument(self, **kwargs: Any) -> None: - unwrap("mcp.shared.session", "BaseSession.send_request") - unwrap("mcp.server.lowlevel.server", "Server._handle_request") - - - def _wrap_send_request( - self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Callable: - import mcp.types as types - """ - Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server. - This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts - the current trace context, and embeds it into the request's params._meta.traceparent field - before forwarding the request to the MCP server. - - Args: - wrapped: The original BaseSession.send_request function - instance: The BaseSession instance - args: Positional arguments, where args[0] is typically the request object - kwargs: Keyword arguments, may contain 'request' parameter - - Returns: - Callable: Async wrapper function that handles trace context injection - """ - - async def async_wrapper(): - request: Optional[types.ClientRequest] = args[0] if len(args) > 0 else None - - if not request: - return await wrapped(*args, **kwargs) - - with self.tracer.start_as_current_span( - MCPSpanNames.CLIENT_SEND_REQUEST, kind=trace.SpanKind.CLIENT - ) as span: - - if request: - span_ctx = trace.set_span_in_context(span) - parent_span = {} - self.propagators.inject(carrier=parent_span, context=span_ctx) - - request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - - if "params" not in request_data: - request_data["params"] = {} - if "_meta" not in request_data["params"]: - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"].update(parent_span) - - # Reconstruct request object with injected trace context - modified_request = request.model_validate(request_data) - new_args = (modified_request,) + args[1:] - - return await wrapped(*new_args, **kwargs) - - return async_wrapper - - # Handle Request Wrapper - async def _wrap_handle_request( - self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Any: - """ - Changes made: - This wrapper intercepts requests before processing, extracts distributed tracing context from - the request's params._meta field, and creates server-side OpenTelemetry spans linked to the client spans. - The wrapper also does not change the original function's behavior by calling it with identical parameters - ensuring no breaking changes to the MCP server functionality. - - request (args[1]) is typically an instance of CallToolRequest or ListToolsRequest - and should have the structure: - request.params.meta.traceparent -> "00---01" - """ - req = args[1] if len(args) > 1 else None - traceparent = None - - if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - traceparent = None - span_context = self._extract_span_context_from_traceparent(traceparent) if traceparent else None - if span_context: - span_name = self._get_mcp_operation(req) - with self.tracer.start_as_current_span( - span_name, - kind=trace.SpanKind.SERVER, - context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)), - ) as span: - self._generate_mcp_attributes(span, req, False) - result = await wrapped(*args, **kwargs) - return result - else: - return await wrapped(*args, **kwargs) - - @staticmethod - def _generate_mcp_attributes(span: trace.Span, request: Any, is_client: bool) -> None: - import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - - if isinstance(request, types.ListToolsRequest): - span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) - if is_client: - span.update_name(MCPSpanNames.CLIENT_LIST_TOOLS) - elif isinstance(request, types.CallToolRequest): - span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) - if is_client: - span.update_name(MCPSpanNames.client_call_tool(request.params.name)) - elif isinstance(request, types.InitializeRequest): - span.set_attribute(MCP_METHOD_NAME, CLIENT_INITIALIZED) - - # Additional attributes can be added here if needed - - @staticmethod - def _extract_span_context_from_traceparent(traceparent: str): - parts = traceparent.split("-") - if len(parts) == 4: - try: - trace_id = int(parts[1], 16) - span_id = int(parts[2], 16) - return trace.SpanContext( - trace_id=trace_id, - span_id=span_id, - is_remote=True, - trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), - trace_state=trace.TraceState(), - ) - except ValueError: - return None - return None - - @staticmethod - def _get_mcp_operation(req: Any) -> str: - - import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - - span_name = "unknown" - - if isinstance(req, types.ListToolsRequest): - span_name = MCPSpanNames.TOOLS_LIST - elif isinstance(req, types.CallToolRequest): - span_name = MCPSpanNames.tools_call(req.params.name) - return span_name From c97c8df467e4bbb7542ea85daff16a939a7f3536 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 6 Aug 2025 01:56:36 -0700 Subject: [PATCH 30/41] Fix MCP instrumentation for distributed tracing --- aws-opentelemetry-distro/pyproject.toml | 6 ++--- .../instrumentation/mcp/instrumentation.py | 23 +++++++------------ 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/aws-opentelemetry-distro/pyproject.toml b/aws-opentelemetry-distro/pyproject.toml index 861a1f80f..fbfffcff5 100644 --- a/aws-opentelemetry-distro/pyproject.toml +++ b/aws-opentelemetry-distro/pyproject.toml @@ -89,11 +89,9 @@ dependencies = [ # If a new patch is added into the list, it must also be added into tox.ini, dev-requirements.txt and _instrumentation_patch patch = [ "botocore ~= 1.0", - "mcp >= 0.1.6", -] -instruments = [ - "mcp >= 0.1.6", + "mcp >= 1.6.0", ] + test = [] [project.entry-points.opentelemetry_configurator] diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index bc6e9eba7..3a487f525 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -34,10 +34,10 @@ def __init__(self, **kwargs): self.tracer = trace.get_tracer(__name__, __version__, tracer_provider=kwargs.get("tracer_provider", None)) def instrumentation_dependencies(self) -> Collection[str]: - return "mcp >= 1.6.0" + return ("mcp >= 1.6.0",) def _instrument(self, **kwargs: Any) -> None: - + register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.shared.session", @@ -63,6 +63,7 @@ def _wrap_send_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Callable: import mcp.types as types + """ Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server. @@ -103,23 +104,16 @@ async def async_wrapper(): new_args = (modified_request,) + args[1:] return await wrapped(*new_args, **kwargs) + return async_wrapper() - return async_wrapper - - # Handle Request Wrapper async def _wrap_handle_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: - """ - Patches Server._handle_request which is responsible for processing requests on the MCP server. - This patched MCP server intercepts incoming requests to extract tracing context from - the request's params._meta field and creates server-side spans linked to the client spans. - """ req = args[1] if len(args) > 1 else None carrier = {} if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - carrier = req.params.meta.__dict__ + carrier = req.params.meta.model_dump() parent_ctx = self.propagators.extract(carrier=carrier) @@ -128,8 +122,7 @@ async def _wrap_handle_request( MCPSpanNames.SPAN_MCP_SERVER, kind=trace.SpanKind.SERVER, context=parent_ctx ) as mcp_server_span: self._set_mcp_server_attributes(mcp_server_span, req) - - return await wrapped(*args, **kwargs) + return await wrapped(*args, **kwargs) @staticmethod def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None: @@ -149,9 +142,9 @@ def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None: def _set_mcp_server_attributes(span: trace.Span, request: Any) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - if isinstance(span, types.ListToolsRequest): + if isinstance(request, types.ListToolsRequest): span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) - if isinstance(span, types.CallToolRequest): + if isinstance(request, types.CallToolRequest): tool_name = request.params.name span.update_name(f"{TOOLS_CALL} {tool_name}") span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) From dc2c73014c5fc01a4b16936911d5dcaf64aa2fba Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Wed, 6 Aug 2025 13:06:40 -0700 Subject: [PATCH 31/41] add further tool instrumentation --- .../instrumentation/mcp/instrumentation.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 3a487f525..0057c55c5 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -1,10 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +import json from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, cast -from wrapt import register_post_import_hook, wrap_function_wrapper +from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper -from opentelemetry import trace +from opentelemetry import context, trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes @@ -15,6 +17,7 @@ from .semconv import ( CLIENT_INITIALIZED, MCP_METHOD_NAME, + MCP_REQUEST_ARGUMENT, TOOLS_CALL, TOOLS_LIST, MCPAttributes, @@ -34,7 +37,7 @@ def __init__(self, **kwargs): self.tracer = trace.get_tracer(__name__, __version__, tracer_provider=kwargs.get("tracer_provider", None)) def instrumentation_dependencies(self) -> Collection[str]: - return ("mcp >= 1.6.0",) + return ("mcp >= 1.8.1",) def _instrument(self, **kwargs: Any) -> None: @@ -88,14 +91,14 @@ async def async_wrapper(): with self.tracer.start_as_current_span( MCPSpanNames.SPAN_MCP_CLIENT, kind=trace.SpanKind.CLIENT - ) as mcp_client_span: + ) as client_span: if request: - span_ctx = trace.set_span_in_context(mcp_client_span) + span_ctx = trace.set_span_in_context(client_span) parent_span = {} self.propagators.inject(carrier=parent_span, context=span_ctx) - McpInstrumentor._set_mcp_client_attributes(mcp_client_span, request) + McpInstrumentor._set_mcp_client_attributes(client_span, request) request_as_json["params"]["_meta"].update(parent_span) @@ -132,6 +135,10 @@ def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None: span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) if isinstance(request, types.CallToolRequest): tool_name = request.params.name + tool_arguments = request.params.arguments + if tool_arguments: + for arg_name, arg_val in tool_arguments.items(): + span.set_attribute(f"{MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val)) span.update_name(f"{TOOLS_CALL} {tool_name}") span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) @@ -149,3 +156,13 @@ def _set_mcp_server_attributes(span: trace.Span, request: Any) -> None: span.update_name(f"{TOOLS_CALL} {tool_name}") span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) + + + @staticmethod + def serialize(args): + try: + return json.dumps(args) + except Exception: + return str(args) + + From 76c24051fc1e652f54c6f36411f1e18db6ca1010 Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Wed, 6 Aug 2025 17:01:33 -0700 Subject: [PATCH 32/41] cleanup code --- .../instrumentation/mcp/instrumentation.py | 175 +++++++++++------- .../distro/instrumentation/mcp/semconv.py | 126 ++++--------- 2 files changed, 153 insertions(+), 148 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 0057c55c5..2b252b8a4 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -7,6 +7,7 @@ from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper from opentelemetry import context, trace +from opentelemetry.trace.status import Status, StatusCode from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes @@ -15,20 +16,14 @@ from .version import __version__ from .semconv import ( - CLIENT_INITIALIZED, - MCP_METHOD_NAME, - MCP_REQUEST_ARGUMENT, - TOOLS_CALL, - TOOLS_LIST, - MCPAttributes, - MCPOperations, - MCPSpanNames, + MCPSpanAttributes, + MCPMethodNameValue, ) class McpInstrumentor(BaseInstrumentor): """ - An instrumentor class for MCP. + An instrumentation class for MCP: https://modelcontextprotocol.io/overview. """ def __init__(self, **kwargs): @@ -40,12 +35,14 @@ def instrumentation_dependencies(self) -> Collection[str]: return ("mcp >= 1.8.1",) def _instrument(self, **kwargs: Any) -> None: - + # TODO: add instrumentation for Streamable Http transport + # See: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.shared.session", "BaseSession.send_request", - self._wrap_send_request, + self._wrap_session_send_request, ), "mcp.shared.session", ) @@ -53,7 +50,7 @@ def _instrument(self, **kwargs: Any) -> None: lambda _: wrap_function_wrapper( "mcp.server.lowlevel.server", "Server._handle_request", - self._wrap_handle_request, + self._wrap_stdio_handle_request, ), "mcp.server.lowlevel.server", ) @@ -62,17 +59,26 @@ def _uninstrument(self, **kwargs: Any) -> None: unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.server.lowlevel.server", "Server._handle_request") - def _wrap_send_request( + def _wrap_session_send_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Callable: import mcp.types as types - """ - Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server. - This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts - the current trace context, and embeds it into the request's params._meta.traceparent field + Instruments MCP client-side stdio request sending, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + + This is the master function responsible for sending requests from the client to the MCP server. See: + https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220 + + The instrumented MCP client intercepts the request to obtain attributes for creating client-side span, extracts + the current trace context, and embeds it into the request's params._meta field before forwarding the request to the MCP server. + + Args: + wrapped: The original BaseSession.send_request method being instrumented + instance: The BaseSession instance handling the stdio communication + args: Positional arguments passed to the original send_request method, containing the ClientRequest + kwargs: Keyword arguments passed to the original send_request method """ async def async_wrapper(): @@ -81,88 +87,131 @@ async def async_wrapper(): if not request: return await wrapped(*args, **kwargs) + request_id = None + + if hasattr(instance, "_request_id"): + request_id = instance._request_id + request_as_json = request.model_dump(by_alias=True, mode="json", exclude_none=True) if "params" not in request_as_json: request_as_json["params"] = {} - if "_meta" not in request_as_json["params"]: request_as_json["params"]["_meta"] = {} - with self.tracer.start_as_current_span( - MCPSpanNames.SPAN_MCP_CLIENT, kind=trace.SpanKind.CLIENT - ) as client_span: + with self.tracer.start_as_current_span("span.mcp.client", kind=trace.SpanKind.CLIENT) as client_span: - if request: - span_ctx = trace.set_span_in_context(client_span) - parent_span = {} - self.propagators.inject(carrier=parent_span, context=span_ctx) + span_ctx = trace.set_span_in_context(client_span) + parent_span = {} + self.propagators.inject(carrier=parent_span, context=span_ctx) - McpInstrumentor._set_mcp_client_attributes(client_span, request) + McpInstrumentor._configure_mcp_span(client_span, request, request_id) + request_as_json["params"]["_meta"].update(parent_span) - request_as_json["params"]["_meta"].update(parent_span) + # Reconstruct request object with injected trace context + modified_request = request.model_validate(request_as_json) + new_args = (modified_request,) + args[1:] - # Reconstruct request object with injected trace context - modified_request = request.model_validate(request_as_json) - new_args = (modified_request,) + args[1:] + try: + result = await wrapped(*new_args, **kwargs) + client_span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + client_span.set_status(Status(StatusCode.ERROR, str(e))) + client_span.record_exception(e) + raise - return await wrapped(*new_args, **kwargs) return async_wrapper() - async def _wrap_handle_request( + async def _wrap_stdio_handle_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: - req = args[1] if len(args) > 1 else None + """ + Instruments MCP server-side stdio request handling, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + + This is the core function responsible for processing incoming requests on the MCP server. See: + https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616 + + The instrumented MCP server intercepts incoming requests to extract tracing context from + the request's params._meta field, creates server-side spans linked to the originating client spans, + and processes the request while maintaining trace continuity. + + Args: + wrapped: The original Server._handle_request method being instrumented + instance: The MCP Server instance processing the stdio communication + args: Positional arguments passed to the original _handle_request method, containing the incoming request + kwargs: Keyword arguments passed to the original _handle_request method + """ + incoming_req = args[1] if len(args) > 1 else None + request_id = None carrier = {} - if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: - carrier = req.params.meta.model_dump() + if incoming_req and hasattr(incoming_req, "id"): + request_id = incoming_req.id + if incoming_req and hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta"): + carrier = incoming_req.params.meta.model_dump() parent_ctx = self.propagators.extract(carrier=carrier) if parent_ctx: with self.tracer.start_as_current_span( - MCPSpanNames.SPAN_MCP_SERVER, kind=trace.SpanKind.SERVER, context=parent_ctx - ) as mcp_server_span: - self._set_mcp_server_attributes(mcp_server_span, req) - return await wrapped(*args, **kwargs) + "span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx + ) as server_span: - @staticmethod - def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None: - import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + self._configure_mcp_span(server_span, incoming_req, request_id) - if isinstance(request, types.ListToolsRequest): - span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) - if isinstance(request, types.CallToolRequest): - tool_name = request.params.name - tool_arguments = request.params.arguments - if tool_arguments: - for arg_name, arg_val in tool_arguments.items(): - span.set_attribute(f"{MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val)) - span.update_name(f"{TOOLS_CALL} {tool_name}") - span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) - span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) - if isinstance(request, types.InitializeRequest): - span.set_attribute(MCP_METHOD_NAME, CLIENT_INITIALIZED) + try: + result = await wrapped(*args, **kwargs) + server_span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + server_span.set_status(Status(StatusCode.ERROR, str(e))) + server_span.record_exception(e) + raise @staticmethod - def _set_mcp_server_attributes(span: trace.Span, request: Any) -> None: + def _configure_mcp_span(span: trace.Span, request, request_id: Optional[str]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + if hasattr(request, "root"): + request = request.root + + if request_id: + span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id) + if isinstance(request, types.ListToolsRequest): - span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) + span.update_name(MCPMethodNameValue.TOOLS_LIST) + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.TOOLS_LIST) + return + if isinstance(request, types.CallToolRequest): tool_name = request.params.name - span.update_name(f"{TOOLS_CALL} {tool_name}") - span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) - span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) + span.update_name(f"{MCPMethodNameValue.TOOLS_CALL} {request.params.name}") + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.TOOLS_CALL) + span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name) - @staticmethod + if hasattr(request.params, "arguments"): + for arg_name, arg_val in request.params.arguments.items(): + span.set_attribute( + f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val) + ) + + if isinstance(request, types.InitializeRequest): + span.update_name(MCPMethodNameValue.INITIALIZED) + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.INITIALIZED) + + if isinstance(request, types.CancelledNotification): + span.update_name(MCPMethodNameValue.NOTIFICATIONS_CANCELLED) + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.NOTIFICATIONS_CANCELLED) + + if isinstance(request, types.ToolListChangedNotification): + span.update_name(MCPMethodNameValue.NOTIFICATIONS_CANCELLED) + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.NOTIFICATIONS_CANCELLED) + + @staticmethod def serialize(args): try: return json.dumps(args) except Exception: return str(args) - - diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py index eab16ce02..099b8c7da 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py @@ -2,112 +2,68 @@ # SPDX-License-Identifier: Apache-2.0 """ -MCP (Model Context Protocol) Semantic Conventions for OpenTelemetry. -""" - +MCP (Model Context Protocol) Semantic Conventions. -MCP_METHOD_NAME = "mcp.method.name" -MCP_REQUEST_ID = "mcp.request.id" -MCP_SESSION_ID = "mcp.session.id" -MCP_TOOL_NAME = "mcp.tool.name" -MCP_PROMPT_NAME = "mcp.prompt.name" -MCP_REQUEST_ARGUMENT = "mcp.request.argument" +Based off of: https://github.com/open-telemetry/semantic-conventions/pull/2083 - -NOTIFICATIONS_CANCELLED = "notifications/cancelled" -NOTIFICATIONS_INITIALIZED = "notifications/initialized" -NOTIFICATIONS_PROGRESS = "notifications/progress" -RESOURCES_LIST = "resources/list" -TOOLS_LIST = "tools/list" -TOOLS_CALL = "tools/call" -CLIENT_INITIALIZED = "initialize" +WARNING: These semantic conventions are currently in development and are considered unstable. +They may change at any time without notice. Use with caution in production environments. +""" -class MCPAttributes: +class MCPSpanAttributes: - # MCP Operation Type Attributes - MCP_INITIALIZE = "notifications/initialize" + MCP_METHOD_NAME = "mcp.method.name" """ - Boolean attribute indicating this span represents an MCP initialize operation. - Set to True when the span tracks session initialization between client and server. + The name of the request or notification method. + Examples: notifications/cancelled; initialize; notifications/initialized """ - - MCP_LIST_TOOLS = "mcp.list_tools" - """ - Boolean attribute indicating this span represents an MCP list tools operation. - Set to True when the span tracks discovery of available tools on the server. + MCP_REQUEST_ID = "mcp.request.id" """ - - MCP_CALL_TOOL = "mcp.call_tool" + This is a unique identifier for the request. + Conditionally Required when the client executes a request. """ - Boolean attribute indicating this span represents an MCP call tool operation. - Set to True when the span tracks execution of a specific tool. + MCP_TOOL_NAME = "mcp.tool.name" """ + The name of the tool provided in the request. - # MCP Tool Information - MCP_TOOL_NAME = "mcp.tool.name" + Conditionally Required when operation is related to a specific tool. + """ + MCP_REQUEST_ARGUMENT = "mcp.request.argument" """ - The name of the MCP tool being called. - Example: "echo", "search", "calculator" + Full attribute: mcp.request.argument. + Additional arguments passed to the request within params object. being the normalized argument name (lowercase), the value being the argument value. """ -class MCPSpanNames: - """Standard span names for MCP operations.""" +class MCPMethodNameValue: - # Client-side span names - SPAN_MCP_CLIENT = "span.mcp.client" + NOTIFICATIONS_CANCELLED = "notifications/cancelled" """ - Span name for client-side MCP request operations. - Used for all outgoing MCP requests (initialize, list tools, call tool). + Notification cancelling a previously-issued request. """ - SPAN_MCP_SERVER = "span.mcp.server" + NOTIFICATIONS_INITIALIZED = "notifications/initialized" """ - Span name for client-side MCP list tools requests. + Notification indicating that the MCP client has been initialized. + """ + NOTIFICATIONS_PROGRESS = "notifications/progress" + """ + Notification indicating the progress for a long-running operation. + """ + RESOURCES_LIST = "resources/list" + """ + Request to list resources available on server. """ - - @staticmethod - def client_call_tool(tool_name: str) -> str: - """ - Generate span name for client-side MCP tool call requests. - - Args: - tool_name: Name of the tool being called - - Returns: - Formatted span name like "mcp.call_tool.echo", "mcp.call_tool.search" - """ - return f"mcp.call_tool.{tool_name}" - TOOLS_LIST = "tools/list" """ - Span name for server-side MCP list tools handling. - Tracks server processing of tool discovery requests. + Request to list tools available on server. + """ + TOOLS_CALL = "tools/call" + """ + Request to call a tool. + """ + INITIALIZED = "initialize" + """ + Request to initialize the MCP client. """ - - @staticmethod - def tools_call(tool_name: str) -> str: - """ - Generate span name for server-side MCP tool call handling. - - Args: - tool_name: Name of the tool being called - - Returns: - Formatted span name like "tools/echo", "tools/search" - """ - return f"tools/{tool_name}" - - -class MCPOperations: - """Standard operation names for MCP semantic conventions.""" - - INITIALIZE = "Notifications/Initialize" - """Operation name for MCP session initialization.""" - - LIST_TOOL = "ListTool" - """Operation name for MCP tool discovery.""" - - UNKNOWN_OPERATION = "UnknownOperation" - """Fallback operation name for unrecognized MCP operations.""" From 323b87a6530906ab271026ec3515ce4501465736 Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Wed, 6 Aug 2025 22:46:54 -0700 Subject: [PATCH 33/41] add more span attribute logic --- .../instrumentation/mcp/instrumentation.py | 76 ++++++++++--------- .../distro/instrumentation/mcp/semconv.py | 24 +++++- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 2b252b8a4..7f0c221bc 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -17,13 +17,13 @@ from .semconv import ( MCPSpanAttributes, - MCPMethodNameValue, + MCPMethodValue, ) class McpInstrumentor(BaseInstrumentor): """ - An instrumentation class for MCP: https://modelcontextprotocol.io/overview. + An instrumentation class for MCP: https://modelcontextprotocol.io/overview """ def __init__(self, **kwargs): @@ -46,11 +46,13 @@ def _instrument(self, **kwargs: Any) -> None: ), "mcp.shared.session", ) + + register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.server.lowlevel.server", "Server._handle_request", - self._wrap_stdio_handle_request, + self._wrap_server_handle_request, ), "mcp.server.lowlevel.server", ) @@ -65,11 +67,14 @@ def _wrap_session_send_request( import mcp.types as types """ - Instruments MCP client-side stdio request sending, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + Instruments MCP client-side request sending for both stdio and Streamable HTTP transport, + see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + + This is the master function responsible for sending requests from the client to the MCP server. + See: + - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220 + - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/client/session_group.py#L233 - This is the master function responsible for sending requests from the client to the MCP server. See: - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220 - The instrumented MCP client intercepts the request to obtain attributes for creating client-side span, extracts the current trace context, and embeds it into the request's params._meta field before forwarding the request to the MCP server. @@ -105,7 +110,7 @@ async def async_wrapper(): parent_span = {} self.propagators.inject(carrier=parent_span, context=span_ctx) - McpInstrumentor._configure_mcp_span(client_span, request, request_id) + McpInstrumentor._generate_mcp_span_attrs(client_span, request, request_id) request_as_json["params"]["_meta"].update(parent_span) # Reconstruct request object with injected trace context @@ -123,13 +128,15 @@ async def async_wrapper(): return async_wrapper() - async def _wrap_stdio_handle_request( + async def _wrap_server_handle_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: """ - Instruments MCP server-side stdio request handling, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + Instruments MCP server-side request handling for both stdio and Streamable HTTP transport, + see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports - This is the core function responsible for processing incoming requests on the MCP server. See: + This is the core function responsible for processing incoming requests on the MCP server. + See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616 The instrumented MCP server intercepts incoming requests to extract tracing context from @@ -158,7 +165,7 @@ async def _wrap_stdio_handle_request( "span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx ) as server_span: - self._configure_mcp_span(server_span, incoming_req, request_id) + self._generate_mcp_span_attrs(server_span, incoming_req, request_id) try: result = await wrapped(*args, **kwargs) @@ -170,45 +177,44 @@ async def _wrap_stdio_handle_request( raise @staticmethod - def _configure_mcp_span(span: trace.Span, request, request_id: Optional[str]) -> None: + def _generate_mcp_span_attrs(span: trace.Span, request, request_id: Optional[str]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import + # Client-side: request is of type ClientRequest which contains the Union of different RootModel types + # Server-side: request is passed the RootModel + # See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/types.py#L1220 if hasattr(request, "root"): request = request.root if request_id: span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id) - - if isinstance(request, types.ListToolsRequest): - span.update_name(MCPMethodNameValue.TOOLS_LIST) - span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.TOOLS_LIST) - return + + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, request.method) if isinstance(request, types.CallToolRequest): tool_name = request.params.name - - span.update_name(f"{MCPMethodNameValue.TOOLS_CALL} {request.params.name}") - span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.TOOLS_CALL) + span.update_name(f"{MCPMethodValue.TOOLS_CALL} {tool_name}") span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name) - if hasattr(request.params, "arguments"): + if request.params.arguments: for arg_name, arg_val in request.params.arguments.items(): span.set_attribute( f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val) ) - - if isinstance(request, types.InitializeRequest): - span.update_name(MCPMethodNameValue.INITIALIZED) - span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.INITIALIZED) - - if isinstance(request, types.CancelledNotification): - span.update_name(MCPMethodNameValue.NOTIFICATIONS_CANCELLED) - span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.NOTIFICATIONS_CANCELLED) - - if isinstance(request, types.ToolListChangedNotification): - span.update_name(MCPMethodNameValue.NOTIFICATIONS_CANCELLED) - span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.NOTIFICATIONS_CANCELLED) - + return + if isinstance(request, types.GetPromptRequest): + prompt_name = request.params.name + span.update_name(f"{MCPMethodValue.PROMPTS_GET} {prompt_name}") + span.set_attribute(MCPSpanAttributes.MCP_PROMPT_NAME, prompt_name) + return + if isinstance(request, (types.ReadResourceRequest, types.SubscribeRequest, types.UnsubscribeRequest)): + resource_uri = str(request.params.uri) + span.update_name(f"{MCPSpanAttributes.MCP_RESOURCE_URI} {resource_uri}") + span.set_attribute(MCPSpanAttributes.MCP_RESOURCE_URI, resource_uri) + return + + span.update_name(request.method) + @staticmethod def serialize(args): try: diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py index 099b8c7da..611c5dcaa 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py @@ -26,7 +26,6 @@ class MCPSpanAttributes: MCP_TOOL_NAME = "mcp.tool.name" """ The name of the tool provided in the request. - Conditionally Required when operation is related to a specific tool. """ MCP_REQUEST_ARGUMENT = "mcp.request.argument" @@ -34,9 +33,23 @@ class MCPSpanAttributes: Full attribute: mcp.request.argument. Additional arguments passed to the request within params object. being the normalized argument name (lowercase), the value being the argument value. """ + MCP_PROMPT_NAME = "mcp.prompt.name" + """ + The name of the prompt or prompt template provided in the request or response + Conditionally Required when operation is related to a specific prompt. + """ + MCP_RESOURCE_URI = "mcp.resource.uri" + """ + The value of the resource uri. + Conditionally Required when the client executes a request type that includes a resource URI parameter. + """ + MCP_TRANSPORT_TYPE = "mcp.transport.type" + """ + The transport type used for MCP communication. + Examples: stdio, streamable_http + """ - -class MCPMethodNameValue: +class MCPMethodValue: NOTIFICATIONS_CANCELLED = "notifications/cancelled" """ @@ -67,3 +80,8 @@ class MCPMethodNameValue: """ Request to initialize the MCP client. """ + + PROMPTS_GET = "prompts/get" + """ + Request to get a prompt. + """ From 3902c60e88a6a577f1753bc5e9dcfb6beb6d605c Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Thu, 7 Aug 2025 13:40:44 -0700 Subject: [PATCH 34/41] add span support for notifications --- .../instrumentation/mcp/instrumentation.py | 165 +++++++++--------- 1 file changed, 84 insertions(+), 81 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 7f0c221bc..c6f07ab09 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -7,7 +7,7 @@ from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper from opentelemetry import context, trace -from opentelemetry.trace.status import Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes @@ -22,6 +22,9 @@ class McpInstrumentor(BaseInstrumentor): + _DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client" + _DEFAULT_SERVER_SPAN_NAME = "span.mcp.server" + """ An instrumentation class for MCP: https://modelcontextprotocol.io/overview """ @@ -35,19 +38,22 @@ def instrumentation_dependencies(self) -> Collection[str]: return ("mcp >= 1.8.1",) def _instrument(self, **kwargs: Any) -> None: - # TODO: add instrumentation for Streamable Http transport - # See: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports - register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.shared.session", "BaseSession.send_request", - self._wrap_session_send_request, + self._wrap_session_send, + ), + "mcp.shared.session", + ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.shared.session", + "BaseSession.send_notification", + self._wrap_session_send, ), "mcp.shared.session", ) - - register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.server.lowlevel.server", @@ -61,69 +67,50 @@ def _uninstrument(self, **kwargs: Any) -> None: unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.server.lowlevel.server", "Server._handle_request") - def _wrap_session_send_request( + def _wrap_session_send( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Callable: import mcp.types as types - """ - Instruments MCP client-side request sending for both stdio and Streamable HTTP transport, - see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports - - This is the master function responsible for sending requests from the client to the MCP server. - See: - - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220 - - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/client/session_group.py#L233 - - The instrumented MCP client intercepts the request to obtain attributes for creating client-side span, extracts - the current trace context, and embeds it into the request's params._meta field - before forwarding the request to the MCP server. - - Args: - wrapped: The original BaseSession.send_request method being instrumented - instance: The BaseSession instance handling the stdio communication - args: Positional arguments passed to the original send_request method, containing the ClientRequest - kwargs: Keyword arguments passed to the original send_request method - """ - async def async_wrapper(): - request: Optional[types.ClientRequest] = args[0] if len(args) > 0 else None - - if not request: + message = args[0] if len(args) > 0 else None + if not message: return await wrapped(*args, **kwargs) - request_id = None + is_client = isinstance(message, (types.ClientRequest, types.ClientNotification)) + request_id: Optional[int] = getattr(instance, "_request_id", None) + span_name = self._DEFAULT_SERVER_SPAN_NAME + span_kind = SpanKind.SERVER - if hasattr(instance, "_request_id"): - request_id = instance._request_id + if is_client: + span_name = self._DEFAULT_CLIENT_SPAN_NAME + span_kind = SpanKind.CLIENT - request_as_json = request.model_dump(by_alias=True, mode="json", exclude_none=True) + message_json = message.model_dump(by_alias=True, mode="json", exclude_none=True) - if "params" not in request_as_json: - request_as_json["params"] = {} - if "_meta" not in request_as_json["params"]: - request_as_json["params"]["_meta"] = {} + if "params" not in message_json: + message_json["params"] = {} + if "_meta" not in message_json["params"]: + message_json["params"]["_meta"] = {} - with self.tracer.start_as_current_span("span.mcp.client", kind=trace.SpanKind.CLIENT) as client_span: + with self.tracer.start_as_current_span(name=span_name, kind=span_kind) as span: + ctx = trace.set_span_in_context(span) + carrier = {} + self.propagators.inject(carrier=carrier, context=ctx) + message_json["params"]["_meta"].update(carrier) - span_ctx = trace.set_span_in_context(client_span) - parent_span = {} - self.propagators.inject(carrier=parent_span, context=span_ctx) + McpInstrumentor._generate_mcp_req_attrs(span, message, request_id) - McpInstrumentor._generate_mcp_span_attrs(client_span, request, request_id) - request_as_json["params"]["_meta"].update(parent_span) - - # Reconstruct request object with injected trace context - modified_request = request.model_validate(request_as_json) - new_args = (modified_request,) + args[1:] + modified_message = message.model_validate(message_json) + new_args = (modified_message,) + args[1:] try: result = await wrapped(*new_args, **kwargs) - client_span.set_status(Status(StatusCode.OK)) + span.set_status(Status(StatusCode.OK)) return result except Exception as e: - client_span.set_status(Status(StatusCode.ERROR, str(e))) - client_span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, str(e))) + span.record_exception(e) raise return async_wrapper() @@ -132,10 +119,10 @@ async def _wrap_server_handle_request( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: """ - Instruments MCP server-side request handling for both stdio and Streamable HTTP transport, + Instruments MCP server-side request handling for both stdio and Streamable HTTP transport, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports - This is the core function responsible for processing incoming requests on the MCP server. + This is the core function responsible for processing incoming requests on the MCP server. See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616 @@ -150,74 +137,90 @@ async def _wrap_server_handle_request( kwargs: Keyword arguments passed to the original _handle_request method """ incoming_req = args[1] if len(args) > 1 else None + + if not incoming_req: + return await wrapped(*args, **kwargs) + request_id = None carrier = {} - if incoming_req and hasattr(incoming_req, "id"): + if hasattr(incoming_req, "id") and incoming_req.id: request_id = incoming_req.id - if incoming_req and hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta"): + if hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta") and incoming_req.meta: carrier = incoming_req.params.meta.model_dump() + # If MCP client is instrumented then params._meta field will contain the + # parent trace context. parent_ctx = self.propagators.extract(carrier=carrier) - if parent_ctx: - with self.tracer.start_as_current_span( - "span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx - ) as server_span: + with self.tracer.start_as_current_span( + "span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx + ) as server_span: - self._generate_mcp_span_attrs(server_span, incoming_req, request_id) + self._generate_mcp_req_attrs(server_span, incoming_req, request_id) - try: - result = await wrapped(*args, **kwargs) - server_span.set_status(Status(StatusCode.OK)) - return result - except Exception as e: - server_span.set_status(Status(StatusCode.ERROR, str(e))) - server_span.record_exception(e) - raise + try: + result = await wrapped(*args, **kwargs) + server_span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + server_span.set_status(Status(StatusCode.ERROR, str(e))) + server_span.record_exception(e) + raise @staticmethod - def _generate_mcp_span_attrs(span: trace.Span, request, request_id: Optional[str]) -> None: + def _generate_mcp_req_attrs(span: trace.Span, request, request_id: Optional[int]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import - # Client-side: request is of type ClientRequest which contains the Union of different RootModel types - # Server-side: request is passed the RootModel + """ + Populates the given span with MCP semantic convention attributes based on the request type. + These semantic conventions are based off: https://github.com/open-telemetry/semantic-conventions/pull/2083 + which are currently in development and are considered unstable. + + Args: + span: The MCP span to be enriched with MCP attributes + request: The MCP request object, from Client Side it is of type ClientRequestModel and from server side it's of type RootModel + request_id: Unique identifier for the request. In theory, this should never be Optional since all requests made from MCP client to server will contain a request id. + """ + + # Client-side request type will be ClientRequest which has root as field + # Server-side: request type will be the root object passed from ClientRequest # See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/types.py#L1220 if hasattr(request, "root"): request = request.root if request_id: span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id) - + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, request.method) if isinstance(request, types.CallToolRequest): tool_name = request.params.name span.update_name(f"{MCPMethodValue.TOOLS_CALL} {tool_name}") span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name) - + request.params.arguments if request.params.arguments: for arg_name, arg_val in request.params.arguments.items(): span.set_attribute( f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val) ) - return + return if isinstance(request, types.GetPromptRequest): prompt_name = request.params.name span.update_name(f"{MCPMethodValue.PROMPTS_GET} {prompt_name}") span.set_attribute(MCPSpanAttributes.MCP_PROMPT_NAME, prompt_name) - return + return if isinstance(request, (types.ReadResourceRequest, types.SubscribeRequest, types.UnsubscribeRequest)): resource_uri = str(request.params.uri) span.update_name(f"{MCPSpanAttributes.MCP_RESOURCE_URI} {resource_uri}") span.set_attribute(MCPSpanAttributes.MCP_RESOURCE_URI, resource_uri) - return - + return + span.update_name(request.method) - + @staticmethod - def serialize(args): + def serialize(args: dict[str, Any]) -> str: try: return json.dumps(args) except Exception: - return str(args) + return "" From 51aaba908eea080d02fe6dbff0e8cb510f81028f Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Thu, 7 Aug 2025 14:50:10 -0700 Subject: [PATCH 35/41] add support for notifications + refactoring --- .../instrumentation/mcp/instrumentation.py | 166 +++++++++++++----- 1 file changed, 121 insertions(+), 45 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index c6f07ab09..7553bc4c4 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -2,11 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass import json -from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, cast +from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, Union, cast -from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper +from wrapt import register_post_import_hook, wrap_function_wrapper -from opentelemetry import context, trace +from opentelemetry import trace from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap @@ -22,13 +22,13 @@ class McpInstrumentor(BaseInstrumentor): - _DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client" - _DEFAULT_SERVER_SPAN_NAME = "span.mcp.server" - """ An instrumentation class for MCP: https://modelcontextprotocol.io/overview """ + _DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client" + _DEFAULT_SERVER_SPAN_NAME = "span.mcp.server" + def __init__(self, **kwargs): super().__init__() self.propagators = kwargs.get("propagators") or get_global_textmap() @@ -62,29 +62,56 @@ def _instrument(self, **kwargs: Any) -> None: ), "mcp.server.lowlevel.server", ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.lowlevel.server", + "Server._handle_notification", + self._wrap_server_handle_notification, + ), + "mcp.server.lowlevel.server", + ) def _uninstrument(self, **kwargs: Any) -> None: unwrap("mcp.shared.session", "BaseSession.send_request") + unwrap("mcp.shared.session", "BaseSession.send_notification") unwrap("mcp.server.lowlevel.server", "Server._handle_request") + unwrap("mcp.server.lowlevel.server", "Server._handle_notification") def _wrap_session_send( self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Callable: - import mcp.types as types + """ + Instruments MCP client and server request/notification sending for both stdio and Streamable HTTP transport, + see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + + See: + - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220 + - https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L296 + + This instrumentation intercepts the requests/notification messages sent between client and server to obtain attributes for creating span, injects + the current trace context, and embeds it into the request's params._meta field before forwarding the request to the MCP server. + + Args: + wrapped: The original BaseSession.send_request/send_notification method + instance: The BaseSession instance + args: Positional arguments passed to the original send_request/send_notification method + kwargs: Keyword arguments passed to the original send_request/send_notification method + """ + from mcp.types import ClientRequest, ClientNotification, ServerRequest, ServerNotification async def async_wrapper(): - message = args[0] if len(args) > 0 else None + message: Optional[Union[ClientRequest, ClientNotification, ServerRequest, ServerNotification]] = ( + args[0] if len(args) > 0 else None + ) + if not message: return await wrapped(*args, **kwargs) - is_client = isinstance(message, (types.ClientRequest, types.ClientNotification)) request_id: Optional[int] = getattr(instance, "_request_id", None) - span_name = self._DEFAULT_SERVER_SPAN_NAME - span_kind = SpanKind.SERVER + span_name, span_kind = self._DEFAULT_SERVER_SPAN_NAME, SpanKind.SERVER - if is_client: - span_name = self._DEFAULT_CLIENT_SPAN_NAME - span_kind = SpanKind.CLIENT + if isinstance(message, (ClientRequest, ClientNotification)): + span_name, span_kind = self._DEFAULT_CLIENT_SPAN_NAME, SpanKind.CLIENT message_json = message.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -99,7 +126,7 @@ async def async_wrapper(): self.propagators.inject(carrier=carrier, context=ctx) message_json["params"]["_meta"].update(carrier) - McpInstrumentor._generate_mcp_req_attrs(span, message, request_id) + McpInstrumentor._generate_mcp_message_attrs(span, message, request_id) modified_message = message.model_validate(message_json) new_args = (modified_message,) + args[1:] @@ -126,10 +153,6 @@ async def _wrap_server_handle_request( See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616 - The instrumented MCP server intercepts incoming requests to extract tracing context from - the request's params._meta field, creates server-side spans linked to the originating client spans, - and processes the request while maintaining trace continuity. - Args: wrapped: The original Server._handle_request method being instrumented instance: The MCP Server instance processing the stdio communication @@ -137,27 +160,72 @@ async def _wrap_server_handle_request( kwargs: Keyword arguments passed to the original _handle_request method """ incoming_req = args[1] if len(args) > 1 else None + return await self._wrap_server_message_handler(wrapped, instance, args, kwargs, incoming_msg=incoming_req) + + async def _wrap_server_handle_notification( + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + """ + Instruments MCP server-side notification handling for both stdio and Streamable HTTP transport, + This is the core function responsible for processing incoming notifications on the MCP server instance. + See: + https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616 + + Args: + wrapped: The original Server._handle_notification method being instrumented + instance: The MCP Server instance processing the stdio communication + args: Positional arguments passed to the original _handle_request method, containing the incoming request + kwargs: Keyword arguments passed to the original _handle_request method + """ + incoming_notif = args[0] if len(args) > 0 else None + return await self._wrap_server_message_handler(wrapped, instance, args, kwargs, incoming_msg=incoming_notif) + + async def _wrap_server_message_handler( + self, + wrapped: Callable, + instance: Any, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + incoming_msg: Optional[Any], + ) -> Any: + """ + Instruments MCP server-side request/notification handling for both stdio and Streamable HTTP transport, + see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports + + See: + https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616 + + The instrumented MCP server intercepts incoming requests/notification messages from the client to extract tracing context from + the messages's params._meta field and creates server-side spans linked to the originating client spans. - if not incoming_req: + Args: + wrapped: The original Server._handle_notification/_handle_request method being instrumented + instance: The Server instance + args: Positional arguments passed to the original _handle_request/ method, containing the incoming request + kwargs: Keyword arguments passed to the original _handle_request method + incoming_msg: The incoming message from the client, can be one of: ClientRequest or ClientNotification + """ + if not incoming_msg: return await wrapped(*args, **kwargs) request_id = None carrier = {} - if hasattr(incoming_req, "id") and incoming_req.id: - request_id = incoming_req.id - if hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta") and incoming_req.meta: - carrier = incoming_req.params.meta.model_dump() + # Request IDs are only present in Request messages not Notifications. + if hasattr(incoming_msg, "id") and incoming_msg.id: + request_id = incoming_msg.id + + # If the client is instrumented then params._meta field will contain the trace context. + if hasattr(incoming_msg, "params") and hasattr(incoming_msg.params, "meta") and incoming_msg.params.meta: + carrier = incoming_msg.params.meta.model_dump() - # If MCP client is instrumented then params._meta field will contain the - # parent trace context. parent_ctx = self.propagators.extract(carrier=carrier) with self.tracer.start_as_current_span( - "span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx + self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx ) as server_span: - self._generate_mcp_req_attrs(server_span, incoming_req, request_id) + self._generate_mcp_message_attrs(server_span, incoming_msg, request_id) try: result = await wrapped(*args, **kwargs) @@ -169,54 +237,62 @@ async def _wrap_server_handle_request( raise @staticmethod - def _generate_mcp_req_attrs(span: trace.Span, request, request_id: Optional[int]) -> None: + def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import """ - Populates the given span with MCP semantic convention attributes based on the request type. + Populates the given span with MCP semantic convention attributes based on the message type. These semantic conventions are based off: https://github.com/open-telemetry/semantic-conventions/pull/2083 which are currently in development and are considered unstable. Args: span: The MCP span to be enriched with MCP attributes - request: The MCP request object, from Client Side it is of type ClientRequestModel and from server side it's of type RootModel - request_id: Unique identifier for the request. In theory, this should never be Optional since all requests made from MCP client to server will contain a request id. + message: The MCP message object, from client side it is of type ClientRequestModel/ClientNotificationModel and from server side it gets passed as type RootModel + request_id: Unique identifier for the request or None if the message is a notification. """ # Client-side request type will be ClientRequest which has root as field # Server-side: request type will be the root object passed from ClientRequest # See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/types.py#L1220 - if hasattr(request, "root"): - request = request.root + if hasattr(message, "root"): + message = message.root if request_id: span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id) - span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, request.method) + span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, message.method) - if isinstance(request, types.CallToolRequest): - tool_name = request.params.name + if isinstance(message, types.CallToolRequest): + tool_name = message.params.name span.update_name(f"{MCPMethodValue.TOOLS_CALL} {tool_name}") span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name) - request.params.arguments - if request.params.arguments: - for arg_name, arg_val in request.params.arguments.items(): + message.params.arguments + if message.params.arguments: + for arg_name, arg_val in message.params.arguments.items(): span.set_attribute( f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val) ) return - if isinstance(request, types.GetPromptRequest): - prompt_name = request.params.name + if isinstance(message, types.GetPromptRequest): + prompt_name = message.params.name span.update_name(f"{MCPMethodValue.PROMPTS_GET} {prompt_name}") span.set_attribute(MCPSpanAttributes.MCP_PROMPT_NAME, prompt_name) return - if isinstance(request, (types.ReadResourceRequest, types.SubscribeRequest, types.UnsubscribeRequest)): - resource_uri = str(request.params.uri) + if isinstance( + message, + ( + types.ReadResourceRequest, + types.SubscribeRequest, + types.UnsubscribeRequest, + types.ResourceUpdatedNotification, + ), + ): + resource_uri = str(message.params.uri) span.update_name(f"{MCPSpanAttributes.MCP_RESOURCE_URI} {resource_uri}") span.set_attribute(MCPSpanAttributes.MCP_RESOURCE_URI, resource_uri) return - span.update_name(request.method) + span.update_name(message.method) @staticmethod def serialize(args: dict[str, Any]) -> str: From 9c0e57aef1c96f5497d54033d989667fc52386bf Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Thu, 7 Aug 2025 16:01:08 -0700 Subject: [PATCH 36/41] add rpc service name for MCP server --- .../distro/_aws_metric_attribute_generator.py | 5 +++++ .../distro/instrumentation/mcp/instrumentation.py | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py index 88eedb152..6e5ad6695 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/_aws_metric_attribute_generator.py @@ -41,6 +41,7 @@ AWS_STEPFUNCTIONS_STATEMACHINE_ARN, ) from amazon.opentelemetry.distro._aws_resource_attribute_configurator import get_service_attribute +from amazon.opentelemetry.distro.instrumentation.mcp.semconv import MCPSpanAttributes from amazon.opentelemetry.distro._aws_span_processing_util import ( LOCAL_ROOT, MAX_KEYWORD_LENGTH, @@ -97,6 +98,7 @@ _SERVER_SOCKET_PORT: str = SpanAttributes.SERVER_SOCKET_PORT _AWS_TABLE_NAMES: str = SpanAttributes.AWS_DYNAMODB_TABLE_NAMES _AWS_BUCKET_NAME: str = SpanAttributes.AWS_S3_BUCKET +_MCP_METHOD_NAME: str = MCPSpanAttributes.MCP_METHOD_NAME # Normalized remote service names for supported AWS services _NORMALIZED_DYNAMO_DB_SERVICE_NAME: str = "AWS::DynamoDB" @@ -263,6 +265,9 @@ def _set_remote_service_and_operation(span: ReadableSpan, attributes: BoundedAtt elif is_key_present(span, _GRAPHQL_OPERATION_TYPE): remote_service = _GRAPHQL remote_operation = _get_remote_operation(span, _GRAPHQL_OPERATION_TYPE) + elif is_key_present(span, _MCP_METHOD_NAME) and is_key_present(span, _RPC_SERVICE): + remote_service = _normalize_remote_service_name(span, _get_remote_service(span, _RPC_SERVICE)) + remote_operation = _get_remote_operation(span, _MCP_METHOD_NAME) # Peer service takes priority as RemoteService over everything but AWS Remote. if is_key_present(span, _PEER_SERVICE) and not is_key_present(span, AWS_REMOTE_SERVICE): diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 7553bc4c4..a5cadadb8 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -108,10 +108,12 @@ async def async_wrapper(): return await wrapped(*args, **kwargs) request_id: Optional[int] = getattr(instance, "_request_id", None) - span_name, span_kind = self._DEFAULT_SERVER_SPAN_NAME, SpanKind.SERVER + span_name = self._DEFAULT_SERVER_SPAN_NAME + span_kind = SpanKind.SERVER if isinstance(message, (ClientRequest, ClientNotification)): - span_name, span_kind = self._DEFAULT_CLIENT_SPAN_NAME, SpanKind.CLIENT + span_name = self._DEFAULT_CLIENT_SPAN_NAME + span_kind = SpanKind.CLIENT message_json = message.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -225,6 +227,7 @@ async def _wrap_server_message_handler( self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx ) as server_span: + server_span.set_attribute(SpanAttributes.RPC_SERVICE, instance.name) self._generate_mcp_message_attrs(server_span, incoming_msg, request_id) try: From dc91c6262bfa4cd761b0d2609173d2b42056590e Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Thu, 7 Aug 2025 16:22:22 -0700 Subject: [PATCH 37/41] removed typo --- .../opentelemetry/distro/instrumentation/mcp/instrumentation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index a5cadadb8..951918fc2 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -269,7 +269,6 @@ def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[ tool_name = message.params.name span.update_name(f"{MCPMethodValue.TOOLS_CALL} {tool_name}") span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name) - message.params.arguments if message.params.arguments: for arg_name, arg_val in message.params.arguments.items(): span.set_attribute( From d7666ca15fece43972b4015435ef6d21b4af0050 Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Thu, 7 Aug 2025 17:10:48 -0700 Subject: [PATCH 38/41] modify contract tests --- .../mcp/test_mcpinstrumentor.py | 943 +++++++++++++++++ .../distro/test_mcpinstrumentor.py | 945 ------------------ .../images/applications/mcp/client.py | 52 +- .../images/applications/mcp/mcp_server.py | 11 + .../tests/test/amazon/mcp/mcp_test.py | 92 +- 5 files changed, 1013 insertions(+), 1030 deletions(-) create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/instrumentation/mcp/test_mcpinstrumentor.py delete mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/instrumentation/mcp/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/instrumentation/mcp/test_mcpinstrumentor.py new file mode 100644 index 000000000..096209ce9 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/instrumentation/mcp/test_mcpinstrumentor.py @@ -0,0 +1,943 @@ +# # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# # SPDX-License-Identifier: Apache-2.0 + +# """ +# Unit tests for MCPInstrumentor - testing actual mcpinstrumentor methods +# """ + +# import asyncio +# import sys +# import unittest +# from typing import Any, Dict, List, Optional +# from unittest.mock import MagicMock + +# from amazon.opentelemetry.distro.instrumentation.mcp import version + + +# # Mock the mcp module to prevent import errors +# class MockMCPTypes: +# class CallToolRequest: +# pass + +# class ListToolsRequest: +# pass + +# class InitializeRequest: +# pass + +# class ClientResult: +# pass + + +# mock_mcp_types = MockMCPTypes() +# sys.modules["mcp"] = MagicMock() +# sys.modules["mcp.types"] = mock_mcp_types +# sys.modules["mcp.shared"] = MagicMock() +# sys.modules["mcp.shared.session"] = MagicMock() +# sys.modules["mcp.server"] = MagicMock() +# sys.modules["mcp.server.lowlevel"] = MagicMock() +# sys.modules["mcp.server.lowlevel.server"] = MagicMock() + + +# class SimpleSpanContext: +# """Simple mock span context without using MagicMock""" + +# def __init__(self, trace_id: int, span_id: int) -> None: +# self.trace_id = trace_id +# self.span_id = span_id + + +# class SimpleTracerProvider: +# """Simple mock tracer provider without using MagicMock""" + +# def __init__(self) -> None: +# self.get_tracer_called = False +# self.tracer_name: Optional[str] = None + +# def get_tracer(self, name: str) -> str: +# self.get_tracer_called = True +# self.tracer_name = name +# return "mock_tracer_from_provider" + + +# class TestInjectTraceContext(unittest.TestCase): +# """Test the _inject_trace_context method""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() + +# def test_inject_trace_context_empty_dict(self) -> None: +# """Test injecting trace context into empty dictionary""" +# # Setup +# request_data = {} +# span_ctx = SimpleSpanContext(trace_id=12345, span_id=67890) + +# # Execute - Actually test the mcpinstrumentor method +# self.instrumentor._inject_trace_context(request_data, span_ctx) + +# # Verify - now uses traceparent W3C format +# self.assertIn("params", request_data) +# self.assertIn("_meta", request_data["params"]) +# self.assertIn("traceparent", request_data["params"]["_meta"]) + +# # Verify traceparent format: "00-{trace_id:032x}-{span_id:016x}-01" +# traceparent = request_data["params"]["_meta"]["traceparent"] +# self.assertTrue(traceparent.startswith("00-")) +# self.assertTrue(traceparent.endswith("-01")) +# parts = traceparent.split("-") +# self.assertEqual(len(parts), 4) +# self.assertEqual(int(parts[1], 16), 12345) # trace_id +# self.assertEqual(int(parts[2], 16), 67890) # span_id + +# def test_inject_trace_context_existing_params(self) -> None: +# """Test injecting trace context when params already exist""" +# # Setup +# request_data = {"params": {"existing_field": "test_value"}} +# span_ctx = SimpleSpanContext(trace_id=99999, span_id=11111) + +# # Execute - Actually test the mcpinstrumentor method +# self.instrumentor._inject_trace_context(request_data, span_ctx) + +# # Verify the existing field is preserved and traceparent is added +# self.assertEqual(request_data["params"]["existing_field"], "test_value") +# self.assertIn("_meta", request_data["params"]) +# self.assertIn("traceparent", request_data["params"]["_meta"]) + +# # Verify traceparent format contains correct trace/span IDs +# traceparent = request_data["params"]["_meta"]["traceparent"] +# parts = traceparent.split("-") +# self.assertEqual(int(parts[1], 16), 99999) # trace_id +# self.assertEqual(int(parts[2], 16), 11111) # span_id + + +# class TestTracerProvider(unittest.TestCase): +# """Test the tracer provider kwargs logic in _instrument method""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() +# # Reset tracer to ensure test isolation +# if hasattr(self.instrumentor, "tracer"): +# delattr(self.instrumentor, "tracer") + +# def test_instrument_without_tracer_provider_kwargs(self) -> None: +# """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" +# # Execute - Actually test the mcpinstrumentor method +# with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer, unittest.mock.patch( +# "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" +# ): +# mock_get_tracer.return_value = "default_tracer" +# self.instrumentor._instrument() + +# # Verify - tracer should be set from trace.get_tracer +# self.assertTrue(hasattr(self.instrumentor, "tracer")) +# self.assertEqual(self.instrumentor.tracer, "default_tracer") +# mock_get_tracer.assert_called_with("instrumentation.mcp") + +# def test_instrument_with_tracer_provider_kwargs(self) -> None: +# """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer""" +# # Setup +# provider = SimpleTracerProvider() + +# # Execute - Actually test the mcpinstrumentor method +# with unittest.mock.patch( +# "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" +# ): +# self.instrumentor._instrument(tracer_provider=provider) + +# # Verify - tracer should be set from the provided tracer_provider +# self.assertTrue(hasattr(self.instrumentor, "tracer")) +# self.assertEqual(self.instrumentor.tracer, "mock_tracer_from_provider") +# self.assertTrue(provider.get_tracer_called) +# self.assertEqual(provider.tracer_name, "instrumentation.mcp") + + +# class TestInstrumentationDependencies(unittest.TestCase): +# """Test the instrumentation_dependencies method""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() + +# def test_instrumentation_dependencies(self) -> None: +# """Test that instrumentation_dependencies method returns the expected dependencies""" +# # Execute - Actually test the mcpinstrumentor method +# dependencies = self.instrumentor.instrumentation_dependencies() + +# # Verify - should return the _instruments collection +# self.assertIsNotNone(dependencies) +# # Should contain mcp dependency +# self.assertIn("mcp >= 1.6.0", dependencies) + + +# class TestTraceContextInjection(unittest.TestCase): +# """Test trace context injection using actual mcpinstrumentor methods""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() + +# def test_trace_context_injection_with_realistic_request(self) -> None: +# """Test actual trace context injection using mcpinstrumentor._inject_trace_context with realistic MCP request""" + +# # Create a realistic MCP request structure +# class CallToolRequest: +# def __init__(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> None: +# self.root = self +# self.params = CallToolParams(tool_name, arguments) + +# def model_dump( +# self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True +# ) -> Dict[str, Any]: +# result = {"method": "call_tool", "params": {"name": self.params.name}} +# if self.params.arguments: +# result["params"]["arguments"] = self.params.arguments +# # Include _meta if it exists (trace context injection point) +# if hasattr(self.params, "_meta") and self.params._meta: +# result["params"]["_meta"] = self.params._meta +# return result + +# # converting raw dictionary data back into an instance of this class +# @classmethod +# def model_validate(cls, data: Dict[str, Any]) -> "CallToolRequest": +# instance = cls(data["params"]["name"], data["params"].get("arguments")) +# # Restore _meta field if present +# if "_meta" in data["params"]: +# instance.params._meta = data["params"]["_meta"] +# return instance + +# class CallToolParams: +# def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: +# self.name = name +# self.arguments = arguments +# self._meta: Optional[Dict[str, Any]] = None # Will hold trace context + +# # Client creates original request +# client_request = CallToolRequest("create_metric", {"metric_name": "response_time", "value": 250}) + +# # Client injects trace context using ACTUAL mcpinstrumentor method +# original_trace_context = SimpleSpanContext(trace_id=98765, span_id=43210) +# request_data = client_request.model_dump() + +# # This is the actual mcpinstrumentor method we're testing +# self.instrumentor._inject_trace_context(request_data, original_trace_context) + +# # Create modified request with trace context +# modified_request = CallToolRequest.model_validate(request_data) + +# # Verify the actual mcpinstrumentor method worked correctly +# client_data = modified_request.model_dump() +# self.assertIn("_meta", client_data["params"]) +# self.assertIn("traceparent", client_data["params"]["_meta"]) + +# # Verify traceparent format contains correct trace/span IDs +# traceparent = client_data["params"]["_meta"]["traceparent"] +# parts = traceparent.split("-") +# self.assertEqual(int(parts[1], 16), 98765) # trace_id +# self.assertEqual(int(parts[2], 16), 43210) # span_id + +# # Verify the tool call data is also preserved +# self.assertEqual(client_data["params"]["name"], "create_metric") +# self.assertEqual(client_data["params"]["arguments"]["metric_name"], "response_time") + + +# class TestInstrumentedMCPServer(unittest.TestCase): +# """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() +# # Initialize tracer so the instrumentor can work +# mock_tracer = MagicMock() +# self.instrumentor.tracer = mock_tracer + +# def test_no_trace_context_fallback(self) -> None: +# """Test graceful handling when no trace context is present on server side""" + +# class MockServerNoTrace: +# @staticmethod +# async def _handle_request(session: Any, request: Any) -> Dict[str, Any]: +# return {"success": True, "handled_without_trace": True} + +# class MockServerRequestNoTrace: +# def __init__(self, tool_name: str) -> None: +# self.params = MockServerRequestParamsNoTrace(tool_name) + +# class MockServerRequestParamsNoTrace: +# def __init__(self, name: str) -> None: +# self.name = name +# self.meta: Optional[Any] = None # No trace context + +# mock_server = MockServerNoTrace() +# server_request = MockServerRequestNoTrace("create_metric") + +# # Setup mocks +# mock_tracer = MagicMock() +# mock_span = MagicMock() +# mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span +# mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + +# # Test server handling without trace context (fallback scenario) +# with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( +# "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} +# ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( +# self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" +# ): + +# result = asyncio.run( +# self.instrumentor._wrap_handle_request(mock_server._handle_request, None, (None, server_request), {}) +# ) + +# # Verify graceful fallback - no tracing spans should be created when no trace context +# # The wrapper should call the original function without creating distributed trace spans +# self.assertEqual(result["success"], True) +# self.assertEqual(result["handled_without_trace"], True) + +# # Should not create traced spans when no trace context is present +# mock_tracer.start_as_current_span.assert_not_called() + +# # pylint: disable=too-many-locals,too-many-statements +# def test_end_to_end_client_server_communication( +# self, +# ) -> None: +# """Test where server actually receives what client sends (including injected trace context)""" + +# # Create realistic request/response classes +# class MCPRequest: +# def __init__( +# self, tool_name: str, arguments: Optional[Dict[str, Any]] = None, method: str = "call_tool" +# ) -> None: +# self.root = self +# self.params = MCPRequestParams(tool_name, arguments) +# self.method = method + +# def model_dump( +# self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True +# ) -> Dict[str, Any]: +# result = {"method": self.method, "params": {"name": self.params.name}} +# if self.params.arguments: +# result["params"]["arguments"] = self.params.arguments +# # Include _meta if it exists (for trace context) +# if hasattr(self.params, "_meta") and self.params._meta: +# result["params"]["_meta"] = self.params._meta +# return result + +# @classmethod +# def model_validate(cls, data: Dict[str, Any]) -> "MCPRequest": +# method = data.get("method", "call_tool") +# instance = cls(data["params"]["name"], data["params"].get("arguments"), method) +# # Restore _meta field if present +# if "_meta" in data["params"]: +# instance.params._meta = data["params"]["_meta"] +# return instance + +# class MCPRequestParams: +# def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: +# self.name = name +# self.arguments = arguments +# self._meta: Optional[Dict[str, Any]] = None + +# class MCPServerRequest: +# def __init__(self, client_request_data: Dict[str, Any]) -> None: +# """Server request created from client's serialized data""" +# self.method = client_request_data.get("method", "call_tool") +# self.params = MCPServerRequestParams(client_request_data["params"]) + +# class MCPServerRequestParams: +# def __init__(self, params_data: Dict[str, Any]) -> None: +# self.name = params_data["name"] +# self.arguments = params_data.get("arguments") +# # Extract traceparent from _meta if present +# if "_meta" in params_data and "traceparent" in params_data["_meta"]: +# self.meta = MCPServerRequestMeta(params_data["_meta"]["traceparent"]) +# else: +# self.meta = None + +# class MCPServerRequestMeta: +# def __init__(self, traceparent: str) -> None: +# self.traceparent = traceparent + +# # Mock client and server that actually communicate +# class EndToEndMCPSystem: +# def __init__(self) -> None: +# self.communication_log: List[str] = [] +# self.last_sent_request: Optional[Any] = None + +# async def client_send_request(self, request: Any) -> Dict[str, Any]: +# """Client sends request - captures what gets sent""" +# self.communication_log.append("CLIENT: Preparing to send request") +# self.last_sent_request = request # Capture the modified request + +# # Simulate sending over network - serialize the request +# serialized_request = request.model_dump() +# self.communication_log.append(f"CLIENT: Sent {serialized_request}") + +# # Return client response +# return {"success": True, "client_response": "Request sent successfully"} + +# async def server_handle_request(self, session: Any, server_request: Any) -> Dict[str, Any]: +# """Server handles the request it received""" +# self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") + +# # Check if traceparent was received +# if server_request.params.meta and server_request.params.meta.traceparent: +# traceparent = server_request.params.meta.traceparent +# # Parse traceparent to extract trace_id and span_id +# parts = traceparent.split("-") +# if len(parts) == 4: +# trace_id = int(parts[1], 16) +# span_id = int(parts[2], 16) +# self.communication_log.append( +# f"SERVER: Found trace context - trace_id: {trace_id}, " f"span_id: {span_id}" +# ) +# else: +# self.communication_log.append("SERVER: Invalid traceparent format") +# else: +# self.communication_log.append("SERVER: No trace context found") + +# return {"success": True, "server_response": f"Handled {server_request.params.name}"} + +# # Create the end-to-end system +# e2e_system = EndToEndMCPSystem() + +# # Create original client request +# original_request = MCPRequest("create_metric", {"name": "cpu_usage", "value": 85}) + +# # Setup OpenTelemetry mocks +# mock_tracer = MagicMock() +# mock_span = MagicMock() +# mock_span_context = MagicMock() +# mock_span_context.trace_id = 12345 +# mock_span_context.span_id = 67890 +# mock_span.get_span_context.return_value = mock_span_context +# mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span +# mock_tracer.start_as_current_span.return_value.__exit__.return_value = None + +# # STEP 1: Client sends request through instrumentation +# with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( +# "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} +# ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): +# # Override the setup tracer with the properly mocked one +# self.instrumentor.tracer = mock_tracer + +# client_result = asyncio.run( +# self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {}) +# ) + +# # Verify client side worked +# self.assertEqual(client_result["success"], True) +# self.assertIn("CLIENT: Preparing to send request", e2e_system.communication_log) + +# # Get the request that was actually sent (with trace context injected) +# sent_request = e2e_system.last_sent_request +# sent_request_data = sent_request.model_dump() + +# # Verify traceparent was injected by client instrumentation +# self.assertIn("_meta", sent_request_data["params"]) +# self.assertIn("traceparent", sent_request_data["params"]["_meta"]) + +# # Parse and verify traceparent contains correct trace/span IDs +# traceparent = sent_request_data["params"]["_meta"]["traceparent"] +# parts = traceparent.split("-") +# self.assertEqual(int(parts[1], 16), 12345) # trace_id +# self.assertEqual(int(parts[2], 16), 67890) # span_id + +# # STEP 2: Server receives the EXACT request that client sent +# # Create server request from the client's serialized data +# server_request = MCPServerRequest(sent_request_data) + +# # Reset tracer mock for server side +# mock_tracer.reset_mock() + +# # Server processes the request it received +# with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( +# "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} +# ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( +# self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" +# ): + +# server_result = asyncio.run( +# self.instrumentor._wrap_handle_request( +# e2e_system.server_handle_request, None, (None, server_request), {} +# ) +# ) + +# # Verify server side worked +# self.assertEqual(server_result["success"], True) + +# # Verify end-to-end trace context propagation +# self.assertIn("SERVER: Found trace context - trace_id: 12345, span_id: 67890", e2e_system.communication_log) + +# # Verify the server received the exact same data the client sent +# self.assertEqual(server_request.params.name, "create_metric") +# self.assertEqual(server_request.params.arguments["name"], "cpu_usage") +# self.assertEqual(server_request.params.arguments["value"], 85) + +# # Verify the traceparent made it through end-to-end +# self.assertIsNotNone(server_request.params.meta) +# self.assertIsNotNone(server_request.params.meta.traceparent) + +# # Parse traceparent and verify trace/span IDs +# traceparent = server_request.params.meta.traceparent +# parts = traceparent.split("-") +# self.assertEqual(int(parts[1], 16), 12345) # trace_id +# self.assertEqual(int(parts[2], 16), 67890) # span_id + +# # Verify complete communication flow +# expected_log_entries = [ +# "CLIENT: Preparing to send request", +# "CLIENT: Sent", # Part of the serialized request log +# "SERVER: Received request for create_metric", +# "SERVER: Found trace context - trace_id: 12345, span_id: 67890", +# ] + +# for expected_entry in expected_log_entries: +# self.assertTrue( +# any(expected_entry in log_entry for log_entry in e2e_system.communication_log), +# f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", +# ) + + +# class TestMCPInstrumentorEdgeCases(unittest.TestCase): +# """Test edge cases and error conditions for MCP instrumentor""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() + +# def test_invalid_traceparent_format(self) -> None: +# """Test handling of malformed traceparent headers""" +# invalid_formats = [ +# "invalid-format", +# "00-invalid-hex-01", +# "00-12345-67890", # Missing part +# "00-12345-67890-01-extra", # Too many parts +# "", # Empty string +# ] + +# for invalid_format in invalid_formats: +# with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): +# result = self.instrumentor._extract_span_context_from_traceparent(invalid_format) +# self.assertIsNone(result, f"Should return None for invalid format: {invalid_format}") + +# def test_version_import(self) -> None: +# """Test that version can be imported""" +# self.assertIsNotNone(version) + +# def test_constants_import(self) -> None: +# """Test that constants can be imported""" +# self.assertIsNotNone(MCPEnvironmentVariables.SERVER_NAME) + +# def test_add_client_attributes_default_server_name(self) -> None: +# """Test _add_client_attributes uses default server name""" +# mock_span = MagicMock() + +# class MockRequest: +# def __init__(self) -> None: +# self.params = MockParams() + +# class MockParams: +# def __init__(self) -> None: +# self.name = "test_tool" + +# request = MockRequest() +# self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + +# # Verify default server name is used +# mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") +# mock_span.set_attribute.assert_any_call("rpc.method", "test_operation") +# mock_span.set_attribute.assert_any_call("mcp.tool.name", "test_tool") + +# def test_add_client_attributes_without_tool_name(self) -> None: +# """Test _add_client_attributes when request has no tool name""" +# mock_span = MagicMock() + +# class MockRequestNoTool: +# def __init__(self) -> None: +# self.params = None + +# request = MockRequestNoTool() +# self.instrumentor._add_client_attributes(mock_span, "test_operation", request) + +# # Should still set service and method, but not tool name +# mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") +# mock_span.set_attribute.assert_any_call("rpc.method", "test_operation") + +# def test_add_server_attributes_without_tool_name(self) -> None: +# """Test _add_server_attributes when request has no tool name""" +# mock_span = MagicMock() + +# class MockRequestNoTool: +# def __init__(self) -> None: +# self.params = None + +# request = MockRequestNoTool() +# self.instrumentor._add_server_attributes(mock_span, "test_operation", request) + +# # Should not set any attributes for server when no tool name +# mock_span.set_attribute.assert_not_called() + +# def test_inject_trace_context_empty_request(self) -> None: +# """Test trace context injection with minimal request data""" +# request_data = {} +# span_ctx = SimpleSpanContext(trace_id=111, span_id=222) + +# self.instrumentor._inject_trace_context(request_data, span_ctx) + +# # Should create params and _meta structure +# self.assertIn("params", request_data) +# self.assertIn("_meta", request_data["params"]) +# self.assertIn("traceparent", request_data["params"]["_meta"]) + +# # Verify traceparent format +# traceparent = request_data["params"]["_meta"]["traceparent"] +# parts = traceparent.split("-") +# self.assertEqual(len(parts), 4) +# self.assertEqual(int(parts[1], 16), 111) # trace_id +# self.assertEqual(int(parts[2], 16), 222) # span_id + +# def test_uninstrument(self) -> None: +# """Test _uninstrument method removes instrumentation""" +# with unittest.mock.patch( +# "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" +# ) as mock_unwrap: +# self.instrumentor._uninstrument() +# self.assertEqual(mock_unwrap.call_count, 2) +# mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") +# mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") + +# def test_extract_span_context_valid_traceparent(self) -> None: +# """Test _extract_span_context_from_traceparent with valid format""" +# # Use correct hex values: 12345 = 0x3039, 67890 = 0x10932 +# valid_traceparent = "00-0000000000003039-0000000000010932-01" +# result = self.instrumentor._extract_span_context_from_traceparent(valid_traceparent) +# self.assertIsNotNone(result) +# self.assertEqual(result.trace_id, 12345) +# self.assertEqual(result.span_id, 67890) +# self.assertTrue(result.is_remote) + +# def test_extract_span_context_value_error(self) -> None: +# """Test _extract_span_context_from_traceparent with invalid hex values""" +# invalid_hex_traceparent = "00-invalid-hex-values-01" +# result = self.instrumentor._extract_span_context_from_traceparent(invalid_hex_traceparent) +# self.assertIsNone(result) + +# def test_instrument_method_coverage(self) -> None: +# """Test _instrument method registers hooks""" +# with unittest.mock.patch( +# "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" +# ) as mock_register: +# self.instrumentor._instrument() +# self.assertEqual(mock_register.call_count, 2) + + +# class TestWrapSendRequestEdgeCases(unittest.TestCase): +# """Test _wrap_send_request edge cases""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() +# mock_tracer = MagicMock() +# mock_span = MagicMock() +# mock_span_context = MagicMock() +# mock_span_context.trace_id = 12345 +# mock_span_context.span_id = 67890 +# mock_span.get_span_context.return_value = mock_span_context +# mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span +# mock_tracer.start_as_current_span.return_value.__exit__.return_value = None +# self.instrumentor.tracer = mock_tracer + +# def test_wrap_send_request_no_request_in_args(self) -> None: +# """Test _wrap_send_request when no request in args""" + +# async def mock_wrapped(): +# return {"result": "no_request"} + +# result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {})) +# self.assertEqual(result["result"], "no_request") + +# def test_wrap_send_request_request_in_kwargs(self) -> None: +# """Test _wrap_send_request when request is in kwargs""" + +# class MockRequest: +# def __init__(self): +# self.root = self +# self.params = MockParams() + +# @staticmethod +# def model_dump(**kwargs): +# return {"method": "test", "params": {"name": "test_tool"}} + +# @classmethod +# def model_validate(cls, data): +# return cls() + +# class MockParams: +# def __init__(self): +# self.name = "test_tool" + +# async def mock_wrapped(*args, **kwargs): +# return {"result": "kwargs_request"} + +# request = MockRequest() + +# with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): +# result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {"request": request})) +# self.assertEqual(result["result"], "kwargs_request") + +# def test_wrap_send_request_no_root_attribute(self) -> None: +# """Test _wrap_send_request when request has no root attribute""" + +# class MockRequestNoRoot: +# def __init__(self): +# self.params = MockParams() + +# @staticmethod +# def model_dump(**kwargs): +# return {"method": "test", "params": {"name": "test_tool"}} + +# @classmethod +# def model_validate(cls, data): +# return cls() + +# class MockParams: +# def __init__(self): +# self.name = "test_tool" + +# async def mock_wrapped(*args, **kwargs): +# return {"result": "no_root"} + +# request = MockRequestNoRoot() + +# with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): +# result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (request,), {})) +# self.assertEqual(result["result"], "no_root") + + +# class TestWrapHandleRequestEdgeCases(unittest.TestCase): +# """Test _wrap_handle_request edge cases""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() +# mock_tracer = MagicMock() +# mock_span = MagicMock() +# mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span +# mock_tracer.start_as_current_span.return_value.__exit__.return_value = None +# self.instrumentor.tracer = mock_tracer + +# def test_wrap_handle_request_no_request(self) -> None: +# """Test _wrap_handle_request when no request in args""" + +# async def mock_wrapped(*args, **kwargs): +# return {"result": "no_request"} + +# result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session",), {})) +# self.assertEqual(result["result"], "no_request") + +# def test_wrap_handle_request_no_params(self) -> None: +# """Test _wrap_handle_request when request has no params""" + +# class MockRequestNoParams: +# def __init__(self): +# self.params = None + +# async def mock_wrapped(*args, **kwargs): +# return {"result": "no_params"} + +# request = MockRequestNoParams() +# result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) +# self.assertEqual(result["result"], "no_params") + +# def test_wrap_handle_request_no_meta(self) -> None: +# """Test _wrap_handle_request when request params has no meta""" + +# class MockRequestNoMeta: +# def __init__(self): +# self.params = MockParamsNoMeta() + +# class MockParamsNoMeta: +# def __init__(self): +# self.meta = None + +# async def mock_wrapped(*args, **kwargs): +# return {"result": "no_meta"} + +# request = MockRequestNoMeta() +# result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) +# self.assertEqual(result["result"], "no_meta") + +# def test_wrap_handle_request_with_valid_traceparent(self) -> None: +# """Test _wrap_handle_request with valid traceparent""" + +# class MockRequestWithTrace: +# def __init__(self): +# self.params = MockParamsWithTrace() + +# class MockParamsWithTrace: +# def __init__(self): +# self.meta = MockMeta() + +# class MockMeta: +# def __init__(self): +# self.traceparent = "00-0000000000003039-0000000000010932-01" + +# async def mock_wrapped(*args, **kwargs): +# return {"result": "with_trace"} + +# request = MockRequestWithTrace() + +# with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( +# self.instrumentor, "_get_mcp_operation", return_value="tools/test" +# ): +# result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) +# self.assertEqual(result["result"], "with_trace") + + +# class TestInstrumentorStaticMethods(unittest.TestCase): +# """Test static methods of MCPInstrumentor""" + +# def test_instrumentation_dependencies_static(self) -> None: +# """Test instrumentation_dependencies as static method""" +# deps = MCPInstrumentor.instrumentation_dependencies() +# self.assertIn("mcp >= 1.6.0", deps) + +# def test_uninstrument_static(self) -> None: +# """Test _uninstrument as static method""" +# with unittest.mock.patch( +# "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" +# ) as mock_unwrap: +# MCPInstrumentor._uninstrument() +# self.assertEqual(mock_unwrap.call_count, 2) +# mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") +# mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") + + +# class TestEnvironmentVariableHandling(unittest.TestCase): +# """Test environment variable handling""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() + +# def test_server_name_environment_variable(self) -> None: +# """Test that MCP_INSTRUMENTATION_SERVER_NAME environment variable is used""" +# mock_span = MagicMock() + +# class MockRequest: +# def __init__(self): +# self.params = MockParams() + +# class MockParams: +# def __init__(self): +# self.name = "test_tool" + +# # Test with environment variable set +# with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "my-custom-server"}): +# request = MockRequest() +# self.instrumentor._add_client_attributes(mock_span, "test_operation", request) +# mock_span.set_attribute.assert_any_call("rpc.service", "my-custom-server") + +# def test_server_name_default_value(self) -> None: +# """Test that default server name is used when environment variable is not set""" +# mock_span = MagicMock() + +# class MockRequest: +# def __init__(self): +# self.params = MockParams() + +# class MockParams: +# def __init__(self): +# self.name = "test_tool" + +# # Test without environment variable +# with unittest.mock.patch.dict("os.environ", {}, clear=True): +# request = MockRequest() +# self.instrumentor._add_client_attributes(mock_span, "test_operation", request) +# mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") + + +# class TestTraceContextFormats(unittest.TestCase): +# """Test trace context format handling""" + +# def setUp(self) -> None: +# self.instrumentor = MCPInstrumentor() + +# def test_inject_trace_context_format(self) -> None: +# """Test that injected trace context follows W3C format""" +# request_data = {} +# span_ctx = SimpleSpanContext(trace_id=0x12345678901234567890123456789012, span_id=0x1234567890123456) +# self.instrumentor._inject_trace_context(request_data, span_ctx) +# traceparent = request_data["params"]["_meta"]["traceparent"] +# self.assertEqual(traceparent, "00-12345678901234567890123456789012-1234567890123456-01") + + +# class FakeParams: +# def __init__(self, name=None, meta=None): +# self.name = name +# self.meta = meta + + +# class FakeRequest: +# def __init__(self, params=None): +# self.params = params +# self.root = self + +# def model_dump(self, **kwargs): +# return {"method": "call_tool", "params": {"name": self.params.name if self.params else None}} + +# @classmethod +# def model_validate(cls, data): +# return cls(params=FakeParams(name=data["params"].get("name"))) + + +# class TestGetMCPOperation(unittest.TestCase): +# """Test _get_mcp_operation function with patched types""" + +# def setUp(self): +# self.instrumentor = MCPInstrumentor() + +# @unittest.mock.patch("mcp.types") +# def test_get_mcp_operation_coverage(self, mock_types): +# """Test _get_mcp_operation with all request types""" + +# # Create actual classes for isinstance checks +# class FakeListToolsRequest: +# pass + +# class FakeCallToolRequest: +# def __init__(self): +# self.params = type("Params", (), {"name": "test_tool"})() + +# # Set up mock types +# mock_types.ListToolsRequest = FakeListToolsRequest +# mock_types.CallToolRequest = FakeCallToolRequest + +# # Test ListToolsRequest path +# result1 = self.instrumentor._get_mcp_operation(FakeListToolsRequest()) +# self.assertEqual(result1, "tools/list") + +# # Test CallToolRequest path +# result2 = self.instrumentor._get_mcp_operation(FakeCallToolRequest()) +# self.assertEqual(result2, "tools/test_tool") + +# # Test unknown request type +# result3 = self.instrumentor._get_mcp_operation(object()) +# self.assertEqual(result3, "unknown") + +# @unittest.mock.patch("mcp.types") +# def test_generate_mcp_attributes_coverage(self, mock_types): +# """Test _generate_mcp_attributes with all request types""" + +# class FakeListToolsRequest: +# pass + +# class FakeCallToolRequest: +# def __init__(self): +# self.params = type("Params", (), {"name": "test_tool"})() + +# class FakeInitializeRequest: +# pass + +# mock_types.ListToolsRequest = FakeListToolsRequest +# mock_types.CallToolRequest = FakeCallToolRequest +# mock_types.InitializeRequest = FakeInitializeRequest + +# mock_span = MagicMock() +# self.instrumentor._generate_mcp_attributes(mock_span, FakeListToolsRequest(), True) +# self.instrumentor._generate_mcp_attributes(mock_span, FakeCallToolRequest(), True) +# self.instrumentor._generate_mcp_attributes(mock_span, FakeInitializeRequest(), True) +# mock_span.set_attribute.assert_any_call("mcp.list_tools", True) +# mock_span.set_attribute.assert_any_call("mcp.call_tool", True) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py deleted file mode 100644 index b5ae9c78a..000000000 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test_mcpinstrumentor.py +++ /dev/null @@ -1,945 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Unit tests for MCPInstrumentor - testing actual mcpinstrumentor methods -""" - -import asyncio -import sys -import unittest -from typing import Any, Dict, List, Optional -from unittest.mock import MagicMock - -from amazon.opentelemetry.distro.instrumentation.mcp import version -from amazon.opentelemetry.distro.instrumentation.mcp.constants import MCPEnvironmentVariables -from amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor import MCPInstrumentor - - -# Mock the mcp module to prevent import errors -class MockMCPTypes: - class CallToolRequest: - pass - - class ListToolsRequest: - pass - - class InitializeRequest: - pass - - class ClientResult: - pass - - -mock_mcp_types = MockMCPTypes() -sys.modules["mcp"] = MagicMock() -sys.modules["mcp.types"] = mock_mcp_types -sys.modules["mcp.shared"] = MagicMock() -sys.modules["mcp.shared.session"] = MagicMock() -sys.modules["mcp.server"] = MagicMock() -sys.modules["mcp.server.lowlevel"] = MagicMock() -sys.modules["mcp.server.lowlevel.server"] = MagicMock() - - -class SimpleSpanContext: - """Simple mock span context without using MagicMock""" - - def __init__(self, trace_id: int, span_id: int) -> None: - self.trace_id = trace_id - self.span_id = span_id - - -class SimpleTracerProvider: - """Simple mock tracer provider without using MagicMock""" - - def __init__(self) -> None: - self.get_tracer_called = False - self.tracer_name: Optional[str] = None - - def get_tracer(self, name: str) -> str: - self.get_tracer_called = True - self.tracer_name = name - return "mock_tracer_from_provider" - - -class TestInjectTraceContext(unittest.TestCase): - """Test the _inject_trace_context method""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_inject_trace_context_empty_dict(self) -> None: - """Test injecting trace context into empty dictionary""" - # Setup - request_data = {} - span_ctx = SimpleSpanContext(trace_id=12345, span_id=67890) - - # Execute - Actually test the mcpinstrumentor method - self.instrumentor._inject_trace_context(request_data, span_ctx) - - # Verify - now uses traceparent W3C format - self.assertIn("params", request_data) - self.assertIn("_meta", request_data["params"]) - self.assertIn("traceparent", request_data["params"]["_meta"]) - - # Verify traceparent format: "00-{trace_id:032x}-{span_id:016x}-01" - traceparent = request_data["params"]["_meta"]["traceparent"] - self.assertTrue(traceparent.startswith("00-")) - self.assertTrue(traceparent.endswith("-01")) - parts = traceparent.split("-") - self.assertEqual(len(parts), 4) - self.assertEqual(int(parts[1], 16), 12345) # trace_id - self.assertEqual(int(parts[2], 16), 67890) # span_id - - def test_inject_trace_context_existing_params(self) -> None: - """Test injecting trace context when params already exist""" - # Setup - request_data = {"params": {"existing_field": "test_value"}} - span_ctx = SimpleSpanContext(trace_id=99999, span_id=11111) - - # Execute - Actually test the mcpinstrumentor method - self.instrumentor._inject_trace_context(request_data, span_ctx) - - # Verify the existing field is preserved and traceparent is added - self.assertEqual(request_data["params"]["existing_field"], "test_value") - self.assertIn("_meta", request_data["params"]) - self.assertIn("traceparent", request_data["params"]["_meta"]) - - # Verify traceparent format contains correct trace/span IDs - traceparent = request_data["params"]["_meta"]["traceparent"] - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 99999) # trace_id - self.assertEqual(int(parts[2], 16), 11111) # span_id - - -class TestTracerProvider(unittest.TestCase): - """Test the tracer provider kwargs logic in _instrument method""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - # Reset tracer to ensure test isolation - if hasattr(self.instrumentor, "tracer"): - delattr(self.instrumentor, "tracer") - - def test_instrument_without_tracer_provider_kwargs(self) -> None: - """Test _instrument method when no tracer_provider in kwargs - should use default tracer""" - # Execute - Actually test the mcpinstrumentor method - with unittest.mock.patch("opentelemetry.trace.get_tracer") as mock_get_tracer, unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" - ): - mock_get_tracer.return_value = "default_tracer" - self.instrumentor._instrument() - - # Verify - tracer should be set from trace.get_tracer - self.assertTrue(hasattr(self.instrumentor, "tracer")) - self.assertEqual(self.instrumentor.tracer, "default_tracer") - mock_get_tracer.assert_called_with("instrumentation.mcp") - - def test_instrument_with_tracer_provider_kwargs(self) -> None: - """Test _instrument method when tracer_provider is in kwargs - should use provider's tracer""" - # Setup - provider = SimpleTracerProvider() - - # Execute - Actually test the mcpinstrumentor method - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" - ): - self.instrumentor._instrument(tracer_provider=provider) - - # Verify - tracer should be set from the provided tracer_provider - self.assertTrue(hasattr(self.instrumentor, "tracer")) - self.assertEqual(self.instrumentor.tracer, "mock_tracer_from_provider") - self.assertTrue(provider.get_tracer_called) - self.assertEqual(provider.tracer_name, "instrumentation.mcp") - - -class TestInstrumentationDependencies(unittest.TestCase): - """Test the instrumentation_dependencies method""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_instrumentation_dependencies(self) -> None: - """Test that instrumentation_dependencies method returns the expected dependencies""" - # Execute - Actually test the mcpinstrumentor method - dependencies = self.instrumentor.instrumentation_dependencies() - - # Verify - should return the _instruments collection - self.assertIsNotNone(dependencies) - # Should contain mcp dependency - self.assertIn("mcp >= 1.6.0", dependencies) - - -class TestTraceContextInjection(unittest.TestCase): - """Test trace context injection using actual mcpinstrumentor methods""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_trace_context_injection_with_realistic_request(self) -> None: - """Test actual trace context injection using mcpinstrumentor._inject_trace_context with realistic MCP request""" - - # Create a realistic MCP request structure - class CallToolRequest: - def __init__(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> None: - self.root = self - self.params = CallToolParams(tool_name, arguments) - - def model_dump( - self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True - ) -> Dict[str, Any]: - result = {"method": "call_tool", "params": {"name": self.params.name}} - if self.params.arguments: - result["params"]["arguments"] = self.params.arguments - # Include _meta if it exists (trace context injection point) - if hasattr(self.params, "_meta") and self.params._meta: - result["params"]["_meta"] = self.params._meta - return result - - # converting raw dictionary data back into an instance of this class - @classmethod - def model_validate(cls, data: Dict[str, Any]) -> "CallToolRequest": - instance = cls(data["params"]["name"], data["params"].get("arguments")) - # Restore _meta field if present - if "_meta" in data["params"]: - instance.params._meta = data["params"]["_meta"] - return instance - - class CallToolParams: - def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: - self.name = name - self.arguments = arguments - self._meta: Optional[Dict[str, Any]] = None # Will hold trace context - - # Client creates original request - client_request = CallToolRequest("create_metric", {"metric_name": "response_time", "value": 250}) - - # Client injects trace context using ACTUAL mcpinstrumentor method - original_trace_context = SimpleSpanContext(trace_id=98765, span_id=43210) - request_data = client_request.model_dump() - - # This is the actual mcpinstrumentor method we're testing - self.instrumentor._inject_trace_context(request_data, original_trace_context) - - # Create modified request with trace context - modified_request = CallToolRequest.model_validate(request_data) - - # Verify the actual mcpinstrumentor method worked correctly - client_data = modified_request.model_dump() - self.assertIn("_meta", client_data["params"]) - self.assertIn("traceparent", client_data["params"]["_meta"]) - - # Verify traceparent format contains correct trace/span IDs - traceparent = client_data["params"]["_meta"]["traceparent"] - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 98765) # trace_id - self.assertEqual(int(parts[2], 16), 43210) # span_id - - # Verify the tool call data is also preserved - self.assertEqual(client_data["params"]["name"], "create_metric") - self.assertEqual(client_data["params"]["arguments"]["metric_name"], "response_time") - - -class TestInstrumentedMCPServer(unittest.TestCase): - """Test mcpinstrumentor with a mock MCP server to verify end-to-end functionality""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - # Initialize tracer so the instrumentor can work - mock_tracer = MagicMock() - self.instrumentor.tracer = mock_tracer - - def test_no_trace_context_fallback(self) -> None: - """Test graceful handling when no trace context is present on server side""" - - class MockServerNoTrace: - @staticmethod - async def _handle_request(session: Any, request: Any) -> Dict[str, Any]: - return {"success": True, "handled_without_trace": True} - - class MockServerRequestNoTrace: - def __init__(self, tool_name: str) -> None: - self.params = MockServerRequestParamsNoTrace(tool_name) - - class MockServerRequestParamsNoTrace: - def __init__(self, name: str) -> None: - self.name = name - self.meta: Optional[Any] = None # No trace context - - mock_server = MockServerNoTrace() - server_request = MockServerRequestNoTrace("create_metric") - - # Setup mocks - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - - # Test server handling without trace context (fallback scenario) - with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" - ): - - result = asyncio.run( - self.instrumentor._wrap_handle_request(mock_server._handle_request, None, (None, server_request), {}) - ) - - # Verify graceful fallback - no tracing spans should be created when no trace context - # The wrapper should call the original function without creating distributed trace spans - self.assertEqual(result["success"], True) - self.assertEqual(result["handled_without_trace"], True) - - # Should not create traced spans when no trace context is present - mock_tracer.start_as_current_span.assert_not_called() - - # pylint: disable=too-many-locals,too-many-statements - def test_end_to_end_client_server_communication( - self, - ) -> None: - """Test where server actually receives what client sends (including injected trace context)""" - - # Create realistic request/response classes - class MCPRequest: - def __init__( - self, tool_name: str, arguments: Optional[Dict[str, Any]] = None, method: str = "call_tool" - ) -> None: - self.root = self - self.params = MCPRequestParams(tool_name, arguments) - self.method = method - - def model_dump( - self, by_alias: bool = True, mode: str = "json", exclude_none: bool = True - ) -> Dict[str, Any]: - result = {"method": self.method, "params": {"name": self.params.name}} - if self.params.arguments: - result["params"]["arguments"] = self.params.arguments - # Include _meta if it exists (for trace context) - if hasattr(self.params, "_meta") and self.params._meta: - result["params"]["_meta"] = self.params._meta - return result - - @classmethod - def model_validate(cls, data: Dict[str, Any]) -> "MCPRequest": - method = data.get("method", "call_tool") - instance = cls(data["params"]["name"], data["params"].get("arguments"), method) - # Restore _meta field if present - if "_meta" in data["params"]: - instance.params._meta = data["params"]["_meta"] - return instance - - class MCPRequestParams: - def __init__(self, name: str, arguments: Optional[Dict[str, Any]] = None) -> None: - self.name = name - self.arguments = arguments - self._meta: Optional[Dict[str, Any]] = None - - class MCPServerRequest: - def __init__(self, client_request_data: Dict[str, Any]) -> None: - """Server request created from client's serialized data""" - self.method = client_request_data.get("method", "call_tool") - self.params = MCPServerRequestParams(client_request_data["params"]) - - class MCPServerRequestParams: - def __init__(self, params_data: Dict[str, Any]) -> None: - self.name = params_data["name"] - self.arguments = params_data.get("arguments") - # Extract traceparent from _meta if present - if "_meta" in params_data and "traceparent" in params_data["_meta"]: - self.meta = MCPServerRequestMeta(params_data["_meta"]["traceparent"]) - else: - self.meta = None - - class MCPServerRequestMeta: - def __init__(self, traceparent: str) -> None: - self.traceparent = traceparent - - # Mock client and server that actually communicate - class EndToEndMCPSystem: - def __init__(self) -> None: - self.communication_log: List[str] = [] - self.last_sent_request: Optional[Any] = None - - async def client_send_request(self, request: Any) -> Dict[str, Any]: - """Client sends request - captures what gets sent""" - self.communication_log.append("CLIENT: Preparing to send request") - self.last_sent_request = request # Capture the modified request - - # Simulate sending over network - serialize the request - serialized_request = request.model_dump() - self.communication_log.append(f"CLIENT: Sent {serialized_request}") - - # Return client response - return {"success": True, "client_response": "Request sent successfully"} - - async def server_handle_request(self, session: Any, server_request: Any) -> Dict[str, Any]: - """Server handles the request it received""" - self.communication_log.append(f"SERVER: Received request for {server_request.params.name}") - - # Check if traceparent was received - if server_request.params.meta and server_request.params.meta.traceparent: - traceparent = server_request.params.meta.traceparent - # Parse traceparent to extract trace_id and span_id - parts = traceparent.split("-") - if len(parts) == 4: - trace_id = int(parts[1], 16) - span_id = int(parts[2], 16) - self.communication_log.append( - f"SERVER: Found trace context - trace_id: {trace_id}, " f"span_id: {span_id}" - ) - else: - self.communication_log.append("SERVER: Invalid traceparent format") - else: - self.communication_log.append("SERVER: No trace context found") - - return {"success": True, "server_response": f"Handled {server_request.params.name}"} - - # Create the end-to-end system - e2e_system = EndToEndMCPSystem() - - # Create original client request - original_request = MCPRequest("create_metric", {"name": "cpu_usage", "value": 85}) - - # Setup OpenTelemetry mocks - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_span_context = MagicMock() - mock_span_context.trace_id = 12345 - mock_span_context.span_id = 67890 - mock_span.get_span_context.return_value = mock_span_context - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - - # STEP 1: Client sends request through instrumentation - with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): - # Override the setup tracer with the properly mocked one - self.instrumentor.tracer = mock_tracer - - client_result = asyncio.run( - self.instrumentor._wrap_send_request(e2e_system.client_send_request, None, (original_request,), {}) - ) - - # Verify client side worked - self.assertEqual(client_result["success"], True) - self.assertIn("CLIENT: Preparing to send request", e2e_system.communication_log) - - # Get the request that was actually sent (with trace context injected) - sent_request = e2e_system.last_sent_request - sent_request_data = sent_request.model_dump() - - # Verify traceparent was injected by client instrumentation - self.assertIn("_meta", sent_request_data["params"]) - self.assertIn("traceparent", sent_request_data["params"]["_meta"]) - - # Parse and verify traceparent contains correct trace/span IDs - traceparent = sent_request_data["params"]["_meta"]["traceparent"] - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 12345) # trace_id - self.assertEqual(int(parts[2], 16), 67890) # span_id - - # STEP 2: Server receives the EXACT request that client sent - # Create server request from the client's serialized data - server_request = MCPServerRequest(sent_request_data) - - # Reset tracer mock for server side - mock_tracer.reset_mock() - - # Server processes the request it received - with unittest.mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer), unittest.mock.patch.dict( - "sys.modules", {"mcp.types": MagicMock(), "mcp": MagicMock()} - ), unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_mcp_operation", return_value="tools/create_metric" - ): - - server_result = asyncio.run( - self.instrumentor._wrap_handle_request( - e2e_system.server_handle_request, None, (None, server_request), {} - ) - ) - - # Verify server side worked - self.assertEqual(server_result["success"], True) - - # Verify end-to-end trace context propagation - self.assertIn("SERVER: Found trace context - trace_id: 12345, span_id: 67890", e2e_system.communication_log) - - # Verify the server received the exact same data the client sent - self.assertEqual(server_request.params.name, "create_metric") - self.assertEqual(server_request.params.arguments["name"], "cpu_usage") - self.assertEqual(server_request.params.arguments["value"], 85) - - # Verify the traceparent made it through end-to-end - self.assertIsNotNone(server_request.params.meta) - self.assertIsNotNone(server_request.params.meta.traceparent) - - # Parse traceparent and verify trace/span IDs - traceparent = server_request.params.meta.traceparent - parts = traceparent.split("-") - self.assertEqual(int(parts[1], 16), 12345) # trace_id - self.assertEqual(int(parts[2], 16), 67890) # span_id - - # Verify complete communication flow - expected_log_entries = [ - "CLIENT: Preparing to send request", - "CLIENT: Sent", # Part of the serialized request log - "SERVER: Received request for create_metric", - "SERVER: Found trace context - trace_id: 12345, span_id: 67890", - ] - - for expected_entry in expected_log_entries: - self.assertTrue( - any(expected_entry in log_entry for log_entry in e2e_system.communication_log), - f"Expected log entry '{expected_entry}' not found in: {e2e_system.communication_log}", - ) - - -class TestMCPInstrumentorEdgeCases(unittest.TestCase): - """Test edge cases and error conditions for MCP instrumentor""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_invalid_traceparent_format(self) -> None: - """Test handling of malformed traceparent headers""" - invalid_formats = [ - "invalid-format", - "00-invalid-hex-01", - "00-12345-67890", # Missing part - "00-12345-67890-01-extra", # Too many parts - "", # Empty string - ] - - for invalid_format in invalid_formats: - with unittest.mock.patch.dict("sys.modules", {"mcp.types": MagicMock()}): - result = self.instrumentor._extract_span_context_from_traceparent(invalid_format) - self.assertIsNone(result, f"Should return None for invalid format: {invalid_format}") - - def test_version_import(self) -> None: - """Test that version can be imported""" - self.assertIsNotNone(version) - - def test_constants_import(self) -> None: - """Test that constants can be imported""" - self.assertIsNotNone(MCPEnvironmentVariables.SERVER_NAME) - - def test_add_client_attributes_default_server_name(self) -> None: - """Test _add_client_attributes uses default server name""" - mock_span = MagicMock() - - class MockRequest: - def __init__(self) -> None: - self.params = MockParams() - - class MockParams: - def __init__(self) -> None: - self.name = "test_tool" - - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_operation", request) - - # Verify default server name is used - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_operation") - mock_span.set_attribute.assert_any_call("mcp.tool.name", "test_tool") - - def test_add_client_attributes_without_tool_name(self) -> None: - """Test _add_client_attributes when request has no tool name""" - mock_span = MagicMock() - - class MockRequestNoTool: - def __init__(self) -> None: - self.params = None - - request = MockRequestNoTool() - self.instrumentor._add_client_attributes(mock_span, "test_operation", request) - - # Should still set service and method, but not tool name - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - mock_span.set_attribute.assert_any_call("rpc.method", "test_operation") - - def test_add_server_attributes_without_tool_name(self) -> None: - """Test _add_server_attributes when request has no tool name""" - mock_span = MagicMock() - - class MockRequestNoTool: - def __init__(self) -> None: - self.params = None - - request = MockRequestNoTool() - self.instrumentor._add_server_attributes(mock_span, "test_operation", request) - - # Should not set any attributes for server when no tool name - mock_span.set_attribute.assert_not_called() - - def test_inject_trace_context_empty_request(self) -> None: - """Test trace context injection with minimal request data""" - request_data = {} - span_ctx = SimpleSpanContext(trace_id=111, span_id=222) - - self.instrumentor._inject_trace_context(request_data, span_ctx) - - # Should create params and _meta structure - self.assertIn("params", request_data) - self.assertIn("_meta", request_data["params"]) - self.assertIn("traceparent", request_data["params"]["_meta"]) - - # Verify traceparent format - traceparent = request_data["params"]["_meta"]["traceparent"] - parts = traceparent.split("-") - self.assertEqual(len(parts), 4) - self.assertEqual(int(parts[1], 16), 111) # trace_id - self.assertEqual(int(parts[2], 16), 222) # span_id - - def test_uninstrument(self) -> None: - """Test _uninstrument method removes instrumentation""" - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" - ) as mock_unwrap: - self.instrumentor._uninstrument() - self.assertEqual(mock_unwrap.call_count, 2) - mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") - mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") - - def test_extract_span_context_valid_traceparent(self) -> None: - """Test _extract_span_context_from_traceparent with valid format""" - # Use correct hex values: 12345 = 0x3039, 67890 = 0x10932 - valid_traceparent = "00-0000000000003039-0000000000010932-01" - result = self.instrumentor._extract_span_context_from_traceparent(valid_traceparent) - self.assertIsNotNone(result) - self.assertEqual(result.trace_id, 12345) - self.assertEqual(result.span_id, 67890) - self.assertTrue(result.is_remote) - - def test_extract_span_context_value_error(self) -> None: - """Test _extract_span_context_from_traceparent with invalid hex values""" - invalid_hex_traceparent = "00-invalid-hex-values-01" - result = self.instrumentor._extract_span_context_from_traceparent(invalid_hex_traceparent) - self.assertIsNone(result) - - def test_instrument_method_coverage(self) -> None: - """Test _instrument method registers hooks""" - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.register_post_import_hook" - ) as mock_register: - self.instrumentor._instrument() - self.assertEqual(mock_register.call_count, 2) - - -class TestWrapSendRequestEdgeCases(unittest.TestCase): - """Test _wrap_send_request edge cases""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_span_context = MagicMock() - mock_span_context.trace_id = 12345 - mock_span_context.span_id = 67890 - mock_span.get_span_context.return_value = mock_span_context - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - self.instrumentor.tracer = mock_tracer - - def test_wrap_send_request_no_request_in_args(self) -> None: - """Test _wrap_send_request when no request in args""" - - async def mock_wrapped(): - return {"result": "no_request"} - - result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {})) - self.assertEqual(result["result"], "no_request") - - def test_wrap_send_request_request_in_kwargs(self) -> None: - """Test _wrap_send_request when request is in kwargs""" - - class MockRequest: - def __init__(self): - self.root = self - self.params = MockParams() - - @staticmethod - def model_dump(**kwargs): - return {"method": "test", "params": {"name": "test_tool"}} - - @classmethod - def model_validate(cls, data): - return cls() - - class MockParams: - def __init__(self): - self.name = "test_tool" - - async def mock_wrapped(*args, **kwargs): - return {"result": "kwargs_request"} - - request = MockRequest() - - with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): - result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (), {"request": request})) - self.assertEqual(result["result"], "kwargs_request") - - def test_wrap_send_request_no_root_attribute(self) -> None: - """Test _wrap_send_request when request has no root attribute""" - - class MockRequestNoRoot: - def __init__(self): - self.params = MockParams() - - @staticmethod - def model_dump(**kwargs): - return {"method": "test", "params": {"name": "test_tool"}} - - @classmethod - def model_validate(cls, data): - return cls() - - class MockParams: - def __init__(self): - self.name = "test_tool" - - async def mock_wrapped(*args, **kwargs): - return {"result": "no_root"} - - request = MockRequestNoRoot() - - with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"): - result = asyncio.run(self.instrumentor._wrap_send_request(mock_wrapped, None, (request,), {})) - self.assertEqual(result["result"], "no_root") - - -class TestWrapHandleRequestEdgeCases(unittest.TestCase): - """Test _wrap_handle_request edge cases""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span - mock_tracer.start_as_current_span.return_value.__exit__.return_value = None - self.instrumentor.tracer = mock_tracer - - def test_wrap_handle_request_no_request(self) -> None: - """Test _wrap_handle_request when no request in args""" - - async def mock_wrapped(*args, **kwargs): - return {"result": "no_request"} - - result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session",), {})) - self.assertEqual(result["result"], "no_request") - - def test_wrap_handle_request_no_params(self) -> None: - """Test _wrap_handle_request when request has no params""" - - class MockRequestNoParams: - def __init__(self): - self.params = None - - async def mock_wrapped(*args, **kwargs): - return {"result": "no_params"} - - request = MockRequestNoParams() - result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - self.assertEqual(result["result"], "no_params") - - def test_wrap_handle_request_no_meta(self) -> None: - """Test _wrap_handle_request when request params has no meta""" - - class MockRequestNoMeta: - def __init__(self): - self.params = MockParamsNoMeta() - - class MockParamsNoMeta: - def __init__(self): - self.meta = None - - async def mock_wrapped(*args, **kwargs): - return {"result": "no_meta"} - - request = MockRequestNoMeta() - result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - self.assertEqual(result["result"], "no_meta") - - def test_wrap_handle_request_with_valid_traceparent(self) -> None: - """Test _wrap_handle_request with valid traceparent""" - - class MockRequestWithTrace: - def __init__(self): - self.params = MockParamsWithTrace() - - class MockParamsWithTrace: - def __init__(self): - self.meta = MockMeta() - - class MockMeta: - def __init__(self): - self.traceparent = "00-0000000000003039-0000000000010932-01" - - async def mock_wrapped(*args, **kwargs): - return {"result": "with_trace"} - - request = MockRequestWithTrace() - - with unittest.mock.patch.object(self.instrumentor, "_generate_mcp_attributes"), unittest.mock.patch.object( - self.instrumentor, "_get_mcp_operation", return_value="tools/test" - ): - result = asyncio.run(self.instrumentor._wrap_handle_request(mock_wrapped, None, ("session", request), {})) - self.assertEqual(result["result"], "with_trace") - - -class TestInstrumentorStaticMethods(unittest.TestCase): - """Test static methods of MCPInstrumentor""" - - def test_instrumentation_dependencies_static(self) -> None: - """Test instrumentation_dependencies as static method""" - deps = MCPInstrumentor.instrumentation_dependencies() - self.assertIn("mcp >= 1.6.0", deps) - - def test_uninstrument_static(self) -> None: - """Test _uninstrument as static method""" - with unittest.mock.patch( - "amazon.opentelemetry.distro.instrumentation.mcp.mcp_instrumentor.unwrap" - ) as mock_unwrap: - MCPInstrumentor._uninstrument() - self.assertEqual(mock_unwrap.call_count, 2) - mock_unwrap.assert_any_call("mcp.shared.session", "BaseSession.send_request") - mock_unwrap.assert_any_call("mcp.server.lowlevel.server", "Server._handle_request") - - -class TestEnvironmentVariableHandling(unittest.TestCase): - """Test environment variable handling""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_server_name_environment_variable(self) -> None: - """Test that MCP_INSTRUMENTATION_SERVER_NAME environment variable is used""" - mock_span = MagicMock() - - class MockRequest: - def __init__(self): - self.params = MockParams() - - class MockParams: - def __init__(self): - self.name = "test_tool" - - # Test with environment variable set - with unittest.mock.patch.dict("os.environ", {"MCP_INSTRUMENTATION_SERVER_NAME": "my-custom-server"}): - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_operation", request) - mock_span.set_attribute.assert_any_call("rpc.service", "my-custom-server") - - def test_server_name_default_value(self) -> None: - """Test that default server name is used when environment variable is not set""" - mock_span = MagicMock() - - class MockRequest: - def __init__(self): - self.params = MockParams() - - class MockParams: - def __init__(self): - self.name = "test_tool" - - # Test without environment variable - with unittest.mock.patch.dict("os.environ", {}, clear=True): - request = MockRequest() - self.instrumentor._add_client_attributes(mock_span, "test_operation", request) - mock_span.set_attribute.assert_any_call("rpc.service", "mcp server") - - -class TestTraceContextFormats(unittest.TestCase): - """Test trace context format handling""" - - def setUp(self) -> None: - self.instrumentor = MCPInstrumentor() - - def test_inject_trace_context_format(self) -> None: - """Test that injected trace context follows W3C format""" - request_data = {} - span_ctx = SimpleSpanContext(trace_id=0x12345678901234567890123456789012, span_id=0x1234567890123456) - self.instrumentor._inject_trace_context(request_data, span_ctx) - traceparent = request_data["params"]["_meta"]["traceparent"] - self.assertEqual(traceparent, "00-12345678901234567890123456789012-1234567890123456-01") - - -class FakeParams: - def __init__(self, name=None, meta=None): - self.name = name - self.meta = meta - - -class FakeRequest: - def __init__(self, params=None): - self.params = params - self.root = self - - def model_dump(self, **kwargs): - return {"method": "call_tool", "params": {"name": self.params.name if self.params else None}} - - @classmethod - def model_validate(cls, data): - return cls(params=FakeParams(name=data["params"].get("name"))) - - -class TestGetMCPOperation(unittest.TestCase): - """Test _get_mcp_operation function with patched types""" - - def setUp(self): - self.instrumentor = MCPInstrumentor() - - @unittest.mock.patch("mcp.types") - def test_get_mcp_operation_coverage(self, mock_types): - """Test _get_mcp_operation with all request types""" - - # Create actual classes for isinstance checks - class FakeListToolsRequest: - pass - - class FakeCallToolRequest: - def __init__(self): - self.params = type("Params", (), {"name": "test_tool"})() - - # Set up mock types - mock_types.ListToolsRequest = FakeListToolsRequest - mock_types.CallToolRequest = FakeCallToolRequest - - # Test ListToolsRequest path - result1 = self.instrumentor._get_mcp_operation(FakeListToolsRequest()) - self.assertEqual(result1, "tools/list") - - # Test CallToolRequest path - result2 = self.instrumentor._get_mcp_operation(FakeCallToolRequest()) - self.assertEqual(result2, "tools/test_tool") - - # Test unknown request type - result3 = self.instrumentor._get_mcp_operation(object()) - self.assertEqual(result3, "unknown") - - @unittest.mock.patch("mcp.types") - def test_generate_mcp_attributes_coverage(self, mock_types): - """Test _generate_mcp_attributes with all request types""" - - class FakeListToolsRequest: - pass - - class FakeCallToolRequest: - def __init__(self): - self.params = type("Params", (), {"name": "test_tool"})() - - class FakeInitializeRequest: - pass - - mock_types.ListToolsRequest = FakeListToolsRequest - mock_types.CallToolRequest = FakeCallToolRequest - mock_types.InitializeRequest = FakeInitializeRequest - - mock_span = MagicMock() - self.instrumentor._generate_mcp_attributes(mock_span, FakeListToolsRequest(), True) - self.instrumentor._generate_mcp_attributes(mock_span, FakeCallToolRequest(), True) - self.instrumentor._generate_mcp_attributes(mock_span, FakeInitializeRequest(), True) - mock_span.set_attribute.assert_any_call("mcp.list_tools", True) - mock_span.set_attribute.assert_any_call("mcp.call_tool", True) diff --git a/contract-tests/images/applications/mcp/client.py b/contract-tests/images/applications/mcp/client.py index cfc0a7f01..13eaceffb 100644 --- a/contract-tests/images/applications/mcp/client.py +++ b/contract-tests/images/applications/mcp/client.py @@ -6,20 +6,40 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.types import PromptReference +from pydantic import AnyUrl class MCPHandler(BaseHTTPRequestHandler): def do_GET(self): # pylint: disable=invalid-name - if self.path == "/mcp/echo": - asyncio.run(self._call_mcp_tool("echo", {"text": "Hello from HTTP request!"})) - self.send_response(200) - self.end_headers() + if "call_tool" in self.path: + asyncio.run(self._call_mcp_server("call_tool")) + elif "list_tools" in self.path: + asyncio.run(self._call_mcp_server("list_tools")) + elif "list_prompts" in self.path: + asyncio.run(self._call_mcp_server("list_prompts")) + elif "list_resources" in self.path: + asyncio.run(self._call_mcp_server("list_resources")) + elif "read_resource" in self.path: + asyncio.run(self._call_mcp_server("read_resource")) + elif "get_prompt" in self.path: + asyncio.run(self._call_mcp_server("get_prompt")) + elif "complete" in self.path: + asyncio.run(self._call_mcp_server("complete")) + elif "set_logging_level" in self.path: + asyncio.run(self._call_mcp_server("set_logging_level")) + elif "ping" in self.path: + asyncio.run(self._call_mcp_server("ping")) else: self.send_response(404) self.end_headers() + return + + self.send_response(200) + self.end_headers() @staticmethod - async def _call_mcp_tool(tool_name, arguments): + async def _call_mcp_server(action, *args): server_env = { "OTEL_PYTHON_DISTRO": "aws_distro", "OTEL_PYTHON_CONFIGURATOR": "aws_configurator", @@ -35,7 +55,27 @@ async def _call_mcp_tool(tool_name, arguments): async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() - result = await session.call_tool(tool_name, arguments) + result = None + if action == "list_tools": + result = await session.list_tools() + elif action == "call_tool": + result = await session.call_tool("echo", {"text": "Hello from HTTP request!"}) + elif action == "list_prompts": + result = await session.list_prompts() + elif action == "list_resources": + result = await session.list_resources() + elif action == "read_resource": + result = await session.read_resource(AnyUrl("file://sample.txt")) + elif action == "get_prompt": + result = await session.get_prompt("greeting", {"name": "Test User"}) + elif action == "complete": + prompt_ref = PromptReference(type="ref/prompt", name="greeting") + result = await session.complete(ref=prompt_ref, argument={"name": "completion_test"}) + elif action == "set_logging_level": + result = await session.set_logging_level("info") + elif action == "ping": + result = await session.send_ping() + return result diff --git a/contract-tests/images/applications/mcp/mcp_server.py b/contract-tests/images/applications/mcp/mcp_server.py index 24e48bca5..d2a082a7a 100644 --- a/contract-tests/images/applications/mcp/mcp_server.py +++ b/contract-tests/images/applications/mcp/mcp_server.py @@ -12,5 +12,16 @@ def echo(text: str) -> str: return f"Echo: {text}" +@mcp.resource(uri="file://sample.txt", name="Sample Resource") +def sample_resource() -> str: + """Sample MCP resource""" + return "This is a sample resource content" + + +@mcp.prompt(name="greeting", description="Generate a greeting message") +def greeting_prompt(name: str = "World") -> str: + """Generate a personalized greeting""" + return f"Hello, {name}! Welcome to our MCP server." + if __name__ == "__main__": mcp.run() diff --git a/contract-tests/tests/test/amazon/mcp/mcp_test.py b/contract-tests/tests/test/amazon/mcp/mcp_test.py index 1533c9335..5887e5f10 100644 --- a/contract-tests/tests/test/amazon/mcp/mcp_test.py +++ b/contract-tests/tests/test/amazon/mcp/mcp_test.py @@ -15,7 +15,15 @@ def get_application_image_name() -> str: def test_mcp_echo_tool(self): """Test MCP echo tool call creates proper spans""" - self.do_test_requests("mcp/echo", "GET", 200, 0, 0, tool_name="echo") + self.do_test_requests("call_tool", "GET", 200, 0, 0) + self.do_test_requests("list_tools", "GET", 200, 0, 0) + self.do_test_requests("list_prompts", "GET", 200, 0, 0) + self.do_test_requests("list_resources", "GET", 200, 0, 0) + self.do_test_requests("read_resource", "GET", 200, 0, 0) + self.do_test_requests("get_prompt", "GET", 200, 0, 0) + self.do_test_requests("complete", "GET", 200, 0, 0) + self.do_test_requests("set_logging_level", "GET", 200, 0, 0) + self.do_test_requests("ping", "GET", 200, 0, 0) @override def _assert_aws_span_attributes(self, resource_scope_spans, path: str, **kwargs) -> None: @@ -26,85 +34,11 @@ def _assert_aws_span_attributes(self, resource_scope_spans, path: str, **kwargs) def _assert_semantic_conventions_span_attributes( self, resource_scope_spans, method: str, path: str, status_code: int, **kwargs ) -> None: - - tool_name = kwargs.get("tool_name", "echo") - initialize_client_span = None - list_tools_client_span = None - list_tools_server_span = None - call_tool_client_span = None - call_tool_server_span = None - + for resource_scope_span in resource_scope_spans: - span = resource_scope_span.span - - if span.name == "notifications/initialize" and span.kind == Span.SPAN_KIND_CLIENT: - initialize_client_span = span - elif span.name == "mcp.list_tools" and span.kind == Span.SPAN_KIND_CLIENT: - list_tools_client_span = span - elif span.name == f"mcp.call_tool.{tool_name}" and span.kind == Span.SPAN_KIND_CLIENT: - call_tool_client_span = span - elif span.name == "tools/list" and span.kind == Span.SPAN_KIND_SERVER: - list_tools_server_span = span - elif span.name == f"tools/{tool_name}" and span.kind == Span.SPAN_KIND_SERVER: - call_tool_server_span = span - - # Validate list tools client span - self.assertIsNotNone(list_tools_client_span, "List tools client span not found") - self.assertEqual(list_tools_client_span.kind, Span.SPAN_KIND_CLIENT) - - # Validate initialize client span (no server span expected) - self.assertIsNotNone(initialize_client_span, "Initialize client span not found") - self.assertEqual(initialize_client_span.kind, Span.SPAN_KIND_CLIENT) - - init_attributes = {attr.key: attr.value for attr in initialize_client_span.attributes} - self.assertIn("notifications/initialize", init_attributes) - self.assertTrue(init_attributes["notifications/initialize"].bool_value) - - # Validate list tools client span - self.assertIsNotNone(list_tools_client_span, "List tools client span not found") - self.assertEqual(list_tools_client_span.kind, Span.SPAN_KIND_CLIENT) - - list_client_attributes = {attr.key: attr.value for attr in list_tools_client_span.attributes} - self.assertIn("mcp.list_tools", list_client_attributes) - self.assertTrue(list_client_attributes["mcp.list_tools"].bool_value) - - # Validate list tools server span - self.assertIsNotNone(list_tools_server_span, "List tools server span not found") - self.assertEqual(list_tools_server_span.kind, Span.SPAN_KIND_SERVER) - - list_server_attributes = {attr.key: attr.value for attr in list_tools_server_span.attributes} - self.assertIn("mcp.list_tools", list_server_attributes) - self.assertTrue(list_server_attributes["mcp.list_tools"].bool_value) - - # Validate call tool client span - self.assertIsNotNone(call_tool_client_span, f"Call tool client span for {tool_name} not found") - self.assertEqual(call_tool_client_span.kind, Span.SPAN_KIND_CLIENT) - - call_client_attributes = {attr.key: attr.value for attr in call_tool_client_span.attributes} - self.assertIn("mcp.call_tool", call_client_attributes) - self.assertTrue(call_client_attributes["mcp.call_tool"].bool_value) - self.assertIn("aws.remote.operation", call_client_attributes) - self.assertEqual(call_client_attributes["aws.remote.operation"].string_value, tool_name) - - # Validate call tool server span - self.assertIsNotNone(call_tool_server_span, f"Call tool server span for {tool_name} not found") - self.assertEqual(call_tool_server_span.kind, Span.SPAN_KIND_SERVER) - - call_server_attributes = {attr.key: attr.value for attr in call_tool_server_span.attributes} - self.assertIn("mcp.call_tool", call_server_attributes) - self.assertTrue(call_server_attributes["mcp.call_tool"].bool_value) - - # Validate distributed tracing for paired spans - self.assertEqual( - list_tools_server_span.trace_id, - list_tools_client_span.trace_id, - "List tools client and server spans should have the same trace ID", - ) - self.assertEqual( - call_tool_server_span.trace_id, - call_tool_client_span.trace_id, - "Call tool client and server spans should have the same trace ID", - ) + for scope_span in resource_scope_span.scope_spans: + for span in scope_span.spans: + print(f"Span attributes: {span.attributes}") @override def _assert_metric_attributes(self, resource_scope_metrics, metric_name: str, expected_sum: int, **kwargs) -> None: From aab4e7ddcf2de050ff9ae4a4c2f348b5e6eb0c5d Mon Sep 17 00:00:00 2001 From: Steve Liu Date: Sat, 9 Aug 2025 17:09:13 -0700 Subject: [PATCH 39/41] add session id extraction logic --- .../instrumentation/mcp/instrumentation.py | 29 +++++++++++++++++-- .../distro/instrumentation/mcp/semconv.py | 6 ++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 951918fc2..5d6ef47fb 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -28,6 +28,7 @@ class McpInstrumentor(BaseInstrumentor): _DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client" _DEFAULT_SERVER_SPAN_NAME = "span.mcp.server" + _MCP_SESSION_ID_HEADER = "mcp-session-id" def __init__(self, **kwargs): super().__init__() @@ -226,8 +227,12 @@ async def _wrap_server_message_handler( with self.tracer.start_as_current_span( self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx ) as server_span: + + # Extract session ID if available + session_id = self._extract_session_id(args) + if session_id: + server_span.set_attribute(MCPSpanAttributes.MCP_SESSION_ID, session_id) - server_span.set_attribute(SpanAttributes.RPC_SERVICE, instance.name) self._generate_mcp_message_attrs(server_span, incoming_msg, request_id) try: @@ -238,7 +243,27 @@ async def _wrap_server_message_handler( server_span.set_status(Status(StatusCode.ERROR, str(e))) server_span.record_exception(e) raise - + + def _extract_session_id(self, args: Tuple[Any, ...]) -> Optional[str]: + """ + Extract session ID from server method arguments. + """ + try: + from mcp.shared.session import RequestResponder # pylint: disable=import-outside-toplevel + from mcp.shared.message import ServerMessageMetadata # pylint: disable=import-outside-toplevel + + message = args[0] + if isinstance(message, RequestResponder): + if message.message_metadata and isinstance(message.message_metadata, ServerMessageMetadata): + request_context = message.message_metadata.request_context + if request_context: + headers = getattr(request_context, 'headers', None) + if headers: + return headers.get(self._MCP_SESSION_ID_HEADER) + return None + except Exception: + return None + @staticmethod def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py index 611c5dcaa..8f89e9a2d 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/semconv.py @@ -48,6 +48,12 @@ class MCPSpanAttributes: The transport type used for MCP communication. Examples: stdio, streamable_http """ + MCP_SESSION_ID = "mcp.session.id" + """ + The session identifier for HTTP transport connections. + Only present for streamable_http transport, not available for stdio. + """ + class MCPMethodValue: From 5c1aeec9e62b1519ce5e49e4c82cffee76fd2424 Mon Sep 17 00:00:00 2001 From: liustve Date: Sun, 10 Aug 2025 05:17:47 +0000 Subject: [PATCH 40/41] add further trace propagation for responses --- .../instrumentation/mcp/instrumentation.py | 99 ++++++++++++++----- 1 file changed, 74 insertions(+), 25 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 5d6ef47fb..37c503223 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -1,24 +1,23 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass import json -from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, Union, cast +from typing import Any, Callable, Collection, Dict, Optional, Tuple, Union from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace -from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap -from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.propagate import get_global_textmap - -from .version import __version__ +from opentelemetry.semconv.attributes.network_attributes import NetworkTransportValues +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import INVALID_SPAN, SpanKind, Status, StatusCode from .semconv import ( - MCPSpanAttributes, MCPMethodValue, + MCPSpanAttributes, ) +from .version import __version__ class McpInstrumentor(BaseInstrumentor): @@ -29,6 +28,7 @@ class McpInstrumentor(BaseInstrumentor): _DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client" _DEFAULT_SERVER_SPAN_NAME = "span.mcp.server" _MCP_SESSION_ID_HEADER = "mcp-session-id" + _MCP_META_FIELD = "_meta" def __init__(self, **kwargs): super().__init__() @@ -55,6 +55,14 @@ def _instrument(self, **kwargs: Any) -> None: ), "mcp.shared.session", ) + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.shared.session", + "BaseSession._send_response", + self._wrap_send_response, + ), + "mcp.shared.session", + ) register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.server.lowlevel.server", @@ -75,6 +83,7 @@ def _instrument(self, **kwargs: Any) -> None: def _uninstrument(self, **kwargs: Any) -> None: unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.shared.session", "BaseSession.send_notification") + unwrap("mcp.shared.session", "BaseSession._send_response") unwrap("mcp.server.lowlevel.server", "Server._handle_request") unwrap("mcp.server.lowlevel.server", "Server._handle_notification") @@ -98,7 +107,7 @@ def _wrap_session_send( args: Positional arguments passed to the original send_request/send_notification method kwargs: Keyword arguments passed to the original send_request/send_notification method """ - from mcp.types import ClientRequest, ClientNotification, ServerRequest, ServerNotification + from mcp.types import ClientNotification, ClientRequest, ServerNotification, ServerRequest async def async_wrapper(): message: Optional[Union[ClientRequest, ClientNotification, ServerRequest, ServerNotification]] = ( @@ -120,14 +129,18 @@ async def async_wrapper(): if "params" not in message_json: message_json["params"] = {} - if "_meta" not in message_json["params"]: - message_json["params"]["_meta"] = {} + if self._MCP_META_FIELD not in message_json["params"]: + message_json["params"][self._MCP_META_FIELD] = {} + + parent_ctx = None + if message_json["params"][self._MCP_META_FIELD]: + parent_ctx = self.propagators.extract(message_json["params"][self._MCP_META_FIELD]) - with self.tracer.start_as_current_span(name=span_name, kind=span_kind) as span: + with self.tracer.start_as_current_span(name=span_name, kind=span_kind, context=parent_ctx) as span: ctx = trace.set_span_in_context(span) carrier = {} self.propagators.inject(carrier=carrier, context=ctx) - message_json["params"]["_meta"].update(carrier) + message_json["params"][self._MCP_META_FIELD].update(carrier) McpInstrumentor._generate_mcp_message_attrs(span, message, request_id) @@ -227,11 +240,14 @@ async def _wrap_server_message_handler( with self.tracer.start_as_current_span( self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx ) as server_span: - - # Extract session ID if available + + # session_id only exits if the transport protocol is Streamable HTTP session_id = self._extract_session_id(args) if session_id: server_span.set_attribute(MCPSpanAttributes.MCP_SESSION_ID, session_id) + server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.PIPE.value) + else: + server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.TCP.value) self._generate_mcp_message_attrs(server_span, incoming_msg, request_id) @@ -243,27 +259,60 @@ async def _wrap_server_message_handler( server_span.set_status(Status(StatusCode.ERROR, str(e))) server_span.record_exception(e) raise - + + async def _wrap_send_response( + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + """ + Instruments BaseSession._send_response to propagate trace context into response. + + Note: we do not need to generate another span for the reponse as it falls under + the _wrap_server_message_handler + """ + response = args[1] if len(args) > 1 else kwargs.get("response", None) + + if not response: + return await wrapped(*args, **kwargs) + + current_span = trace.get_current_span() + if current_span is not INVALID_SPAN: + # Inject trace context into response + carrier = {} + self.propagators.inject(carrier=carrier, context=trace.set_span_in_context(current_span)) + + response_json = response.model_dump(by_alias=True, mode="json", exclude_none=True) + + if self._MCP_META_FIELD not in response_json: + response_json[self._MCP_META_FIELD] = {} + response_json[self._MCP_META_FIELD].update(carrier) + + modified_response = response.model_validate(response_json) + + if len(args) > 1: + args = args[:1] + (modified_response,) + args[2:] + else: + kwargs["response"] = modified_response + + return await wrapped(*args, **kwargs) + def _extract_session_id(self, args: Tuple[Any, ...]) -> Optional[str]: """ Extract session ID from server method arguments. """ try: - from mcp.shared.session import RequestResponder # pylint: disable=import-outside-toplevel - from mcp.shared.message import ServerMessageMetadata # pylint: disable=import-outside-toplevel - message = args[0] - if isinstance(message, RequestResponder): - if message.message_metadata and isinstance(message.message_metadata, ServerMessageMetadata): - request_context = message.message_metadata.request_context - if request_context: - headers = getattr(request_context, 'headers', None) + if hasattr(message, "message_metadata"): + message_metadata = message.message_metadata + if message.message_metadata and hasattr(message_metadata, "request_context"): + request_context = message_metadata.request_context + if request_context and hasattr(request_context, "headers"): + headers = request_context.headers if headers: return headers.get(self._MCP_SESSION_ID_HEADER) return None except Exception: return None - + @staticmethod def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import @@ -326,4 +375,4 @@ def serialize(args: dict[str, Any]) -> str: try: return json.dumps(args) except Exception: - return "" + return "unknown_args" From de1020e9d4d3685aeecbb6445fd26222a99837e9 Mon Sep 17 00:00:00 2001 From: liustve Date: Sun, 10 Aug 2025 21:08:32 +0000 Subject: [PATCH 41/41] Revert "add further trace propagation for responses" This reverts commit 5c1aeec9e62b1519ce5e49e4c82cffee76fd2424. --- .../instrumentation/mcp/instrumentation.py | 99 +++++-------------- 1 file changed, 25 insertions(+), 74 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py index 37c503223..5d6ef47fb 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py @@ -1,23 +1,24 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass import json -from typing import Any, Callable, Collection, Dict, Optional, Tuple, Union +from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, Union, cast from wrapt import register_post_import_hook, wrap_function_wrapper from opentelemetry import trace +from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap -from opentelemetry.propagate import get_global_textmap -from opentelemetry.semconv.attributes.network_attributes import NetworkTransportValues from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import INVALID_SPAN, SpanKind, Status, StatusCode +from opentelemetry.propagate import get_global_textmap + +from .version import __version__ from .semconv import ( - MCPMethodValue, MCPSpanAttributes, + MCPMethodValue, ) -from .version import __version__ class McpInstrumentor(BaseInstrumentor): @@ -28,7 +29,6 @@ class McpInstrumentor(BaseInstrumentor): _DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client" _DEFAULT_SERVER_SPAN_NAME = "span.mcp.server" _MCP_SESSION_ID_HEADER = "mcp-session-id" - _MCP_META_FIELD = "_meta" def __init__(self, **kwargs): super().__init__() @@ -55,14 +55,6 @@ def _instrument(self, **kwargs: Any) -> None: ), "mcp.shared.session", ) - register_post_import_hook( - lambda _: wrap_function_wrapper( - "mcp.shared.session", - "BaseSession._send_response", - self._wrap_send_response, - ), - "mcp.shared.session", - ) register_post_import_hook( lambda _: wrap_function_wrapper( "mcp.server.lowlevel.server", @@ -83,7 +75,6 @@ def _instrument(self, **kwargs: Any) -> None: def _uninstrument(self, **kwargs: Any) -> None: unwrap("mcp.shared.session", "BaseSession.send_request") unwrap("mcp.shared.session", "BaseSession.send_notification") - unwrap("mcp.shared.session", "BaseSession._send_response") unwrap("mcp.server.lowlevel.server", "Server._handle_request") unwrap("mcp.server.lowlevel.server", "Server._handle_notification") @@ -107,7 +98,7 @@ def _wrap_session_send( args: Positional arguments passed to the original send_request/send_notification method kwargs: Keyword arguments passed to the original send_request/send_notification method """ - from mcp.types import ClientNotification, ClientRequest, ServerNotification, ServerRequest + from mcp.types import ClientRequest, ClientNotification, ServerRequest, ServerNotification async def async_wrapper(): message: Optional[Union[ClientRequest, ClientNotification, ServerRequest, ServerNotification]] = ( @@ -129,18 +120,14 @@ async def async_wrapper(): if "params" not in message_json: message_json["params"] = {} - if self._MCP_META_FIELD not in message_json["params"]: - message_json["params"][self._MCP_META_FIELD] = {} - - parent_ctx = None - if message_json["params"][self._MCP_META_FIELD]: - parent_ctx = self.propagators.extract(message_json["params"][self._MCP_META_FIELD]) + if "_meta" not in message_json["params"]: + message_json["params"]["_meta"] = {} - with self.tracer.start_as_current_span(name=span_name, kind=span_kind, context=parent_ctx) as span: + with self.tracer.start_as_current_span(name=span_name, kind=span_kind) as span: ctx = trace.set_span_in_context(span) carrier = {} self.propagators.inject(carrier=carrier, context=ctx) - message_json["params"][self._MCP_META_FIELD].update(carrier) + message_json["params"]["_meta"].update(carrier) McpInstrumentor._generate_mcp_message_attrs(span, message, request_id) @@ -240,14 +227,11 @@ async def _wrap_server_message_handler( with self.tracer.start_as_current_span( self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx ) as server_span: - - # session_id only exits if the transport protocol is Streamable HTTP + + # Extract session ID if available session_id = self._extract_session_id(args) if session_id: server_span.set_attribute(MCPSpanAttributes.MCP_SESSION_ID, session_id) - server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.PIPE.value) - else: - server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.TCP.value) self._generate_mcp_message_attrs(server_span, incoming_msg, request_id) @@ -259,60 +243,27 @@ async def _wrap_server_message_handler( server_span.set_status(Status(StatusCode.ERROR, str(e))) server_span.record_exception(e) raise - - async def _wrap_send_response( - self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Any: - """ - Instruments BaseSession._send_response to propagate trace context into response. - - Note: we do not need to generate another span for the reponse as it falls under - the _wrap_server_message_handler - """ - response = args[1] if len(args) > 1 else kwargs.get("response", None) - - if not response: - return await wrapped(*args, **kwargs) - - current_span = trace.get_current_span() - if current_span is not INVALID_SPAN: - # Inject trace context into response - carrier = {} - self.propagators.inject(carrier=carrier, context=trace.set_span_in_context(current_span)) - - response_json = response.model_dump(by_alias=True, mode="json", exclude_none=True) - - if self._MCP_META_FIELD not in response_json: - response_json[self._MCP_META_FIELD] = {} - response_json[self._MCP_META_FIELD].update(carrier) - - modified_response = response.model_validate(response_json) - - if len(args) > 1: - args = args[:1] + (modified_response,) + args[2:] - else: - kwargs["response"] = modified_response - - return await wrapped(*args, **kwargs) - + def _extract_session_id(self, args: Tuple[Any, ...]) -> Optional[str]: """ Extract session ID from server method arguments. """ try: + from mcp.shared.session import RequestResponder # pylint: disable=import-outside-toplevel + from mcp.shared.message import ServerMessageMetadata # pylint: disable=import-outside-toplevel + message = args[0] - if hasattr(message, "message_metadata"): - message_metadata = message.message_metadata - if message.message_metadata and hasattr(message_metadata, "request_context"): - request_context = message_metadata.request_context - if request_context and hasattr(request_context, "headers"): - headers = request_context.headers + if isinstance(message, RequestResponder): + if message.message_metadata and isinstance(message.message_metadata, ServerMessageMetadata): + request_context = message.message_metadata.request_context + if request_context: + headers = getattr(request_context, 'headers', None) if headers: return headers.get(self._MCP_SESSION_ID_HEADER) return None except Exception: return None - + @staticmethod def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None: import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import @@ -375,4 +326,4 @@ def serialize(args: dict[str, Any]) -> str: try: return json.dumps(args) except Exception: - return "unknown_args" + return ""