Skip to content

Commit 3902c60

Browse files
committed
add span support for notifications
1 parent 323b87a commit 3902c60

File tree

1 file changed

+84
-81
lines changed
  • aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp

1 file changed

+84
-81
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py

Lines changed: 84 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
88

99
from opentelemetry import context, trace
10-
from opentelemetry.trace.status import Status, StatusCode
10+
from opentelemetry.trace import SpanKind, Status, StatusCode
1111
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1212
from opentelemetry.instrumentation.utils import unwrap
1313
from opentelemetry.semconv.trace import SpanAttributes
@@ -22,6 +22,9 @@
2222

2323

2424
class McpInstrumentor(BaseInstrumentor):
25+
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
26+
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
27+
2528
"""
2629
An instrumentation class for MCP: https://modelcontextprotocol.io/overview
2730
"""
@@ -35,19 +38,22 @@ def instrumentation_dependencies(self) -> Collection[str]:
3538
return ("mcp >= 1.8.1",)
3639

3740
def _instrument(self, **kwargs: Any) -> None:
38-
# TODO: add instrumentation for Streamable Http transport
39-
# See: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
40-
4141
register_post_import_hook(
4242
lambda _: wrap_function_wrapper(
4343
"mcp.shared.session",
4444
"BaseSession.send_request",
45-
self._wrap_session_send_request,
45+
self._wrap_session_send,
46+
),
47+
"mcp.shared.session",
48+
)
49+
register_post_import_hook(
50+
lambda _: wrap_function_wrapper(
51+
"mcp.shared.session",
52+
"BaseSession.send_notification",
53+
self._wrap_session_send,
4654
),
4755
"mcp.shared.session",
4856
)
49-
50-
5157
register_post_import_hook(
5258
lambda _: wrap_function_wrapper(
5359
"mcp.server.lowlevel.server",
@@ -61,69 +67,50 @@ def _uninstrument(self, **kwargs: Any) -> None:
6167
unwrap("mcp.shared.session", "BaseSession.send_request")
6268
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
6369

64-
def _wrap_session_send_request(
70+
def _wrap_session_send(
6571
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
6672
) -> Callable:
6773
import mcp.types as types
6874

69-
"""
70-
Instruments MCP client-side request sending for both stdio and Streamable HTTP transport,
71-
see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
72-
73-
This is the master function responsible for sending requests from the client to the MCP server.
74-
See:
75-
- https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220
76-
- https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/client/session_group.py#L233
77-
78-
The instrumented MCP client intercepts the request to obtain attributes for creating client-side span, extracts
79-
the current trace context, and embeds it into the request's params._meta field
80-
before forwarding the request to the MCP server.
81-
82-
Args:
83-
wrapped: The original BaseSession.send_request method being instrumented
84-
instance: The BaseSession instance handling the stdio communication
85-
args: Positional arguments passed to the original send_request method, containing the ClientRequest
86-
kwargs: Keyword arguments passed to the original send_request method
87-
"""
88-
8975
async def async_wrapper():
90-
request: Optional[types.ClientRequest] = args[0] if len(args) > 0 else None
91-
92-
if not request:
76+
message = args[0] if len(args) > 0 else None
77+
if not message:
9378
return await wrapped(*args, **kwargs)
9479

95-
request_id = None
80+
is_client = isinstance(message, (types.ClientRequest, types.ClientNotification))
81+
request_id: Optional[int] = getattr(instance, "_request_id", None)
82+
span_name = self._DEFAULT_SERVER_SPAN_NAME
83+
span_kind = SpanKind.SERVER
9684

97-
if hasattr(instance, "_request_id"):
98-
request_id = instance._request_id
85+
if is_client:
86+
span_name = self._DEFAULT_CLIENT_SPAN_NAME
87+
span_kind = SpanKind.CLIENT
9988

100-
request_as_json = request.model_dump(by_alias=True, mode="json", exclude_none=True)
89+
message_json = message.model_dump(by_alias=True, mode="json", exclude_none=True)
10190

102-
if "params" not in request_as_json:
103-
request_as_json["params"] = {}
104-
if "_meta" not in request_as_json["params"]:
105-
request_as_json["params"]["_meta"] = {}
91+
if "params" not in message_json:
92+
message_json["params"] = {}
93+
if "_meta" not in message_json["params"]:
94+
message_json["params"]["_meta"] = {}
10695

107-
with self.tracer.start_as_current_span("span.mcp.client", kind=trace.SpanKind.CLIENT) as client_span:
96+
with self.tracer.start_as_current_span(name=span_name, kind=span_kind) as span:
97+
ctx = trace.set_span_in_context(span)
98+
carrier = {}
99+
self.propagators.inject(carrier=carrier, context=ctx)
100+
message_json["params"]["_meta"].update(carrier)
108101

109-
span_ctx = trace.set_span_in_context(client_span)
110-
parent_span = {}
111-
self.propagators.inject(carrier=parent_span, context=span_ctx)
102+
McpInstrumentor._generate_mcp_req_attrs(span, message, request_id)
112103

113-
McpInstrumentor._generate_mcp_span_attrs(client_span, request, request_id)
114-
request_as_json["params"]["_meta"].update(parent_span)
115-
116-
# Reconstruct request object with injected trace context
117-
modified_request = request.model_validate(request_as_json)
118-
new_args = (modified_request,) + args[1:]
104+
modified_message = message.model_validate(message_json)
105+
new_args = (modified_message,) + args[1:]
119106

120107
try:
121108
result = await wrapped(*new_args, **kwargs)
122-
client_span.set_status(Status(StatusCode.OK))
109+
span.set_status(Status(StatusCode.OK))
123110
return result
124111
except Exception as e:
125-
client_span.set_status(Status(StatusCode.ERROR, str(e)))
126-
client_span.record_exception(e)
112+
span.set_status(Status(StatusCode.ERROR, str(e)))
113+
span.record_exception(e)
127114
raise
128115

129116
return async_wrapper()
@@ -132,10 +119,10 @@ async def _wrap_server_handle_request(
132119
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
133120
) -> Any:
134121
"""
135-
Instruments MCP server-side request handling for both stdio and Streamable HTTP transport,
122+
Instruments MCP server-side request handling for both stdio and Streamable HTTP transport,
136123
see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
137124
138-
This is the core function responsible for processing incoming requests on the MCP server.
125+
This is the core function responsible for processing incoming requests on the MCP server.
139126
See:
140127
https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616
141128
@@ -150,74 +137,90 @@ async def _wrap_server_handle_request(
150137
kwargs: Keyword arguments passed to the original _handle_request method
151138
"""
152139
incoming_req = args[1] if len(args) > 1 else None
140+
141+
if not incoming_req:
142+
return await wrapped(*args, **kwargs)
143+
153144
request_id = None
154145
carrier = {}
155146

156-
if incoming_req and hasattr(incoming_req, "id"):
147+
if hasattr(incoming_req, "id") and incoming_req.id:
157148
request_id = incoming_req.id
158-
if incoming_req and hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta"):
149+
if hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta") and incoming_req.meta:
159150
carrier = incoming_req.params.meta.model_dump()
160151

152+
# If MCP client is instrumented then params._meta field will contain the
153+
# parent trace context.
161154
parent_ctx = self.propagators.extract(carrier=carrier)
162155

163-
if parent_ctx:
164-
with self.tracer.start_as_current_span(
165-
"span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx
166-
) as server_span:
156+
with self.tracer.start_as_current_span(
157+
"span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx
158+
) as server_span:
167159

168-
self._generate_mcp_span_attrs(server_span, incoming_req, request_id)
160+
self._generate_mcp_req_attrs(server_span, incoming_req, request_id)
169161

170-
try:
171-
result = await wrapped(*args, **kwargs)
172-
server_span.set_status(Status(StatusCode.OK))
173-
return result
174-
except Exception as e:
175-
server_span.set_status(Status(StatusCode.ERROR, str(e)))
176-
server_span.record_exception(e)
177-
raise
162+
try:
163+
result = await wrapped(*args, **kwargs)
164+
server_span.set_status(Status(StatusCode.OK))
165+
return result
166+
except Exception as e:
167+
server_span.set_status(Status(StatusCode.ERROR, str(e)))
168+
server_span.record_exception(e)
169+
raise
178170

179171
@staticmethod
180-
def _generate_mcp_span_attrs(span: trace.Span, request, request_id: Optional[str]) -> None:
172+
def _generate_mcp_req_attrs(span: trace.Span, request, request_id: Optional[int]) -> None:
181173
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
182174

183-
# Client-side: request is of type ClientRequest which contains the Union of different RootModel types
184-
# Server-side: request is passed the RootModel
175+
"""
176+
Populates the given span with MCP semantic convention attributes based on the request type.
177+
These semantic conventions are based off: https://github.com/open-telemetry/semantic-conventions/pull/2083
178+
which are currently in development and are considered unstable.
179+
180+
Args:
181+
span: The MCP span to be enriched with MCP attributes
182+
request: The MCP request object, from Client Side it is of type ClientRequestModel and from server side it's of type RootModel
183+
request_id: Unique identifier for the request. In theory, this should never be Optional since all requests made from MCP client to server will contain a request id.
184+
"""
185+
186+
# Client-side request type will be ClientRequest which has root as field
187+
# Server-side: request type will be the root object passed from ClientRequest
185188
# See: https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/types.py#L1220
186189
if hasattr(request, "root"):
187190
request = request.root
188191

189192
if request_id:
190193
span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id)
191-
194+
192195
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, request.method)
193196

194197
if isinstance(request, types.CallToolRequest):
195198
tool_name = request.params.name
196199
span.update_name(f"{MCPMethodValue.TOOLS_CALL} {tool_name}")
197200
span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name)
198-
201+
request.params.arguments
199202
if request.params.arguments:
200203
for arg_name, arg_val in request.params.arguments.items():
201204
span.set_attribute(
202205
f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val)
203206
)
204-
return
207+
return
205208
if isinstance(request, types.GetPromptRequest):
206209
prompt_name = request.params.name
207210
span.update_name(f"{MCPMethodValue.PROMPTS_GET} {prompt_name}")
208211
span.set_attribute(MCPSpanAttributes.MCP_PROMPT_NAME, prompt_name)
209-
return
212+
return
210213
if isinstance(request, (types.ReadResourceRequest, types.SubscribeRequest, types.UnsubscribeRequest)):
211214
resource_uri = str(request.params.uri)
212215
span.update_name(f"{MCPSpanAttributes.MCP_RESOURCE_URI} {resource_uri}")
213216
span.set_attribute(MCPSpanAttributes.MCP_RESOURCE_URI, resource_uri)
214-
return
215-
217+
return
218+
216219
span.update_name(request.method)
217-
220+
218221
@staticmethod
219-
def serialize(args):
222+
def serialize(args: dict[str, Any]) -> str:
220223
try:
221224
return json.dumps(args)
222225
except Exception:
223-
return str(args)
226+
return ""

0 commit comments

Comments
 (0)