Skip to content

Commit e0e24e1

Browse files
committed
fix: retrieval of session ID during a request
1 parent ae9d325 commit e0e24e1

File tree

3 files changed

+69
-69
lines changed

3 files changed

+69
-69
lines changed

sentry_sdk/integrations/mcp/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,9 @@ def setup_once():
2727
Patches MCP server classes to instrument handler execution.
2828
"""
2929
from sentry_sdk.integrations.mcp.lowlevel import patch_lowlevel_server
30-
from sentry_sdk.integrations.mcp.transport import (
31-
patch_streamable_http_transport,
32-
)
3330

3431
# Patch server classes to instrument handlers
3532
patch_lowlevel_server()
3633

37-
# Patch HTTP transport to track session IDs
38-
patch_streamable_http_transport()
39-
4034

4135
__all__ = ["MCPIntegration"]

sentry_sdk/integrations/mcp/lowlevel.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sentry_sdk.integrations.mcp import MCPIntegration
1313
from sentry_sdk.integrations.mcp.transport import (
1414
detect_mcp_transport_from_context,
15-
mcp_session_id_ctx,
15+
get_session_id_from_context,
1616
)
1717
from sentry_sdk.utils import safe_serialize
1818

@@ -24,6 +24,28 @@
2424
from typing import Any, Callable
2525

2626

27+
def _get_request_context_data():
28+
# type: () -> tuple[str | None, str | None]
29+
"""
30+
Extract request ID and session ID from the MCP request context.
31+
32+
Returns:
33+
Tuple of (request_id, session_id). Either value may be None if not available.
34+
"""
35+
request_id = None # type: str | None
36+
session_id = None # type: str | None
37+
38+
try:
39+
ctx = request_ctx.get()
40+
request_id = ctx.request_id
41+
session_id = get_session_id_from_context(ctx)
42+
except LookupError:
43+
# No request context available
44+
pass
45+
46+
return request_id, session_id
47+
48+
2749
def _get_span_config(handler_type, item_name):
2850
# type: (str, str) -> tuple[str, str, str, str | None]
2951
"""
@@ -51,9 +73,15 @@ def _get_span_config(handler_type, item_name):
5173

5274

5375
def _set_span_input_data(
54-
span, handler_name, span_data_key, mcp_method_name, arguments, request_id=None
76+
span,
77+
handler_name,
78+
span_data_key,
79+
mcp_method_name,
80+
arguments,
81+
request_id=None,
82+
session_id=None,
5583
):
56-
# type: (Any, str, str, str, dict[str, Any], str | None) -> None
84+
# type: (Any, str, str, str, dict[str, Any], str | None, str | None) -> None
5785
"""Set input span data for MCP handlers."""
5886
# Set handler identifier
5987
span.set_data(span_data_key, handler_name)
@@ -69,19 +97,14 @@ def _set_span_input_data(
6997
# No request context available - likely stdio
7098
span.set_data(SPANDATA.MCP_TRANSPORT, "pipe")
7199

72-
# Extract session ID from context variable if available (HTTP transport)
73-
try:
74-
session_id = mcp_session_id_ctx.get()
75-
if session_id:
76-
span.set_data(SPANDATA.MCP_SESSION_ID, session_id)
77-
except Exception:
78-
# Session ID not available or transport module not imported
79-
pass
80-
81100
# Set request_id if provided
82101
if request_id:
83102
span.set_data(SPANDATA.MCP_REQUEST_ID, request_id)
84103

104+
# Set session_id if provided
105+
if session_id:
106+
span.set_data(SPANDATA.MCP_SESSION_ID, session_id)
107+
85108
# Set request arguments (excluding common request context objects)
86109
for k, v in arguments.items():
87110
span.set_data(f"mcp.request.argument.{k}", safe_serialize(v))
@@ -262,13 +285,8 @@ async def _async_handler_wrapper(handler_type, func, original_args):
262285
name=span_name,
263286
origin=MCPIntegration.origin,
264287
) as span:
265-
# Get request ID from context
266-
request_id = None
267-
try:
268-
ctx = request_ctx.get()
269-
request_id = ctx.request_id
270-
except LookupError:
271-
pass
288+
# Get request ID and session ID from context
289+
request_id, session_id = _get_request_context_data()
272290

273291
# Set input span data
274292
_set_span_input_data(
@@ -278,6 +296,7 @@ async def _async_handler_wrapper(handler_type, func, original_args):
278296
mcp_method_name,
279297
arguments,
280298
request_id,
299+
session_id,
281300
)
282301

283302
# For resources, extract and set protocol
@@ -329,13 +348,8 @@ def _sync_handler_wrapper(handler_type, func, original_args):
329348
name=span_name,
330349
origin=MCPIntegration.origin,
331350
) as span:
332-
# Get request ID from context
333-
request_id = None
334-
try:
335-
ctx = request_ctx.get()
336-
request_id = ctx.request_id
337-
except LookupError:
338-
pass
351+
# Get request ID and session ID from context
352+
request_id, session_id = _get_request_context_data()
339353

340354
# Set input span data
341355
_set_span_input_data(
@@ -345,6 +359,7 @@ def _sync_handler_wrapper(handler_type, func, original_args):
345359
mcp_method_name,
346360
arguments,
347361
request_id,
362+
session_id,
348363
)
349364

350365
# For resources, extract and set protocol

sentry_sdk/integrations/mcp/transport.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,15 @@
99
- If there's an HTTP request context (SSE/WebSocket), transport is "tcp"
1010
- If there's no request context (stdio), transport is "pipe"
1111
12-
Session ID tracking is done by patching the StreamableHTTPServerTransport to store
13-
the session ID in a context variable that can be accessed during handler execution.
12+
Session ID tracking is done by patching Server._run_request_handler to store a
13+
reference to the server instance in the request context. The session ID can then
14+
be retrieved from the server's transport when needed during handler execution.
1415
"""
1516

16-
import contextvars
1717
from typing import TYPE_CHECKING
1818

1919
if TYPE_CHECKING:
2020
from typing import Optional, Any
21-
from starlette.types import Receive, Scope, Send
22-
23-
24-
# Context variable to store the current MCP session ID
25-
mcp_session_id_ctx = contextvars.ContextVar("mcp_session_id", default=None) # type: contextvars.ContextVar[Optional[str]]
2621

2722

2823
def detect_mcp_transport_from_context(request_ctx):
@@ -58,36 +53,32 @@ def detect_mcp_transport_from_context(request_ctx):
5853
return None
5954

6055

61-
def patch_streamable_http_transport():
62-
# type: () -> None
56+
def get_session_id_from_context(request_ctx):
57+
# type: (Any) -> Optional[str]
6358
"""
64-
Patches the StreamableHTTPServerTransport to store session IDs in context.
59+
Extract session ID from the request context.
60+
61+
The session ID is sent by the client in the MCP-Session-Id header and is
62+
available in the Starlette Request object stored in ctx.request.
6563
66-
This allows handler code to access the session ID via the mcp_session_id_ctx
67-
context variable, regardless of whether it's the first request (where the
68-
session ID hasn't been sent to the client yet) or a subsequent request.
64+
Args:
65+
request_ctx: The MCP request context object
66+
67+
Returns:
68+
Session ID string if available, None otherwise
6969
"""
7070
try:
71-
from mcp.server.streamable_http import StreamableHTTPServerTransport
72-
except ImportError:
73-
# StreamableHTTP transport not available
74-
return
75-
76-
original_handle_request = StreamableHTTPServerTransport.handle_request
77-
78-
async def patched_handle_request(self, scope, receive, send):
79-
# type: (Any, Scope, Receive, Send) -> None
80-
"""Wrap handle_request to set session ID in context."""
81-
# Store session ID in context variable before handling request
82-
token = None
83-
if hasattr(self, "mcp_session_id") and self.mcp_session_id:
84-
token = mcp_session_id_ctx.set(self.mcp_session_id)
85-
86-
try:
87-
await original_handle_request(self, scope, receive, send)
88-
finally:
89-
# Reset context after request
90-
if token is not None:
91-
mcp_session_id_ctx.reset(token)
92-
93-
StreamableHTTPServerTransport.handle_request = patched_handle_request # type: ignore
71+
# The Starlette Request object is stored in ctx.request
72+
if hasattr(request_ctx, "request") and request_ctx.request is not None:
73+
request = request_ctx.request
74+
75+
# Check if it's a Starlette Request with headers
76+
if hasattr(request, "headers"):
77+
# The session ID is sent in the mcp-session-id header
78+
# MCP_SESSION_ID_HEADER = "mcp-session-id"
79+
return request.headers.get("mcp-session-id")
80+
81+
except Exception:
82+
pass
83+
84+
return None

0 commit comments

Comments
 (0)