1
1
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
# SPDX-License-Identifier: Apache-2.0
3
- from dataclasses import dataclass
4
3
import json
5
- from typing import Any , AsyncGenerator , Callable , Collection , Dict , Optional , Tuple , Union , cast
4
+ from typing import Any , Callable , Collection , Dict , Optional , Tuple , Union
6
5
7
6
from wrapt import register_post_import_hook , wrap_function_wrapper
8
7
9
8
from opentelemetry import trace
10
- from opentelemetry .trace import SpanKind , Status , StatusCode
11
9
from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
12
10
from opentelemetry .instrumentation .utils import unwrap
13
- from opentelemetry .semconv .trace import SpanAttributes
14
11
from opentelemetry .propagate import get_global_textmap
15
-
16
- from .version import __version__
12
+ from opentelemetry .semconv .attributes .network_attributes import NetworkTransportValues
13
+ from opentelemetry .semconv .trace import SpanAttributes
14
+ from opentelemetry .trace import INVALID_SPAN , SpanKind , Status , StatusCode
17
15
18
16
from .semconv import (
19
- MCPSpanAttributes ,
20
17
MCPMethodValue ,
18
+ MCPSpanAttributes ,
21
19
)
20
+ from .version import __version__
22
21
23
22
24
23
class McpInstrumentor (BaseInstrumentor ):
@@ -29,6 +28,7 @@ class McpInstrumentor(BaseInstrumentor):
29
28
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
30
29
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
31
30
_MCP_SESSION_ID_HEADER = "mcp-session-id"
31
+ _MCP_META_FIELD = "_meta"
32
32
33
33
def __init__ (self , ** kwargs ):
34
34
super ().__init__ ()
@@ -55,6 +55,14 @@ def _instrument(self, **kwargs: Any) -> None:
55
55
),
56
56
"mcp.shared.session" ,
57
57
)
58
+ register_post_import_hook (
59
+ lambda _ : wrap_function_wrapper (
60
+ "mcp.shared.session" ,
61
+ "BaseSession._send_response" ,
62
+ self ._wrap_send_response ,
63
+ ),
64
+ "mcp.shared.session" ,
65
+ )
58
66
register_post_import_hook (
59
67
lambda _ : wrap_function_wrapper (
60
68
"mcp.server.lowlevel.server" ,
@@ -75,6 +83,7 @@ def _instrument(self, **kwargs: Any) -> None:
75
83
def _uninstrument (self , ** kwargs : Any ) -> None :
76
84
unwrap ("mcp.shared.session" , "BaseSession.send_request" )
77
85
unwrap ("mcp.shared.session" , "BaseSession.send_notification" )
86
+ unwrap ("mcp.shared.session" , "BaseSession._send_response" )
78
87
unwrap ("mcp.server.lowlevel.server" , "Server._handle_request" )
79
88
unwrap ("mcp.server.lowlevel.server" , "Server._handle_notification" )
80
89
@@ -98,7 +107,7 @@ def _wrap_session_send(
98
107
args: Positional arguments passed to the original send_request/send_notification method
99
108
kwargs: Keyword arguments passed to the original send_request/send_notification method
100
109
"""
101
- from mcp .types import ClientRequest , ClientNotification , ServerRequest , ServerNotification
110
+ from mcp .types import ClientNotification , ClientRequest , ServerNotification , ServerRequest
102
111
103
112
async def async_wrapper ():
104
113
message : Optional [Union [ClientRequest , ClientNotification , ServerRequest , ServerNotification ]] = (
@@ -120,14 +129,18 @@ async def async_wrapper():
120
129
121
130
if "params" not in message_json :
122
131
message_json ["params" ] = {}
123
- if "_meta" not in message_json ["params" ]:
124
- message_json ["params" ]["_meta" ] = {}
132
+ if self ._MCP_META_FIELD not in message_json ["params" ]:
133
+ message_json ["params" ][self ._MCP_META_FIELD ] = {}
134
+
135
+ parent_ctx = None
136
+ if message_json ["params" ][self ._MCP_META_FIELD ]:
137
+ parent_ctx = self .propagators .extract (message_json ["params" ][self ._MCP_META_FIELD ])
125
138
126
- with self .tracer .start_as_current_span (name = span_name , kind = span_kind ) as span :
139
+ with self .tracer .start_as_current_span (name = span_name , kind = span_kind , context = parent_ctx ) as span :
127
140
ctx = trace .set_span_in_context (span )
128
141
carrier = {}
129
142
self .propagators .inject (carrier = carrier , context = ctx )
130
- message_json ["params" ]["_meta" ].update (carrier )
143
+ message_json ["params" ][self . _MCP_META_FIELD ].update (carrier )
131
144
132
145
McpInstrumentor ._generate_mcp_message_attrs (span , message , request_id )
133
146
@@ -227,11 +240,14 @@ async def _wrap_server_message_handler(
227
240
with self .tracer .start_as_current_span (
228
241
self ._DEFAULT_SERVER_SPAN_NAME , kind = SpanKind .SERVER , context = parent_ctx
229
242
) as server_span :
230
-
231
- # Extract session ID if available
243
+
244
+ # session_id only exits if the transport protocol is Streamable HTTP
232
245
session_id = self ._extract_session_id (args )
233
246
if session_id :
234
247
server_span .set_attribute (MCPSpanAttributes .MCP_SESSION_ID , session_id )
248
+ server_span .set_attribute (SpanAttributes .NETWORK_TRANSPORT , NetworkTransportValues .PIPE .value )
249
+ else :
250
+ server_span .set_attribute (SpanAttributes .NETWORK_TRANSPORT , NetworkTransportValues .TCP .value )
235
251
236
252
self ._generate_mcp_message_attrs (server_span , incoming_msg , request_id )
237
253
@@ -243,27 +259,60 @@ async def _wrap_server_message_handler(
243
259
server_span .set_status (Status (StatusCode .ERROR , str (e )))
244
260
server_span .record_exception (e )
245
261
raise
246
-
262
+
263
+ async def _wrap_send_response (
264
+ self , wrapped : Callable , instance : Any , args : Tuple [Any , ...], kwargs : Dict [str , Any ]
265
+ ) -> Any :
266
+ """
267
+ Instruments BaseSession._send_response to propagate trace context into response.
268
+
269
+ Note: we do not need to generate another span for the reponse as it falls under
270
+ the _wrap_server_message_handler
271
+ """
272
+ response = args [1 ] if len (args ) > 1 else kwargs .get ("response" , None )
273
+
274
+ if not response :
275
+ return await wrapped (* args , ** kwargs )
276
+
277
+ current_span = trace .get_current_span ()
278
+ if current_span is not INVALID_SPAN :
279
+ # Inject trace context into response
280
+ carrier = {}
281
+ self .propagators .inject (carrier = carrier , context = trace .set_span_in_context (current_span ))
282
+
283
+ response_json = response .model_dump (by_alias = True , mode = "json" , exclude_none = True )
284
+
285
+ if self ._MCP_META_FIELD not in response_json :
286
+ response_json [self ._MCP_META_FIELD ] = {}
287
+ response_json [self ._MCP_META_FIELD ].update (carrier )
288
+
289
+ modified_response = response .model_validate (response_json )
290
+
291
+ if len (args ) > 1 :
292
+ args = args [:1 ] + (modified_response ,) + args [2 :]
293
+ else :
294
+ kwargs ["response" ] = modified_response
295
+
296
+ return await wrapped (* args , ** kwargs )
297
+
247
298
def _extract_session_id (self , args : Tuple [Any , ...]) -> Optional [str ]:
248
299
"""
249
300
Extract session ID from server method arguments.
250
301
"""
251
302
try :
252
- from mcp .shared .session import RequestResponder # pylint: disable=import-outside-toplevel
253
- from mcp .shared .message import ServerMessageMetadata # pylint: disable=import-outside-toplevel
254
-
255
303
message = args [0 ]
256
- if isinstance (message , RequestResponder ):
257
- if message .message_metadata and isinstance (message .message_metadata , ServerMessageMetadata ):
258
- request_context = message .message_metadata .request_context
259
- if request_context :
260
- headers = getattr (request_context , 'headers' , None )
304
+ if hasattr (message , "message_metadata" ):
305
+ message_metadata = message .message_metadata
306
+ if message .message_metadata and hasattr (message_metadata , "request_context" ):
307
+ request_context = message_metadata .request_context
308
+ if request_context and hasattr (request_context , "headers" ):
309
+ headers = request_context .headers
261
310
if headers :
262
311
return headers .get (self ._MCP_SESSION_ID_HEADER )
263
312
return None
264
313
except Exception :
265
314
return None
266
-
315
+
267
316
@staticmethod
268
317
def _generate_mcp_message_attrs (span : trace .Span , message , request_id : Optional [int ]) -> None :
269
318
import mcp .types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
@@ -326,4 +375,4 @@ def serialize(args: dict[str, Any]) -> str:
326
375
try :
327
376
return json .dumps (args )
328
377
except Exception :
329
- return ""
378
+ return "unknown_args "
0 commit comments