1
1
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
# SPDX-License-Identifier: Apache-2.0
3
+ from dataclasses import dataclass
3
4
import json
4
- from typing import Any , Callable , Collection , Dict , Optional , Tuple , Union
5
+ from typing import Any , AsyncGenerator , Callable , Collection , Dict , Optional , Tuple , Union , cast
5
6
6
7
from wrapt import register_post_import_hook , wrap_function_wrapper
7
8
8
9
from opentelemetry import trace
10
+ from opentelemetry .trace import SpanKind , Status , StatusCode
9
11
from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
10
12
from opentelemetry .instrumentation .utils import unwrap
11
- from opentelemetry .propagate import get_global_textmap
12
- from opentelemetry .semconv .attributes .network_attributes import NetworkTransportValues
13
13
from opentelemetry .semconv .trace import SpanAttributes
14
- from opentelemetry .trace import INVALID_SPAN , SpanKind , Status , StatusCode
14
+ from opentelemetry .propagate import get_global_textmap
15
+
16
+ from .version import __version__
15
17
16
18
from .semconv import (
17
- MCPMethodValue ,
18
19
MCPSpanAttributes ,
20
+ MCPMethodValue ,
19
21
)
20
- from .version import __version__
21
22
22
23
23
24
class McpInstrumentor (BaseInstrumentor ):
@@ -28,7 +29,6 @@ class McpInstrumentor(BaseInstrumentor):
28
29
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
29
30
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
30
31
_MCP_SESSION_ID_HEADER = "mcp-session-id"
31
- _MCP_META_FIELD = "_meta"
32
32
33
33
def __init__ (self , ** kwargs ):
34
34
super ().__init__ ()
@@ -55,14 +55,6 @@ def _instrument(self, **kwargs: Any) -> None:
55
55
),
56
56
"mcp.shared.session" ,
57
57
)
58
- register_post_import_hook (
59
- lambda _ : wrap_function_wrapper (
60
- "mcp.shared.session" ,
61
- "BaseSession._send_response" ,
62
- self ._wrap_send_response ,
63
- ),
64
- "mcp.shared.session" ,
65
- )
66
58
register_post_import_hook (
67
59
lambda _ : wrap_function_wrapper (
68
60
"mcp.server.lowlevel.server" ,
@@ -83,7 +75,6 @@ def _instrument(self, **kwargs: Any) -> None:
83
75
def _uninstrument (self , ** kwargs : Any ) -> None :
84
76
unwrap ("mcp.shared.session" , "BaseSession.send_request" )
85
77
unwrap ("mcp.shared.session" , "BaseSession.send_notification" )
86
- unwrap ("mcp.shared.session" , "BaseSession._send_response" )
87
78
unwrap ("mcp.server.lowlevel.server" , "Server._handle_request" )
88
79
unwrap ("mcp.server.lowlevel.server" , "Server._handle_notification" )
89
80
@@ -107,7 +98,7 @@ def _wrap_session_send(
107
98
args: Positional arguments passed to the original send_request/send_notification method
108
99
kwargs: Keyword arguments passed to the original send_request/send_notification method
109
100
"""
110
- from mcp .types import ClientNotification , ClientRequest , ServerNotification , ServerRequest
101
+ from mcp .types import ClientRequest , ClientNotification , ServerRequest , ServerNotification
111
102
112
103
async def async_wrapper ():
113
104
message : Optional [Union [ClientRequest , ClientNotification , ServerRequest , ServerNotification ]] = (
@@ -129,18 +120,14 @@ async def async_wrapper():
129
120
130
121
if "params" not in message_json :
131
122
message_json ["params" ] = {}
132
- if self ._MCP_META_FIELD not in message_json ["params" ]:
133
- message_json ["params" ][self ._MCP_META_FIELD ] = {}
134
-
135
- parent_ctx = None
136
- if message_json ["params" ][self ._MCP_META_FIELD ]:
137
- parent_ctx = self .propagators .extract (message_json ["params" ][self ._MCP_META_FIELD ])
123
+ if "_meta" not in message_json ["params" ]:
124
+ message_json ["params" ]["_meta" ] = {}
138
125
139
- with self .tracer .start_as_current_span (name = span_name , kind = span_kind , context = parent_ctx ) as span :
126
+ with self .tracer .start_as_current_span (name = span_name , kind = span_kind ) as span :
140
127
ctx = trace .set_span_in_context (span )
141
128
carrier = {}
142
129
self .propagators .inject (carrier = carrier , context = ctx )
143
- message_json ["params" ][self . _MCP_META_FIELD ].update (carrier )
130
+ message_json ["params" ]["_meta" ].update (carrier )
144
131
145
132
McpInstrumentor ._generate_mcp_message_attrs (span , message , request_id )
146
133
@@ -240,14 +227,11 @@ async def _wrap_server_message_handler(
240
227
with self .tracer .start_as_current_span (
241
228
self ._DEFAULT_SERVER_SPAN_NAME , kind = SpanKind .SERVER , context = parent_ctx
242
229
) as server_span :
243
-
244
- # session_id only exits if the transport protocol is Streamable HTTP
230
+
231
+ # Extract session ID if available
245
232
session_id = self ._extract_session_id (args )
246
233
if session_id :
247
234
server_span .set_attribute (MCPSpanAttributes .MCP_SESSION_ID , session_id )
248
- server_span .set_attribute (SpanAttributes .NETWORK_TRANSPORT , NetworkTransportValues .PIPE .value )
249
- else :
250
- server_span .set_attribute (SpanAttributes .NETWORK_TRANSPORT , NetworkTransportValues .TCP .value )
251
235
252
236
self ._generate_mcp_message_attrs (server_span , incoming_msg , request_id )
253
237
@@ -259,60 +243,27 @@ async def _wrap_server_message_handler(
259
243
server_span .set_status (Status (StatusCode .ERROR , str (e )))
260
244
server_span .record_exception (e )
261
245
raise
262
-
263
- async def _wrap_send_response (
264
- self , wrapped : Callable , instance : Any , args : Tuple [Any , ...], kwargs : Dict [str , Any ]
265
- ) -> Any :
266
- """
267
- Instruments BaseSession._send_response to propagate trace context into response.
268
-
269
- Note: we do not need to generate another span for the reponse as it falls under
270
- the _wrap_server_message_handler
271
- """
272
- response = args [1 ] if len (args ) > 1 else kwargs .get ("response" , None )
273
-
274
- if not response :
275
- return await wrapped (* args , ** kwargs )
276
-
277
- current_span = trace .get_current_span ()
278
- if current_span is not INVALID_SPAN :
279
- # Inject trace context into response
280
- carrier = {}
281
- self .propagators .inject (carrier = carrier , context = trace .set_span_in_context (current_span ))
282
-
283
- response_json = response .model_dump (by_alias = True , mode = "json" , exclude_none = True )
284
-
285
- if self ._MCP_META_FIELD not in response_json :
286
- response_json [self ._MCP_META_FIELD ] = {}
287
- response_json [self ._MCP_META_FIELD ].update (carrier )
288
-
289
- modified_response = response .model_validate (response_json )
290
-
291
- if len (args ) > 1 :
292
- args = args [:1 ] + (modified_response ,) + args [2 :]
293
- else :
294
- kwargs ["response" ] = modified_response
295
-
296
- return await wrapped (* args , ** kwargs )
297
-
246
+
298
247
def _extract_session_id (self , args : Tuple [Any , ...]) -> Optional [str ]:
299
248
"""
300
249
Extract session ID from server method arguments.
301
250
"""
302
251
try :
252
+ from mcp .shared .session import RequestResponder # pylint: disable=import-outside-toplevel
253
+ from mcp .shared .message import ServerMessageMetadata # pylint: disable=import-outside-toplevel
254
+
303
255
message = args [0 ]
304
- if hasattr (message , "message_metadata" ):
305
- message_metadata = message .message_metadata
306
- if message .message_metadata and hasattr (message_metadata , "request_context" ):
307
- request_context = message_metadata .request_context
308
- if request_context and hasattr (request_context , "headers" ):
309
- headers = request_context .headers
256
+ if isinstance (message , RequestResponder ):
257
+ if message .message_metadata and isinstance (message .message_metadata , ServerMessageMetadata ):
258
+ request_context = message .message_metadata .request_context
259
+ if request_context :
260
+ headers = getattr (request_context , 'headers' , None )
310
261
if headers :
311
262
return headers .get (self ._MCP_SESSION_ID_HEADER )
312
263
return None
313
264
except Exception :
314
265
return None
315
-
266
+
316
267
@staticmethod
317
268
def _generate_mcp_message_attrs (span : trace .Span , message , request_id : Optional [int ]) -> None :
318
269
import mcp .types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
@@ -375,4 +326,4 @@ def serialize(args: dict[str, Any]) -> str:
375
326
try :
376
327
return json .dumps (args )
377
328
except Exception :
378
- return "unknown_args "
329
+ return ""
0 commit comments