Skip to content

Commit 5c1aeec

Browse files
committed
add further trace propagation for responses
1 parent aab4e7d commit 5c1aeec

File tree

1 file changed

+74
-25
lines changed
  • aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp

1 file changed

+74
-25
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/instrumentation/mcp/instrumentation.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
from dataclasses import dataclass
43
import json
5-
from typing import Any, AsyncGenerator, Callable, Collection, Dict, Optional, Tuple, Union, cast
4+
from typing import Any, Callable, Collection, Dict, Optional, Tuple, Union
65

76
from wrapt import register_post_import_hook, wrap_function_wrapper
87

98
from opentelemetry import trace
10-
from opentelemetry.trace import SpanKind, Status, StatusCode
119
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
1210
from opentelemetry.instrumentation.utils import unwrap
13-
from opentelemetry.semconv.trace import SpanAttributes
1411
from opentelemetry.propagate import get_global_textmap
15-
16-
from .version import __version__
12+
from opentelemetry.semconv.attributes.network_attributes import NetworkTransportValues
13+
from opentelemetry.semconv.trace import SpanAttributes
14+
from opentelemetry.trace import INVALID_SPAN, SpanKind, Status, StatusCode
1715

1816
from .semconv import (
19-
MCPSpanAttributes,
2017
MCPMethodValue,
18+
MCPSpanAttributes,
2119
)
20+
from .version import __version__
2221

2322

2423
class McpInstrumentor(BaseInstrumentor):
@@ -29,6 +28,7 @@ class McpInstrumentor(BaseInstrumentor):
2928
_DEFAULT_CLIENT_SPAN_NAME = "span.mcp.client"
3029
_DEFAULT_SERVER_SPAN_NAME = "span.mcp.server"
3130
_MCP_SESSION_ID_HEADER = "mcp-session-id"
31+
_MCP_META_FIELD = "_meta"
3232

3333
def __init__(self, **kwargs):
3434
super().__init__()
@@ -55,6 +55,14 @@ def _instrument(self, **kwargs: Any) -> None:
5555
),
5656
"mcp.shared.session",
5757
)
58+
register_post_import_hook(
59+
lambda _: wrap_function_wrapper(
60+
"mcp.shared.session",
61+
"BaseSession._send_response",
62+
self._wrap_send_response,
63+
),
64+
"mcp.shared.session",
65+
)
5866
register_post_import_hook(
5967
lambda _: wrap_function_wrapper(
6068
"mcp.server.lowlevel.server",
@@ -75,6 +83,7 @@ def _instrument(self, **kwargs: Any) -> None:
7583
def _uninstrument(self, **kwargs: Any) -> None:
7684
unwrap("mcp.shared.session", "BaseSession.send_request")
7785
unwrap("mcp.shared.session", "BaseSession.send_notification")
86+
unwrap("mcp.shared.session", "BaseSession._send_response")
7887
unwrap("mcp.server.lowlevel.server", "Server._handle_request")
7988
unwrap("mcp.server.lowlevel.server", "Server._handle_notification")
8089

@@ -98,7 +107,7 @@ def _wrap_session_send(
98107
args: Positional arguments passed to the original send_request/send_notification method
99108
kwargs: Keyword arguments passed to the original send_request/send_notification method
100109
"""
101-
from mcp.types import ClientRequest, ClientNotification, ServerRequest, ServerNotification
110+
from mcp.types import ClientNotification, ClientRequest, ServerNotification, ServerRequest
102111

103112
async def async_wrapper():
104113
message: Optional[Union[ClientRequest, ClientNotification, ServerRequest, ServerNotification]] = (
@@ -120,14 +129,18 @@ async def async_wrapper():
120129

121130
if "params" not in message_json:
122131
message_json["params"] = {}
123-
if "_meta" not in message_json["params"]:
124-
message_json["params"]["_meta"] = {}
132+
if self._MCP_META_FIELD not in message_json["params"]:
133+
message_json["params"][self._MCP_META_FIELD] = {}
134+
135+
parent_ctx = None
136+
if message_json["params"][self._MCP_META_FIELD]:
137+
parent_ctx = self.propagators.extract(message_json["params"][self._MCP_META_FIELD])
125138

126-
with self.tracer.start_as_current_span(name=span_name, kind=span_kind) as span:
139+
with self.tracer.start_as_current_span(name=span_name, kind=span_kind, context=parent_ctx) as span:
127140
ctx = trace.set_span_in_context(span)
128141
carrier = {}
129142
self.propagators.inject(carrier=carrier, context=ctx)
130-
message_json["params"]["_meta"].update(carrier)
143+
message_json["params"][self._MCP_META_FIELD].update(carrier)
131144

132145
McpInstrumentor._generate_mcp_message_attrs(span, message, request_id)
133146

@@ -227,11 +240,14 @@ async def _wrap_server_message_handler(
227240
with self.tracer.start_as_current_span(
228241
self._DEFAULT_SERVER_SPAN_NAME, kind=SpanKind.SERVER, context=parent_ctx
229242
) as server_span:
230-
231-
# Extract session ID if available
243+
244+
# session_id only exits if the transport protocol is Streamable HTTP
232245
session_id = self._extract_session_id(args)
233246
if session_id:
234247
server_span.set_attribute(MCPSpanAttributes.MCP_SESSION_ID, session_id)
248+
server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.PIPE.value)
249+
else:
250+
server_span.set_attribute(SpanAttributes.NETWORK_TRANSPORT, NetworkTransportValues.TCP.value)
235251

236252
self._generate_mcp_message_attrs(server_span, incoming_msg, request_id)
237253

@@ -243,27 +259,60 @@ async def _wrap_server_message_handler(
243259
server_span.set_status(Status(StatusCode.ERROR, str(e)))
244260
server_span.record_exception(e)
245261
raise
246-
262+
263+
async def _wrap_send_response(
264+
self, wrapped: Callable, instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
265+
) -> Any:
266+
"""
267+
Instruments BaseSession._send_response to propagate trace context into response.
268+
269+
Note: we do not need to generate another span for the reponse as it falls under
270+
the _wrap_server_message_handler
271+
"""
272+
response = args[1] if len(args) > 1 else kwargs.get("response", None)
273+
274+
if not response:
275+
return await wrapped(*args, **kwargs)
276+
277+
current_span = trace.get_current_span()
278+
if current_span is not INVALID_SPAN:
279+
# Inject trace context into response
280+
carrier = {}
281+
self.propagators.inject(carrier=carrier, context=trace.set_span_in_context(current_span))
282+
283+
response_json = response.model_dump(by_alias=True, mode="json", exclude_none=True)
284+
285+
if self._MCP_META_FIELD not in response_json:
286+
response_json[self._MCP_META_FIELD] = {}
287+
response_json[self._MCP_META_FIELD].update(carrier)
288+
289+
modified_response = response.model_validate(response_json)
290+
291+
if len(args) > 1:
292+
args = args[:1] + (modified_response,) + args[2:]
293+
else:
294+
kwargs["response"] = modified_response
295+
296+
return await wrapped(*args, **kwargs)
297+
247298
def _extract_session_id(self, args: Tuple[Any, ...]) -> Optional[str]:
248299
"""
249300
Extract session ID from server method arguments.
250301
"""
251302
try:
252-
from mcp.shared.session import RequestResponder # pylint: disable=import-outside-toplevel
253-
from mcp.shared.message import ServerMessageMetadata # pylint: disable=import-outside-toplevel
254-
255303
message = args[0]
256-
if isinstance(message, RequestResponder):
257-
if message.message_metadata and isinstance(message.message_metadata, ServerMessageMetadata):
258-
request_context = message.message_metadata.request_context
259-
if request_context:
260-
headers = getattr(request_context, 'headers', None)
304+
if hasattr(message, "message_metadata"):
305+
message_metadata = message.message_metadata
306+
if message.message_metadata and hasattr(message_metadata, "request_context"):
307+
request_context = message_metadata.request_context
308+
if request_context and hasattr(request_context, "headers"):
309+
headers = request_context.headers
261310
if headers:
262311
return headers.get(self._MCP_SESSION_ID_HEADER)
263312
return None
264313
except Exception:
265314
return None
266-
315+
267316
@staticmethod
268317
def _generate_mcp_message_attrs(span: trace.Span, message, request_id: Optional[int]) -> None:
269318
import mcp.types as types # pylint: disable=import-outside-toplevel,consider-using-from-import
@@ -326,4 +375,4 @@ def serialize(args: dict[str, Any]) -> str:
326375
try:
327376
return json.dumps(args)
328377
except Exception:
329-
return ""
378+
return "unknown_args"

0 commit comments

Comments
 (0)