Skip to content

Commit 76c2405

Browse files
committed
cleanup code
1 parent dc2c730 commit 76c2405

File tree

2 files changed

+153
-148
lines changed

2 files changed

+153
-148
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py

Lines changed: 112 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
88

99
from opentelemetry import context, trace
10+
from opentelemetry.trace.status import Status, StatusCode
1011
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1112
from opentelemetry.instrumentation.utils import unwrap
1213
from opentelemetry.semconv.trace import SpanAttributes
@@ -15,20 +16,14 @@
1516
from .version import __version__
1617

1718
from .semconv import (
18-
CLIENT_INITIALIZED,
19-
MCP_METHOD_NAME,
20-
MCP_REQUEST_ARGUMENT,
21-
TOOLS_CALL,
22-
TOOLS_LIST,
23-
MCPAttributes,
24-
MCPOperations,
25-
MCPSpanNames,
19+
MCPSpanAttributes,
20+
MCPMethodNameValue,
2621
)
2722

2823

2924
class McpInstrumentor(BaseInstrumentor):
3025
"""
31-
An instrumentor class for MCP.
26+
An instrumentation class for MCP: https://modelcontextprotocol.io/overview.
3227
"""
3328

3429
def __init__(self, **kwargs):
@@ -40,20 +35,22 @@ def instrumentation_dependencies(self) -> Collection[str]:
4035
return ("mcp >= 1.8.1",)
4136

4237
def _instrument(self, **kwargs: Any) -> None:
43-
38+
# TODO: add instrumentation for Streamable Http transport
39+
# See: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
40+
4441
register_post_import_hook(
4542
lambda _: wrap_function_wrapper(
4643
"mcp.shared.session",
4744
"BaseSession.send_request",
48-
self._wrap_send_request,
45+
self._wrap_session_send_request,
4946
),
5047
"mcp.shared.session",
5148
)
5249
register_post_import_hook(
5350
lambda _: wrap_function_wrapper(
5451
"mcp.server.lowlevel.server",
5552
"Server._handle_request",
56-
self._wrap_handle_request,
53+
self._wrap_stdio_handle_request,
5754
),
5855
"mcp.server.lowlevel.server",
5956
)
@@ -62,17 +59,26 @@ def _uninstrument(self, **kwargs: Any) -> None:
6259
unwrap("mcp.shared.session", "BaseSession.send_request")
6360
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
6461

65-
def _wrap_send_request(
62+
def _wrap_session_send_request(
6663
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
6764
) -> Callable:
6865
import mcp.types as types
69-
7066

7167
"""
72-
Patches BaseSession.send_request which is responsible for sending requests from the client to the MCP server.
73-
This patched MCP client intercepts the request to obtain attributes for creating client-side span, extracts
74-
the current trace context, and embeds it into the request's params._meta.traceparent field
68+
Instruments MCP client-side stdio request sending, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
69+
70+
This is the master function responsible for sending requests from the client to the MCP server. See:
71+
https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/shared/session.py#L220
72+
73+
The instrumented MCP client intercepts the request to obtain attributes for creating client-side span, extracts
74+
the current trace context, and embeds it into the request's params._meta field
7575
before forwarding the request to the MCP server.
76+
77+
Args:
78+
wrapped: The original BaseSession.send_request method being instrumented
79+
instance: The BaseSession instance handling the stdio communication
80+
args: Positional arguments passed to the original send_request method, containing the ClientRequest
81+
kwargs: Keyword arguments passed to the original send_request method
7682
"""
7783

7884
async def async_wrapper():
@@ -81,88 +87,131 @@ async def async_wrapper():
8187
if not request:
8288
return await wrapped(*args, **kwargs)
8389

90+
request_id = None
91+
92+
if hasattr(instance, "_request_id"):
93+
request_id = instance._request_id
94+
8495
request_as_json = request.model_dump(by_alias=True, mode="json", exclude_none=True)
8596

8697
if "params" not in request_as_json:
8798
request_as_json["params"] = {}
88-
8999
if "_meta" not in request_as_json["params"]:
90100
request_as_json["params"]["_meta"] = {}
91101

92-
with self.tracer.start_as_current_span(
93-
MCPSpanNames.SPAN_MCP_CLIENT, kind=trace.SpanKind.CLIENT
94-
) as client_span:
102+
with self.tracer.start_as_current_span("span.mcp.client", kind=trace.SpanKind.CLIENT) as client_span:
95103

96-
if request:
97-
span_ctx = trace.set_span_in_context(client_span)
98-
parent_span = {}
99-
self.propagators.inject(carrier=parent_span, context=span_ctx)
104+
span_ctx = trace.set_span_in_context(client_span)
105+
parent_span = {}
106+
self.propagators.inject(carrier=parent_span, context=span_ctx)
100107

101-
McpInstrumentor._set_mcp_client_attributes(client_span, request)
108+
McpInstrumentor._configure_mcp_span(client_span, request, request_id)
109+
request_as_json["params"]["_meta"].update(parent_span)
102110

103-
request_as_json["params"]["_meta"].update(parent_span)
111+
# Reconstruct request object with injected trace context
112+
modified_request = request.model_validate(request_as_json)
113+
new_args = (modified_request,) + args[1:]
104114

105-
# Reconstruct request object with injected trace context
106-
modified_request = request.model_validate(request_as_json)
107-
new_args = (modified_request,) + args[1:]
115+
try:
116+
result = await wrapped(*new_args, **kwargs)
117+
client_span.set_status(Status(StatusCode.OK))
118+
return result
119+
except Exception as e:
120+
client_span.set_status(Status(StatusCode.ERROR, str(e)))
121+
client_span.record_exception(e)
122+
raise
108123

109-
return await wrapped(*new_args, **kwargs)
110124
return async_wrapper()
111125

112-
async def _wrap_handle_request(
126+
async def _wrap_stdio_handle_request(
113127
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
114128
) -> Any:
115-
req = args[1] if len(args) > 1 else None
129+
"""
130+
Instruments MCP server-side stdio request handling, see: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports
131+
132+
This is the core function responsible for processing incoming requests on the MCP server. See:
133+
https://github.com/modelcontextprotocol/python-sdk/blob/e68e513b428243057f9c4693e10162eb3bb52897/src/mcp/server/lowlevel/server.py#L616
134+
135+
The instrumented MCP server intercepts incoming requests to extract tracing context from
136+
the request's params._meta field, creates server-side spans linked to the originating client spans,
137+
and processes the request while maintaining trace continuity.
138+
139+
Args:
140+
wrapped: The original Server._handle_request method being instrumented
141+
instance: The MCP Server instance processing the stdio communication
142+
args: Positional arguments passed to the original _handle_request method, containing the incoming request
143+
kwargs: Keyword arguments passed to the original _handle_request method
144+
"""
145+
incoming_req = args[1] if len(args) > 1 else None
146+
request_id = None
116147
carrier = {}
117148

118-
if req and hasattr(req, "params") and req.params and hasattr(req.params, "meta") and req.params.meta:
119-
carrier = req.params.meta.model_dump()
149+
if incoming_req and hasattr(incoming_req, "id"):
150+
request_id = incoming_req.id
151+
if incoming_req and hasattr(incoming_req, "params") and hasattr(incoming_req.params, "meta"):
152+
carrier = incoming_req.params.meta.model_dump()
120153

121154
parent_ctx = self.propagators.extract(carrier=carrier)
122155

123156
if parent_ctx:
124157
with self.tracer.start_as_current_span(
125-
MCPSpanNames.SPAN_MCP_SERVER, kind=trace.SpanKind.SERVER, context=parent_ctx
126-
) as mcp_server_span:
127-
self._set_mcp_server_attributes(mcp_server_span, req)
128-
return await wrapped(*args, **kwargs)
158+
"span.mcp.server", kind=trace.SpanKind.SERVER, context=parent_ctx
159+
) as server_span:
129160

130-
@staticmethod
131-
def _set_mcp_client_attributes(span: trace.Span, request: Any) -> None:
132-
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
161+
self._configure_mcp_span(server_span, incoming_req, request_id)
133162

134-
if isinstance(request, types.ListToolsRequest):
135-
span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST)
136-
if isinstance(request, types.CallToolRequest):
137-
tool_name = request.params.name
138-
tool_arguments = request.params.arguments
139-
if tool_arguments:
140-
for arg_name, arg_val in tool_arguments.items():
141-
span.set_attribute(f"{MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val))
142-
span.update_name(f"{TOOLS_CALL} {tool_name}")
143-
span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL)
144-
span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name)
145-
if isinstance(request, types.InitializeRequest):
146-
span.set_attribute(MCP_METHOD_NAME, CLIENT_INITIALIZED)
163+
try:
164+
result = await wrapped(*args, **kwargs)
165+
server_span.set_status(Status(StatusCode.OK))
166+
return result
167+
except Exception as e:
168+
server_span.set_status(Status(StatusCode.ERROR, str(e)))
169+
server_span.record_exception(e)
170+
raise
147171

148172
@staticmethod
149-
def _set_mcp_server_attributes(span: trace.Span, request: Any) -> None:
173+
def _configure_mcp_span(span: trace.Span, request, request_id: Optional[str]) -> None:
150174
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
151175

176+
if hasattr(request, "root"):
177+
request = request.root
178+
179+
if request_id:
180+
span.set_attribute(MCPSpanAttributes.MCP_REQUEST_ID, request_id)
181+
152182
if isinstance(request, types.ListToolsRequest):
153-
span.set_attribute(MCP_METHOD_NAME, TOOLS_LIST)
183+
span.update_name(MCPMethodNameValue.TOOLS_LIST)
184+
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.TOOLS_LIST)
185+
return
186+
154187
if isinstance(request, types.CallToolRequest):
155188
tool_name = request.params.name
156-
span.update_name(f"{TOOLS_CALL} {tool_name}")
157-
span.set_attribute(MCP_METHOD_NAME, TOOLS_CALL)
158-
span.set_attribute(MCPAttributes.MCP_TOOL_NAME, tool_name)
159189

190+
span.update_name(f"{MCPMethodNameValue.TOOLS_CALL} {request.params.name}")
191+
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.TOOLS_CALL)
192+
span.set_attribute(MCPSpanAttributes.MCP_TOOL_NAME, tool_name)
160193

161-
@staticmethod
194+
if hasattr(request.params, "arguments"):
195+
for arg_name, arg_val in request.params.arguments.items():
196+
span.set_attribute(
197+
f"{MCPSpanAttributes.MCP_REQUEST_ARGUMENT}.{arg_name}", McpInstrumentor.serialize(arg_val)
198+
)
199+
200+
if isinstance(request, types.InitializeRequest):
201+
span.update_name(MCPMethodNameValue.INITIALIZED)
202+
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.INITIALIZED)
203+
204+
if isinstance(request, types.CancelledNotification):
205+
span.update_name(MCPMethodNameValue.NOTIFICATIONS_CANCELLED)
206+
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.NOTIFICATIONS_CANCELLED)
207+
208+
if isinstance(request, types.ToolListChangedNotification):
209+
span.update_name(MCPMethodNameValue.NOTIFICATIONS_CANCELLED)
210+
span.set_attribute(MCPSpanAttributes.MCP_METHOD_NAME, MCPMethodNameValue.NOTIFICATIONS_CANCELLED)
211+
212+
@staticmethod
162213
def serialize(args):
163214
try:
164215
return json.dumps(args)
165216
except Exception:
166217
return str(args)
167-
168-

0 commit comments

Comments
 (0)