Skip to content

Commit de1020e

Browse files
committed
Revert "add further trace propagation for responses"
This reverts commit 5c1aeec.
1 parent 5c1aeec commit de1020e

File tree

1 file changed

+25
-74
lines changed
  • aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp

1 file changed

+25
-74
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py

Lines changed: 25 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
from dataclasses import dataclass
34
import json
4-
from typing import Any, Callable, Collection, Dict, Optional, Tuple, Union
5+
from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, Union, cast
56

67
from wrapt import register_post_import_hook, wrap_function_wrapper
78

89
from opentelemetry import trace
10+
from opentelemetry.trace import SpanKind, Status, StatusCode
911
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1012
from opentelemetry.instrumentation.utils import unwrap
11-
from opentelemetry.propagate import get_global_textmap
12-
from opentelemetry.semconv.attributes.network_attributes import NetworkTransportValues
1313
from opentelemetry.semconv.trace import SpanAttributes
14-
from opentelemetry.trace import INVALID_SPAN, SpanKind, Status, StatusCode
14+
from opentelemetry.propagate import get_global_textmap
15+
16+
from .version import __version__
1517

1618
from .semconv import (
17-
MCPMethodValue,
1819
MCPSpanAttributes,
20+
MCPMethodValue,
1921
)
20-
from .version import __version__
2122

2223

2324
class McpInstrumentor(BaseInstrumentor):
@@ -28,7 +29,6 @@ class McpInstrumentor(BaseInstrumentor):
2829
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
2930
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
3031
_MCP_SESSION_ID_HEADER = "mcp-session-id"
31-
_MCP_META_FIELD = "_meta"
3232

3333
def __init__(self, **kwargs):
3434
super().__init__()
@@ -55,14 +55,6 @@ def _instrument(self, **kwargs: Any) -> None:
5555
),
5656
"mcp.shared.session",
5757
)
58-
register_post_import_hook(
59-
lambda _: wrap_function_wrapper(
60-
"mcp.shared.session",
61-
"BaseSession._send_response",
62-
self._wrap_send_response,
63-
),
64-
"mcp.shared.session",
65-
)
6658
register_post_import_hook(
6759
lambda _: wrap_function_wrapper(
6860
"mcp.server.lowlevel.server",
@@ -83,7 +75,6 @@ def _instrument(self, **kwargs: Any) -> None:
8375
def _uninstrument(self, **kwargs: Any) -> None:
8476
unwrap("mcp.shared.session", "BaseSession.send_request")
8577
unwrap("mcp.shared.session", "BaseSession.send_notification")
86-
unwrap("mcp.shared.session", "BaseSession._send_response")
8778
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
8879
unwrap("mcp.server.lowlevel.server", "Server._handle_notification")
8980

@@ -107,7 +98,7 @@ def _wrap_session_send(
10798
args: Positional arguments passed to the original send_request/send_notification method
10899
kwargs: Keyword arguments passed to the original send_request/send_notification method
109100
"""
110-
from mcp.types import ClientNotification, ClientRequest, ServerNotification, ServerRequest
101+
from mcp.types import ClientRequest, ClientNotification, ServerRequest, ServerNotification
111102

112103
async def async_wrapper():
113104
message: Optional[Union[ClientRequest, ClientNotification, ServerRequest, ServerNotification]] = (
@@ -129,18 +120,14 @@ async def async_wrapper():
129120

130121
if "params" not in message_json:
131122
message_json["params"] = {}
132-
if self._MCP_META_FIELD not in message_json["params"]:
133-
message_json["params"][self._MCP_META_FIELD] = {}
134-
135-
parent_ctx = None
136-
if message_json["params"][self._MCP_META_FIELD]:
137-
parent_ctx = self.propagators.extract(message_json["params"][self._MCP_META_FIELD])
123+
if "_meta" not in message_json["params"]:
124+
message_json["params"]["_meta"] = {}
138125

139-
with self.tracer.start_as_current_span(name=span_name, kind=span_kind, context=parent_ctx) as span:
126+
with self.tracer.start_as_current_span(name=span_name, kind=span_kind) as span:
140127
ctx = trace.set_span_in_context(span)
141128
carrier = {}
142129
self.propagators.inject(carrier=carrier, context=ctx)
143-
message_json["params"][self._MCP_META_FIELD].update(carrier)
130+
message_json["params"]["_meta"].update(carrier)
144131

145132
McpInstrumentor._generate_mcp_message_attrs(span, message, request_id)
146133

@@ -240,14 +227,11 @@ async def _wrap_server_message_handler(
240227
with self.tracer.start_as_current_span(
241228
self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx
242229
) as server_span:
243-
244-
# session_id only exits if the transport protocol is Streamable HTTP
230+
231+
# Extract session ID if available
245232
session_id = self._extract_session_id(args)
246233
if session_id:
247234
server_span.set_attribute(MCPSpanAttributes.MCP_SESSION_ID, session_id)
248-
server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.PIPE.value)
249-
else:
250-
server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.TCP.value)
251235

252236
self._generate_mcp_message_attrs(server_span, incoming_msg, request_id)
253237

@@ -259,60 +243,27 @@ async def _wrap_server_message_handler(
259243
server_span.set_status(Status(StatusCode.ERROR, str(e)))
260244
server_span.record_exception(e)
261245
raise
262-
263-
async def _wrap_send_response(
264-
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
265-
) -> Any:
266-
"""
267-
Instruments BaseSession._send_response to propagate trace context into response.
268-
269-
Note: we do not need to generate another span for the reponse as it falls under
270-
the _wrap_server_message_handler
271-
"""
272-
response = args[1] if len(args) > 1 else kwargs.get("response", None)
273-
274-
if not response:
275-
return await wrapped(*args, **kwargs)
276-
277-
current_span = trace.get_current_span()
278-
if current_span is not INVALID_SPAN:
279-
# Inject trace context into response
280-
carrier = {}
281-
self.propagators.inject(carrier=carrier, context=trace.set_span_in_context(current_span))
282-
283-
response_json = response.model_dump(by_alias=True, mode="json", exclude_none=True)
284-
285-
if self._MCP_META_FIELD not in response_json:
286-
response_json[self._MCP_META_FIELD] = {}
287-
response_json[self._MCP_META_FIELD].update(carrier)
288-
289-
modified_response = response.model_validate(response_json)
290-
291-
if len(args) > 1:
292-
args = args[:1] + (modified_response,) + args[2:]
293-
else:
294-
kwargs["response"] = modified_response
295-
296-
return await wrapped(*args, **kwargs)
297-
246+
298247
def _extract_session_id(self, args: Tuple[Any, ...]) -> Optional[str]:
299248
"""
300249
Extract session ID from server method arguments.
301250
"""
302251
try:
252+
from mcp.shared.session import RequestResponder # pylint: disable=import-outside-toplevel
253+
from mcp.shared.message import ServerMessageMetadata # pylint: disable=import-outside-toplevel
254+
303255
message = args[0]
304-
if hasattr(message, "message_metadata"):
305-
message_metadata = message.message_metadata
306-
if message.message_metadata and hasattr(message_metadata, "request_context"):
307-
request_context = message_metadata.request_context
308-
if request_context and hasattr(request_context, "headers"):
309-
headers = request_context.headers
256+
if isinstance(message, RequestResponder):
257+
if message.message_metadata and isinstance(message.message_metadata, ServerMessageMetadata):
258+
request_context = message.message_metadata.request_context
259+
if request_context:
260+
headers = getattr(request_context, 'headers', None)
310261
if headers:
311262
return headers.get(self._MCP_SESSION_ID_HEADER)
312263
return None
313264
except Exception:
314265
return None
315-
266+
316267
@staticmethod
317268
def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None:
318269
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
@@ -375,4 +326,4 @@ def serialize(args: dict[str, Any]) -> str:
375326
try:
376327
return json.dumps(args)
377328
except Exception:
378-
return "unknown_args"
329+
return ""

0 commit comments

Comments
 (0)