Skip to content

Commit 51aaba9

Browse files
committed
add support for notifications + refactoring
1 parent 3902c60 commit 51aaba9

File tree

1 file changed

+121
-45
lines changed
  • aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp

1 file changed

+121
-45
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from dataclasses import dataclass
44
import json
5-
from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, cast
5+
from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, Union, cast
66

7-
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
7+
from wrapt import register_post_import_hook, wrap_function_wrapper
88

9-
from opentelemetry import context, trace
9+
from opentelemetry import trace
1010
from opentelemetry.trace import SpanKind, Status, StatusCode
1111
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1212
from opentelemetry.instrumentation.utils import unwrap
@@ -22,13 +22,13 @@
2222

2323

2424
class McpInstrumentor(BaseInstrumentor):
25-
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
26-
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
27-
2825
"""
2926
An instrumentation class for MCP: https://modelcontextprotocol.io/overview
3027
"""
3128

29+
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
30+
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
31+
3232
def __init__(self, **kwargs):
3333
super().__init__()
3434
self.propagators = kwargs.get("propagators") or get_global_textmap()
@@ -62,29 +62,56 @@ def _instrument(self, **kwargs: Any) -> None:
6262
),
6363
"mcp.server.lowlevel.server",
6464
)
65+
register_post_import_hook(
66+
lambda _: wrap_function_wrapper(
67+
"mcp.server.lowlevel.server",
68+
"Server._handle_notification",
69+
self._wrap_server_handle_notification,
70+
),
71+
"mcp.server.lowlevel.server",
72+
)
6573

6674
def _uninstrument(self, **kwargs: Any) -> None:
6775
unwrap("mcp.shared.session", "BaseSession.send_request")
76+
unwrap("mcp.shared.session", "BaseSession.send_notification")
6877
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
78+
unwrap("mcp.server.lowlevel.server", "Server._handle_notification")
6979

7080
def _wrap_session_send(
7181
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
7282
) -> Callable:
73-
import mcp.types as types
83+
"""
84+
Instruments MCP client and server request/notification sending for both stdio and Streamable HTTP transport,
85+
see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
86+
87+
See:
88+
- https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220
89+
- https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L296
90+
91+
This instrumentation intercepts the requests/notification messages sent between client and server to obtain attributes for creating span, injects
92+
the current trace context, and embeds it into the request's params._meta field before forwarding the request to the MCP server.
93+
94+
Args:
95+
wrapped: The original BaseSession.send_request/send_notification method
96+
instance: The BaseSession instance
97+
args: Positional arguments passed to the original send_request/send_notification method
98+
kwargs: Keyword arguments passed to the original send_request/send_notification method
99+
"""
100+
from mcp.types import ClientRequest, ClientNotification, ServerRequest, ServerNotification
74101

75102
async def async_wrapper():
76-
message = args[0] if len(args) > 0 else None
103+
message: Optional[Union[ClientRequest, ClientNotification, ServerRequest, ServerNotification]] = (
104+
args[0] if len(args) > 0 else None
105+
)
106+
77107
if not message:
78108
return await wrapped(*args, **kwargs)
79109

80-
is_client = isinstance(message, (types.ClientRequest, types.ClientNotification))
81110
request_id: Optional[int] = getattr(instance, "_request_id", None)
82-
span_name = self._DEFAULT_SERVER_SPAN_NAME
83-
span_kind = SpanKind.SERVER
111+
span_name, span_kind = self._DEFAULT_SERVER_SPAN_NAME, SpanKind.SERVER
84112

85-
if is_client:
86-
span_name = self._DEFAULT_CLIENT_SPAN_NAME
87-
span_kind = SpanKind.CLIENT
113+
if isinstance(message, (ClientRequest, ClientNotification)):
114+
span_name, span_kind = self._DEFAULT_CLIENT_SPAN_NAME, SpanKind.CLIENT
88115

89116
message_json = message.model_dump(by_alias=True, mode="json", exclude_none=True)
90117

@@ -99,7 +126,7 @@ async def async_wrapper():
99126
self.propagators.inject(carrier=carrier, context=ctx)
100127
message_json["params"]["_meta"].update(carrier)
101128

102-
McpInstrumentor._generate_mcp_req_attrs(span, message, request_id)
129+
McpInstrumentor._generate_mcp_message_attrs(span, message, request_id)
103130

104131
modified_message = message.model_validate(message_json)
105132
new_args = (modified_message,) + args[1:]
@@ -126,38 +153,79 @@ async def _wrap_server_handle_request(
126153
See:
127154
https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616
128155
129-
The instrumented MCP server intercepts incoming requests to extract tracing context from
130-
the request's params._meta field, creates server-side spans linked to the originating client spans,
131-
and processes the request while maintaining trace continuity.
132-
133156
Args:
134157
wrapped: The original Server._handle_request method being instrumented
135158
instance: The MCP Server instance processing the stdio communication
136159
args: Positional arguments passed to the original _handle_request method, containing the incoming request
137160
kwargs: Keyword arguments passed to the original _handle_request method
138161
"""
139162
incoming_req = args[1] if len(args) > 1 else None
163+
return await self._wrap_server_message_handler(wrapped, instance, args, kwargs, incoming_msg=incoming_req)
164+
165+
async def _wrap_server_handle_notification(
166+
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
167+
) -> Any:
168+
"""
169+
Instruments MCP server-side notification handling for both stdio and Streamable HTTP transport,
170+
This is the core function responsible for processing incoming notifications on the MCP server instance.
171+
See:
172+
https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616
173+
174+
Args:
175+
wrapped: The original Server._handle_notification method being instrumented
176+
instance: The MCP Server instance processing the stdio communication
177+
args: Positional arguments passed to the original _handle_request method, containing the incoming request
178+
kwargs: Keyword arguments passed to the original _handle_request method
179+
"""
180+
incoming_notif = args[0] if len(args) > 0 else None
181+
return await self._wrap_server_message_handler(wrapped, instance, args, kwargs, incoming_msg=incoming_notif)
182+
183+
async def _wrap_server_message_handler(
184+
self,
185+
wrapped: Callable,
186+
instance: Any,
187+
args: Tuple[Any, ...],
188+
kwargs: Dict[str, Any],
189+
incoming_msg: Optional[Any],
190+
) -> Any:
191+
"""
192+
Instruments MCP server-side request/notification handling for both stdio and Streamable HTTP transport,
193+
see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
194+
195+
See:
196+
https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616
197+
198+
The instrumented MCP server intercepts incoming requests/notification messages from the client to extract tracing context from
199+
the messages's params._meta field and creates server-side spans linked to the originating client spans.
140200
141-
if not incoming_req:
201+
Args:
202+
wrapped: The original Server._handle_notification/_handle_request method being instrumented
203+
instance: The Server instance
204+
args: Positional arguments passed to the original _handle_request/ method, containing the incoming request
205+
kwargs: Keyword arguments passed to the original _handle_request method
206+
incoming_msg: The incoming message from the client, can be one of: ClientRequest or ClientNotification
207+
"""
208+
if not incoming_msg:
142209
return await wrapped(*args, **kwargs)
143210

144211
request_id = None
145212
carrier = {}
146213

147-
if hasattr(incoming_req, "id") and incoming_req.id:
148-
request_id = incoming_req.id
149-
if hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta") and incoming_req.meta:
150-
carrier = incoming_req.params.meta.model_dump()
214+
# Request IDs are only present in Request messages not Notifications.
215+
if hasattr(incoming_msg, "id") and incoming_msg.id:
216+
request_id = incoming_msg.id
217+
218+
# If the client is instrumented then params._meta field will contain the trace context.
219+
if hasattr(incoming_msg, "params") and hasattr(incoming_msg.params, "meta") and incoming_msg.params.meta:
220+
carrier = incoming_msg.params.meta.model_dump()
151221

152-
# If MCP client is instrumented then params._meta field will contain the
153-
# parent trace context.
154222
parent_ctx = self.propagators.extract(carrier=carrier)
155223

156224
with self.tracer.start_as_current_span(
157-
"span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx
225+
self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx
158226
) as server_span:
159227

160-
self._generate_mcp_req_attrs(server_span, incoming_req, request_id)
228+
self._generate_mcp_message_attrs(server_span, incoming_msg, request_id)
161229

162230
try:
163231
result = await wrapped(*args, **kwargs)
@@ -169,54 +237,62 @@ async def _wrap_server_handle_request(
169237
raise
170238

171239
@staticmethod
172-
def _generate_mcp_req_attrs(span: trace.Span, request, request_id: Optional[int]) -> None:
240+
def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None:
173241
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
174242

175243
"""
176-
Populates the given span with MCP semantic convention attributes based on the request type.
244+
Populates the given span with MCP semantic convention attributes based on the message type.
177245
These semantic conventions are based off: https://github.com/open-telemetry/semantic-conventions/pull/2083
178246
which are currently in development and are considered unstable.
179247
180248
Args:
181249
span: The MCP span to be enriched with MCP attributes
182-
request: The MCP request object, from Client Side it is of type ClientRequestModel and from server side it's of type RootModel
183-
request_id: Unique identifier for the request. In theory, this should never be Optional since all requests made from MCP client to server will contain a request id.
250+
message: The MCP message object, from client side it is of type ClientRequestModel/ClientNotificationModel and from server side it gets passed as type RootModel
251+
request_id: Unique identifier for the request or None if the message is a notification.
184252
"""
185253

186254
# Client-side request type will be ClientRequest which has root as field
187255
# Server-side: request type will be the root object passed from ClientRequest
188256
# See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/types.py#L1220
189-
if hasattr(request, "root"):
190-
request = request.root
257+
if hasattr(message, "root"):
258+
message = message.root
191259

192260
if request_id:
193261
span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id)
194262

195-
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, request.method)
263+
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, message.method)
196264

197-
if isinstance(request, types.CallToolRequest):
198-
tool_name = request.params.name
265+
if isinstance(message, types.CallToolRequest):
266+
tool_name = message.params.name
199267
span.update_name(f"{MCPMethodValue.TOOLS_CALL} {tool_name}")
200268
span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name)
201-
request.params.arguments
202-
if request.params.arguments:
203-
for arg_name, arg_val in request.params.arguments.items():
269+
message.params.arguments
270+
if message.params.arguments:
271+
for arg_name, arg_val in message.params.arguments.items():
204272
span.set_attribute(
205273
f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val)
206274
)
207275
return
208-
if isinstance(request, types.GetPromptRequest):
209-
prompt_name = request.params.name
276+
if isinstance(message, types.GetPromptRequest):
277+
prompt_name = message.params.name
210278
span.update_name(f"{MCPMethodValue.PROMPTS_GET} {prompt_name}")
211279
span.set_attribute(MCPSpanAttributes.MCP_PROMPT_NAME, prompt_name)
212280
return
213-
if isinstance(request, (types.ReadResourceRequest, types.SubscribeRequest, types.UnsubscribeRequest)):
214-
resource_uri = str(request.params.uri)
281+
if isinstance(
282+
message,
283+
(
284+
types.ReadResourceRequest,
285+
types.SubscribeRequest,
286+
types.UnsubscribeRequest,
287+
types.ResourceUpdatedNotification,
288+
),
289+
):
290+
resource_uri = str(message.params.uri)
215291
span.update_name(f"{MCPSpanAttributes.MCP_RESOURCE_URI} {resource_uri}")
216292
span.set_attribute(MCPSpanAttributes.MCP_RESOURCE_URI, resource_uri)
217293
return
218294

219-
span.update_name(request.method)
295+
span.update_name(message.method)
220296

221297
@staticmethod
222298
def serialize(args: dict[str, Any]) -> str:

0 commit comments

Comments
 (0)