|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, cast |
| 4 | + |
| 5 | +from wrapt import register_post_import_hook, wrap_function_wrapper |
| 6 | + |
| 7 | +from opentelemetry import trace |
| 8 | +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor |
| 9 | +from opentelemetry.instrumentation.utils import unwrap |
| 10 | +from opentelemetry.semconv.trace import SpanAttributes |
| 11 | +from opentelemetry.propagate import get_global_textmap |
| 12 | + |
| 13 | +from .version import __version__ |
| 14 | + |
| 15 | +from .semconv import ( |
| 16 | + CLIENT_INITIALIZED, |
| 17 | + MCP_METHOD_NAME, |
| 18 | + TOOLS_CALL, |
| 19 | + TOOLS_LIST, |
| 20 | + MCPAttributes, |
| 21 | + MCPOperations, |
| 22 | + MCPSpanNames, |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +class McpInstrumentor(BaseInstrumentor): |
| 27 | + """ |
| 28 | + An instrumentor class for MCP. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self, **kwargs): |
| 32 | + super().__init__() |
| 33 | + self.propagators = kwargs.get("propagators") or get_global_textmap() |
| 34 | + self.tracer = trace.get_tracer(__name__, __version__, tracer_provider=kwargs.get("tracer_provider", None)) |
| 35 | + |
| 36 | + def instrumentation_dependencies(self) -> Collection[str]: |
| 37 | + return "mcp >= 1.6.0" |
| 38 | + |
| 39 | + def _instrument(self, **kwargs: Any) -> None: |
| 40 | + |
| 41 | + register_post_import_hook( |
| 42 | + lambda _: wrap_function_wrapper( |
| 43 | + "mcp.shared.session", |
| 44 | + "BaseSession.send_request", |
| 45 | + self._wrap_send_request, |
| 46 | + ), |
| 47 | + "mcp.shared.session", |
| 48 | + ) |
| 49 | + register_post_import_hook( |
| 50 | + lambda _: wrap_function_wrapper( |
| 51 | + "mcp.server.lowlevel.server", |
| 52 | + "Server._handle_request", |
| 53 | + self._wrap_handle_request, |
| 54 | + ), |
| 55 | + "mcp.server.lowlevel.server", |
| 56 | + ) |
| 57 | + |
| 58 | + def _uninstrument(self, **kwargs: Any) -> None: |
| 59 | + unwrap("mcp.shared.session", "BaseSession.send_request") |
| 60 | + unwrap("mcp.server.lowlevel.server", "Server._handle_request") |
| 61 | + |
| 62 | + def _wrap_send_request( |
| 63 | + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] |
| 64 | + ) -> Callable: |
| 65 | + import mcp.types as types |
| 66 | + |
| 67 | + """ |
| 68 | + Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server. |
| 69 | + This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts |
| 70 | + the current trace context, and embeds it into the request's params._meta.traceparent field |
| 71 | + before forwarding the request to the MCP server. |
| 72 | + """ |
| 73 | + |
| 74 | + async def async_wrapper(): |
| 75 | + request: Optional[types.ClientRequest] = args[0] if len(args) > 0 else None |
| 76 | + |
| 77 | + if not request: |
| 78 | + return await wrapped(*args, **kwargs) |
| 79 | + |
| 80 | + request_as_json = request.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 81 | + |
| 82 | + if "params" not in request_as_json: |
| 83 | + request_as_json["params"] = {} |
| 84 | + |
| 85 | + if "_meta" not in request_as_json["params"]: |
| 86 | + request_as_json["params"]["_meta"] = {} |
| 87 | + |
| 88 | + with self.tracer.start_as_current_span( |
| 89 | + MCPSpanNames.SPAN_MCP_CLIENT, kind=trace.SpanKind.CLIENT |
| 90 | + ) as mcp_client_span: |
| 91 | + |
| 92 | + if request: |
| 93 | + span_ctx = trace.set_span_in_context(mcp_client_span) |
| 94 | + parent_span = {} |
| 95 | + self.propagators.inject(carrier=parent_span, context=span_ctx) |
| 96 | + |
| 97 | + McpInstrumentor._set_mcp_client_attributes(mcp_client_span, request) |
| 98 | + |
| 99 | + request_as_json["params"]["_meta"].update(parent_span) |
| 100 | + |
| 101 | + # Reconstruct request object with injected trace context |
| 102 | + modified_request = request.model_validate(request_as_json) |
| 103 | + new_args = (modified_request,) + args[1:] |
| 104 | + |
| 105 | + return await wrapped(*new_args, **kwargs) |
| 106 | + |
| 107 | + return async_wrapper |
| 108 | + |
| 109 | + # Handle Request Wrapper |
| 110 | + async def _wrap_handle_request( |
| 111 | + self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] |
| 112 | + ) -> Any: |
| 113 | + """ |
| 114 | + Patches Server._handle_request which is responsible for processing requests on the MCP server. |
| 115 | + This patched MCP server intercepts incoming requests to extract tracing context from |
| 116 | + the request's params._meta field and creates server-side spans linked to the client spans. |
| 117 | + """ |
| 118 | + req = args[1] if len(args) > 1 else None |
| 119 | + carrier = {} |
| 120 | + |
| 121 | + if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta: |
| 122 | + carrier = req.params.meta.__dict__ |
| 123 | + |
| 124 | + parent_ctx = self.propagators.extract(carrier=carrier) |
| 125 | + |
| 126 | + if parent_ctx: |
| 127 | + with self.tracer.start_as_current_span( |
| 128 | + MCPSpanNames.SPAN_MCP_SERVER, kind=trace.SpanKind.SERVER, context=parent_ctx |
| 129 | + ) as mcp_server_span: |
| 130 | + self._set_mcp_server_attributes(mcp_server_span, req) |
| 131 | + |
| 132 | + return await wrapped(*args, **kwargs) |
| 133 | + |
| 134 | + @staticmethod |
| 135 | + def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None: |
| 136 | + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import |
| 137 | + |
| 138 | + if isinstance(request, types.ListToolsRequest): |
| 139 | + span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) |
| 140 | + if isinstance(request, types.CallToolRequest): |
| 141 | + tool_name = request.params.name |
| 142 | + span.update_name(f"{TOOLS_CALL} {tool_name}") |
| 143 | + span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) |
| 144 | + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) |
| 145 | + if isinstance(request, types.InitializeRequest): |
| 146 | + span.set_attribute(MCP_METHOD_NAME, CLIENT_INITIALIZED) |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def _set_mcp_server_attributes(span: trace.Span, request: Any) -> None: |
| 150 | + import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import |
| 151 | + |
| 152 | + if isinstance(span, types.ListToolsRequest): |
| 153 | + span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST) |
| 154 | + if isinstance(span, types.CallToolRequest): |
| 155 | + tool_name = request.params.name |
| 156 | + span.update_name(f"{TOOLS_CALL} {tool_name}") |
| 157 | + span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL) |
| 158 | + span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name) |
0 commit comments