11# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
33import logging
4- from typing import Any , Collection
4+ from typing import Any , Callable , Collection , Dict , Tuple
55
6+ from mcp import ClientRequest
67from wrapt import register_post_import_hook , wrap_function_wrapper
78
89from opentelemetry import trace
@@ -29,12 +30,46 @@ class MCPInstrumentor(BaseInstrumentor):
2930 An instrumenter for MCP.
3031 """
3132
32- def __init__ (self ): # pylint: disable=no-self-use
33+ def __init__ (self ):
3334 super ().__init__ ()
3435 self .tracer = None
3536
37+ @staticmethod
38+ def instrumentation_dependencies () -> Collection [str ]:
39+ return _instruments
40+
41+ def _instrument (self , ** kwargs : Any ) -> None :
42+ tracer_provider = kwargs .get ("tracer_provider" )
43+ if tracer_provider :
44+ self .tracer = tracer_provider .get_tracer ("mcp" )
45+ else :
46+ self .tracer = trace .get_tracer ("mcp" )
47+ register_post_import_hook (
48+ lambda _ : wrap_function_wrapper (
49+ "mcp.shared.session" ,
50+ "BaseSession.send_request" ,
51+ self ._wrap_send_request ,
52+ ),
53+ "mcp.shared.session" ,
54+ )
55+ register_post_import_hook (
56+ lambda _ : wrap_function_wrapper (
57+ "mcp.server.lowlevel.server" ,
58+ "Server._handle_request" ,
59+ self ._wrap_handle_request ,
60+ ),
61+ "mcp.server.lowlevel.server" ,
62+ )
63+
64+ @staticmethod
65+ def _uninstrument (** kwargs : Any ) -> None :
66+ unwrap ("mcp.shared.session" , "BaseSession.send_request" )
67+ unwrap ("mcp.server.lowlevel.server" , "Server._handle_request" )
68+
3669 # Send Request Wrapper
37- def _wrap_send_request (self , wrapped , instance , args , kwargs ): # pylint: disable=no-self-use
70+ def _wrap_send_request (
71+ self , wrapped : Callable , instance : Any , args : Tuple [Any , ...], kwargs : Dict [str , Any ]
72+ ) -> Callable :
3873 """
3974 Changes made:
4075 The wrapper intercepts the request before sending, injects distributed tracing context into the
@@ -43,14 +78,14 @@ def _wrap_send_request(self, wrapped, instance, args, kwargs): # pylint: disabl
4378 type and calling the original function with identical parameters.
4479 """
4580
46- async def async_wrapper (): # pylint: disable=no-self-use
81+ async def async_wrapper ():
4782 with self .tracer .start_as_current_span ("client.send_request" , kind = trace .SpanKind .CLIENT ) as span :
4883 span_ctx = span .get_span_context ()
4984 request = args [0 ] if len (args ) > 0 else kwargs .get ("request" )
5085 if request :
5186 req_root = request .root if hasattr (request , "root" ) else request
5287
53- self .handle_attributes (span , req_root , True )
88+ self ._generate_mcp_attributes (span , req_root , is_client = True )
5489 request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
5590 self ._inject_trace_context (request_data , span_ctx )
5691 # Reconstruct request object with injected trace context
@@ -68,7 +103,9 @@ async def async_wrapper(): # pylint: disable=no-self-use
68103 return async_wrapper ()
69104
70105 # Handle Request Wrapper
71- async def _wrap_handle_request (self , wrapped , instance , args , kwargs ): # pylint: disable=no-self-use
106+ async def _wrap_handle_request (
107+ self , wrapped : Callable , instance : Any , args : Tuple [Any , ...], kwargs : Dict [str , Any ]
108+ ) -> Any :
72109 """
73110 Changes made:
74111 This wrapper intercepts requests before processing, extracts distributed tracing context from
@@ -87,19 +124,35 @@ async def _wrap_handle_request(self, wrapped, instance, args, kwargs): # pylint
87124 traceparent = getattr (req .params .meta , "traceparent" , None )
88125 span_context = self ._extract_span_context_from_traceparent (traceparent ) if traceparent else None
89126 if span_context :
90- span_name = self ._get_span_name (req )
127+ span_name = self ._get_mcp_operation (req )
91128 with self .tracer .start_as_current_span (
92129 span_name ,
93130 kind = trace .SpanKind .SERVER ,
94131 context = trace .set_span_in_context (trace .NonRecordingSpan (span_context )),
95132 ) as span :
96- self .handle_attributes (span , req , False )
133+ self ._generate_mcp_attributes (span , req , False )
97134 result = await wrapped (* args , ** kwargs )
98135 return result
99136 else :
100137 return await wrapped (* args , ** kwargs )
101138
102- def _inject_trace_context (self , request_data , span_ctx ): # pylint: disable=no-self-use
139+ def _generate_mcp_attributes (self , span : trace .Span , request : ClientRequest , is_client : bool ) -> None :
140+ import mcp .types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
141+
142+ operation = "UnknownOperation"
143+ if isinstance (request , types .ListToolsRequest ):
144+ operation = "ListTool"
145+ span .set_attribute ("mcp.list_tools" , True )
146+ elif isinstance (request , types .CallToolRequest ):
147+ operation = request .params .name
148+ span .set_attribute ("mcp.call_tool" , True )
149+ if is_client :
150+ self ._add_client_attributes (span , operation , request )
151+ else :
152+ self ._add_server_attributes (span , operation , request )
153+
154+ @staticmethod
155+ def _inject_trace_context (request_data : Dict [str , Any ], span_ctx ) -> None :
103156 if "params" not in request_data :
104157 request_data ["params" ] = {}
105158 if "_meta" not in request_data ["params" ]:
@@ -110,7 +163,8 @@ def _inject_trace_context(self, request_data, span_ctx): # pylint: disable=no-s
110163 traceparent = f"00-{ trace_id_hex } -{ span_id_hex } -{ trace_flags } "
111164 request_data ["params" ]["_meta" ]["traceparent" ] = traceparent
112165
113- def _extract_span_context_from_traceparent (self , traceparent ): # pylint: disable=no-self-use
166+ @staticmethod
167+ def _extract_span_context_from_traceparent (traceparent : str ):
114168 parts = traceparent .split ("-" )
115169 if len (parts ) == 4 :
116170 try :
@@ -127,72 +181,29 @@ def _extract_span_context_from_traceparent(self, traceparent): # pylint: disabl
127181 return None
128182 return None
129183
130- def _get_span_name ( self , req ): # pylint: disable=no-self-use
131- span_name = "unknown"
184+ @ staticmethod
185+ def _get_mcp_operation ( req : ClientRequest ) -> str :
132186 import mcp .types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
133187
188+ span_name = "unknown"
189+
134190 if isinstance (req , types .ListToolsRequest ):
135191 span_name = "tools/list"
136192 elif isinstance (req , types .CallToolRequest ):
137- if hasattr (req , "params" ) and hasattr (req .params , "name" ):
138- span_name = f"tools/{ req .params .name } "
139- else :
140- span_name = "unknown"
193+ span_name = f"tools/{ req .params .name } "
141194 return span_name
142195
143- def handle_attributes (self , span , request , is_client = True ): # pylint: disable=no-self-use
144- import mcp .types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
145-
146- operation = self ._get_span_name (request )
147- if isinstance (request , types .ListToolsRequest ):
148- operation = "ListTool"
149- span .set_attribute ("mcp.list_tools" , True )
150- elif isinstance (request , types .CallToolRequest ):
151- if hasattr (request , "params" ) and hasattr (request .params , "name" ):
152- operation = request .params .name
153- span .set_attribute ("mcp.call_tool" , True )
154- if is_client :
155- self ._add_client_attributes (span , operation , request )
156- else :
157- self ._add_server_attributes (span , operation , request )
196+ @staticmethod
197+ def _add_client_attributes (span : trace .Span , operation : str , request : ClientRequest ) -> None :
198+ import os # pylint: disable=import-outside-toplevel
158199
159- def _add_client_attributes ( self , span , operation , request ): # pylint: disable=no-self-use
160- span .set_attribute ("aws.remote.service" , "Appsignals MCP Server" )
200+ service_name = os . environ . get ( "MCP_SERVICE_NAME" , "Generic MCP Server" )
201+ span .set_attribute ("aws.remote.service" , service_name )
161202 span .set_attribute ("aws.remote.operation" , operation )
162- if hasattr (request , "params" ) and hasattr (request .params , "name" ):
203+ if hasattr (request , "params" ) and request . params and hasattr (request .params , "name" ):
163204 span .set_attribute ("tool.name" , request .params .name )
164205
165- def _add_server_attributes ( self , span , operation , request ): # pylint: disable=no-self-use
166- span . set_attribute ( "server_side" , True )
167- if hasattr (request , "params" ) and hasattr (request .params , "name" ):
206+ @ staticmethod
207+ def _add_server_attributes ( span : trace . Span , operation : str , request : ClientRequest ) -> None :
208+ if hasattr (request , "params" ) and request . params and hasattr (request .params , "name" ):
168209 span .set_attribute ("tool.name" , request .params .name )
169-
170- def instrumentation_dependencies (self ) -> Collection [str ]: # pylint: disable=no-self-use
171- return _instruments
172-
173- def _instrument (self , ** kwargs : Any ) -> None : # pylint: disable=no-self-use
174- tracer_provider = kwargs .get ("tracer_provider" )
175- if tracer_provider :
176- self .tracer = tracer_provider .get_tracer ("mcp" )
177- else :
178- self .tracer = trace .get_tracer ("mcp" )
179- register_post_import_hook (
180- lambda _ : wrap_function_wrapper (
181- "mcp.shared.session" ,
182- "BaseSession.send_request" ,
183- self ._wrap_send_request ,
184- ),
185- "mcp.shared.session" ,
186- )
187- register_post_import_hook (
188- lambda _ : wrap_function_wrapper (
189- "mcp.server.lowlevel.server" ,
190- "Server._handle_request" ,
191- self ._wrap_handle_request ,
192- ),
193- "mcp.server.lowlevel.server" ,
194- )
195-
196- def _uninstrument (self , ** kwargs : Any ) -> None : # pylint: disable=no-self-use
197- unwrap ("mcp.shared.session" , "BaseSession.send_request" )
198- unwrap ("mcp.server.lowlevel.server" , "Server._handle_request" )
0 commit comments