1
1
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
# SPDX-License-Identifier: Apache-2.0
3
- from typing import Any , Callable , Collection , Dict , Tuple
3
+ from typing import Any , Callable , Collection , Dict , Optional , Tuple
4
4
5
5
from wrapt import register_post_import_hook , wrap_function_wrapper
6
6
7
7
from opentelemetry import trace
8
8
from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
9
9
from opentelemetry .instrumentation .utils import unwrap
10
10
from opentelemetry .semconv .trace import SpanAttributes
11
+ from opentelemetry .instrumentation .mcp .version import __version__
12
+ from opentelemetry .propagate import get_global_textmap
11
13
12
- from .constants import MCPEnvironmentVariables , MCPTraceContext
13
- from .semconv import MCPAttributes , MCPOperations , MCPSpanNames
14
+ from .semconv import CLIENT_INITIALIZED , MCP_METHOD_NAME , TOOLS_CALL , TOOLS_LIST , MCPAttributes , MCPOperations , MCPSpanNames
14
15
15
16
16
- class MCPInstrumentor (BaseInstrumentor ):
17
+ class McpInstrumentor (BaseInstrumentor ):
17
18
"""
18
19
An instrumenter for MCP.
19
20
"""
20
21
21
- def __init__ (self ):
22
+ def __init__ (self , ** kwargs ):
22
23
super ().__init__ ()
23
- self .tracer = None
24
+ self .propagators = kwargs .get ("propagators" ) or get_global_textmap ()
25
+ self .tracer = trace .get_tracer (__name__ , __version__ , tracer_provider = kwargs .get ("tracer_provider" , None ))
24
26
25
- @staticmethod
26
- def instrumentation_dependencies () -> Collection [str ]:
27
- return ("mcp >= 1.6.0" ,)
27
+ def instrumentation_dependencies (self ) -> Collection [str ]:
28
+ return "mcp >= 1.6.0"
28
29
29
30
def _instrument (self , ** kwargs : Any ) -> None :
30
- tracer_provider = kwargs .get ("tracer_provider" )
31
- if tracer_provider :
32
- self .tracer = tracer_provider .get_tracer ("instrumentation.mcp" )
33
- else :
34
- self .tracer = trace .get_tracer ("instrumentation.mcp" )
35
31
register_post_import_hook (
36
32
lambda _ : wrap_function_wrapper (
37
33
"mcp.shared.session" ,
@@ -49,48 +45,61 @@ def _instrument(self, **kwargs: Any) -> None:
49
45
"mcp.server.lowlevel.server" ,
50
46
)
51
47
52
- @staticmethod
53
- def _uninstrument (** kwargs : Any ) -> None :
48
+ def _uninstrument (self , ** kwargs : Any ) -> None :
54
49
unwrap ("mcp.shared.session" , "BaseSession.send_request" )
55
50
unwrap ("mcp.server.lowlevel.server" , "Server._handle_request" )
56
-
57
- # Send Request Wrapper
51
+
52
+
58
53
def _wrap_send_request (
59
54
self , wrapped : Callable , instance : Any , args : Tuple [Any , ...], kwargs : Dict [str , Any ]
60
55
) -> Callable :
61
- """
62
- Changes made:
63
- The wrapper intercepts the request before sending, injects distributed tracing context into the
64
- request's params._meta field and creates OpenTelemetry spans. The wrapper does not change anything
65
- else from the original function's behavior because it reconstructs the request object with the same
66
- type and calling the original function with identical parameters.
56
+ import mcp .types as types
57
+ """
58
+ Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server.
59
+ This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts
60
+ the current trace context, and embeds it into the request's params._meta.traceparent field
61
+ before forwarding the request to the MCP server.
62
+
63
+ Args:
64
+ wrapped: The original BaseSession.send_request function
65
+ instance: The BaseSession instance
66
+ args: Positional arguments, where args[0] is typically the request object
67
+ kwargs: Keyword arguments, may contain 'request' parameter
68
+
69
+ Returns:
70
+ Callable: Async wrapper function that handles trace context injection
67
71
"""
68
72
69
73
async def async_wrapper ():
74
+ request : Optional [types .ClientRequest ] = args [0 ] if len (args ) > 0 else None
75
+
76
+ if not request :
77
+ return await wrapped (* args , ** kwargs )
78
+
70
79
with self .tracer .start_as_current_span (
71
80
MCPSpanNames .CLIENT_SEND_REQUEST , kind = trace .SpanKind .CLIENT
72
81
) as span :
73
- span_ctx = span .get_span_context ()
74
- request = args [0 ] if len (args ) > 0 else kwargs .get ("request" )
82
+
75
83
if request :
76
- req_root = request .root if hasattr (request , "root" ) else request
77
-
78
- self ._generate_mcp_attributes (span , req_root , is_client = True )
84
+ span_ctx = trace .set_span_in_context (span )
85
+ parent_span = {}
86
+ self .propagators .inject (carrier = parent_span , context = span_ctx )
87
+
79
88
request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
80
- self ._inject_trace_context (request_data , span_ctx )
89
+
90
+ if "params" not in request_data :
91
+ request_data ["params" ] = {}
92
+ if "_meta" not in request_data ["params" ]:
93
+ request_data ["params" ]["_meta" ] = {}
94
+ request_data ["params" ]["_meta" ].update (parent_span )
95
+
81
96
# Reconstruct request object with injected trace context
82
- modified_request = type (request ).model_validate (request_data )
83
- if len (args ) > 0 :
84
- new_args = (modified_request ,) + args [1 :]
85
- result = await wrapped (* new_args , ** kwargs )
86
- else :
87
- kwargs ["request" ] = modified_request
88
- result = await wrapped (* args , ** kwargs )
89
- else :
90
- result = await wrapped (* args , ** kwargs )
91
- return result
97
+ modified_request = request .model_validate (request_data )
98
+ new_args = (modified_request ,) + args [1 :]
99
+
100
+ return await wrapped (* new_args , ** kwargs )
92
101
93
- return async_wrapper ()
102
+ return async_wrapper
94
103
95
104
# Handle Request Wrapper
96
105
async def _wrap_handle_request (
@@ -111,7 +120,7 @@ async def _wrap_handle_request(
111
120
traceparent = None
112
121
113
122
if req and hasattr (req , "params" ) and req .params and hasattr (req .params , "meta" ) and req .params .meta :
114
- traceparent = getattr ( req . params . meta , MCPTraceContext . TRACEPARENT_HEADER , None )
123
+ traceparent = None
115
124
span_context = self ._extract_span_context_from_traceparent (traceparent ) if traceparent else None
116
125
if span_context :
117
126
span_name = self ._get_mcp_operation (req )
@@ -130,40 +139,18 @@ async def _wrap_handle_request(
130
139
def _generate_mcp_attributes (span : trace .Span , request : Any , is_client : bool ) -> None :
131
140
import mcp .types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
132
141
133
- operation = MCPOperations .UNKNOWN_OPERATION
134
-
135
142
if isinstance (request , types .ListToolsRequest ):
136
- operation = MCPOperations .LIST_TOOL
137
- span .set_attribute (MCPAttributes .MCP_LIST_TOOLS , True )
143
+ span .set_attribute (MCP_METHOD_NAME , TOOLS_LIST )
138
144
if is_client :
139
145
span .update_name (MCPSpanNames .CLIENT_LIST_TOOLS )
140
146
elif isinstance (request , types .CallToolRequest ):
141
- operation = request .params .name
142
- span .set_attribute (MCPAttributes .MCP_CALL_TOOL , True )
147
+ span .set_attribute (MCP_METHOD_NAME , TOOLS_CALL )
143
148
if is_client :
144
149
span .update_name (MCPSpanNames .client_call_tool (request .params .name ))
145
150
elif isinstance (request , types .InitializeRequest ):
146
- operation = MCPOperations .INITIALIZE
147
- span .set_attribute (MCPAttributes .MCP_INITIALIZE , True )
148
- if is_client :
149
- span .update_name (MCPSpanNames .CLIENT_INITIALIZE )
151
+ span .set_attribute (MCP_METHOD_NAME , CLIENT_INITIALIZED )
150
152
151
- if is_client :
152
- MCPInstrumentor ._add_client_attributes (span , operation , request )
153
- else :
154
- MCPInstrumentor ._add_server_attributes (span , operation , request )
155
-
156
- @staticmethod
157
- def _inject_trace_context (request_data : Dict [str , Any ], span_ctx ) -> None :
158
- if "params" not in request_data :
159
- request_data ["params" ] = {}
160
- if "_meta" not in request_data ["params" ]:
161
- request_data ["params" ]["_meta" ] = {}
162
- trace_id_hex = f"{ span_ctx .trace_id :032x} "
163
- span_id_hex = f"{ span_ctx .span_id :016x} "
164
- trace_flags = MCPTraceContext .TRACE_FLAGS_SAMPLED
165
- traceparent = f"{ MCPTraceContext .TRACEPARENT_VERSION } -{ trace_id_hex } -{ span_id_hex } -{ trace_flags } "
166
- request_data ["params" ]["_meta" ][MCPTraceContext .TRACEPARENT_HEADER ] = traceparent
153
+ # Additional attributes can be added here if needed
167
154
168
155
@staticmethod
169
156
def _extract_span_context_from_traceparent (traceparent : str ):
@@ -195,18 +182,3 @@ def _get_mcp_operation(req: Any) -> str:
195
182
elif isinstance (req , types .CallToolRequest ):
196
183
span_name = MCPSpanNames .tools_call (req .params .name )
197
184
return span_name
198
-
199
- @staticmethod
200
- def _add_client_attributes (span : trace .Span , operation : str , request : Any ) -> None :
201
- import os # pylint: disable=import-outside-toplevel
202
-
203
- service_name = os .environ .get (MCPEnvironmentVariables .SERVER_NAME , "mcp server" )
204
- span .set_attribute (SpanAttributes .RPC_SERVICE , service_name )
205
- span .set_attribute (SpanAttributes .RPC_METHOD , operation )
206
- if hasattr (request , "params" ) and request .params and hasattr (request .params , "name" ):
207
- span .set_attribute (MCPAttributes .MCP_TOOL_NAME , request .params .name )
208
-
209
- @staticmethod
210
- def _add_server_attributes (span : trace .Span , operation : str , request : Any ) -> None :
211
- if hasattr (request , "params" ) and request .params and hasattr (request .params , "name" ):
212
- span .set_attribute (MCPAttributes .MCP_TOOL_NAME , request .params .name )
0 commit comments