Skip to content

Commit a86f8f4

Browse files
committed
propagate request context though transport
1 parent 6e418e6 commit a86f8f4

File tree

8 files changed

+246
-8
lines changed

8 files changed

+246
-8
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
124124
def lifespan_wrapper(
125125
app: FastMCP,
126126
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
127-
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
127+
) -> Callable[
128+
[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]
129+
]:
128130
@asynccontextmanager
129131
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
130132
async with lifespan(app) as context:
@@ -933,7 +935,8 @@ def my_tool(x: int, ctx: Context) -> str:
933935
def __init__(
934936
self,
935937
*,
936-
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
938+
request_context: RequestContext[ServerSessionT, LifespanContextT]
939+
| None = None,
937940
fastmcp: FastMCP | None = None,
938941
**kwargs: Any,
939942
):
@@ -949,7 +952,9 @@ def fastmcp(self) -> FastMCP:
949952
return self._fastmcp
950953

951954
@property
952-
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
955+
def request_context(
956+
self,
957+
) -> RequestContext[ServerSessionT, LifespanContextT]:
953958
"""Access to the underlying request context."""
954959
if self._request_context is None:
955960
raise ValueError("Context is not available outside of a request")

src/mcp/server/lowlevel/server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def main():
8585
from mcp.server.stdio import stdio_server as stdio_server
8686
from mcp.shared.context import RequestContext
8787
from mcp.shared.exceptions import McpError
88-
from mcp.shared.message import SessionMessage
88+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
8989
from mcp.shared.session import RequestResponder
9090

9191
logger = logging.getLogger(__name__)
@@ -215,7 +215,9 @@ def get_capabilities(
215215
)
216216

217217
@property
218-
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
218+
def request_context(
219+
self,
220+
) -> RequestContext[ServerSession, LifespanResultT]:
219221
"""If called outside of a request context, this will raise a LookupError."""
220222
return request_ctx.get()
221223

@@ -555,6 +557,15 @@ async def _handle_request(
555557

556558
token = None
557559
try:
560+
# Extract request context from message metadata
561+
request_data = None
562+
if (
563+
hasattr(message, "message_metadata")
564+
and message.message_metadata
565+
and isinstance(message.message_metadata, ServerMessageMetadata)
566+
):
567+
request_data = message.message_metadata.request_context
568+
558569
# Set our global state that can be retrieved via
559570
# app.get_request_context()
560571
token = request_ctx.set(
@@ -563,6 +574,7 @@ async def _handle_request(
563574
message.request_meta,
564575
session,
565576
lifespan_context,
577+
request=request_data,
566578
)
567579
)
568580
response = await handler(req)

src/mcp/server/sse.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55-
from mcp.shared.message import SessionMessage
55+
from mcp.shared.context import RequestData
56+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5657

5758
logger = logging.getLogger(__name__)
5859

@@ -203,7 +204,19 @@ async def handle_post_message(
203204
await writer.send(err)
204205
return
205206

206-
session_message = SessionMessage(message)
207+
# Extract request headers and other context
208+
request_context: RequestData = {
209+
"headers": dict(request.headers),
210+
"method": request.method,
211+
"url": str(request.url),
212+
"client": request.client,
213+
"path_params": request.path_params,
214+
"query_params": dict(request.query_params),
215+
}
216+
217+
# Create session message with request context
218+
metadata = ServerMessageMetadata(request_context=request_context)
219+
session_message = SessionMessage(message, metadata=metadata)
207220
logger.debug(f"Sending session message to writer: {session_message}")
208221
response = Response("Accepted", status_code=202)
209222
await response(scope, receive, send)

src/mcp/shared/context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
1111

12+
# Type alias for request-specific data (e.g., headers, auth info)
13+
RequestData = dict[str, Any]
14+
1215

1316
@dataclass
1417
class RequestContext(Generic[SessionT, LifespanContextT]):
1518
request_id: RequestId
1619
meta: RequestParams.Meta | None
1720
session: SessionT
1821
lifespan_context: LifespanContextT
22+
request: RequestData | None = None

src/mcp/shared/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def create_client_server_memory_streams() -> (
6060

6161
@asynccontextmanager
6262
async def create_connected_server_and_client_session(
63-
server: Server[Any],
63+
server: Server[Any, Any],
6464
read_timeout_seconds: timedelta | None = None,
6565
sampling_callback: SamplingFnT | None = None,
6666
list_roots_callback: ListRootsFnT | None = None,

src/mcp/shared/message.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass
1010

11+
from mcp.shared.context import RequestData
1112
from mcp.types import JSONRPCMessage, RequestId
1213

1314
ResumptionToken = str
@@ -30,6 +31,8 @@ class ServerMessageMetadata:
3031
"""Metadata specific to server messages."""
3132

3233
related_request_id: RequestId | None = None
34+
# Request-specific context (e.g., headers, auth info)
35+
request_context: RequestData | None = None
3336

3437

3538
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

src/mcp/shared/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ def __init__(
8080
ReceiveNotificationT
8181
]""",
8282
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
83+
message_metadata: MessageMetadata = None,
8384
) -> None:
8485
self.request_id = request_id
8586
self.request_meta = request_meta
8687
self.request = request
88+
self.message_metadata = message_metadata
8789
self._session = session
8890
self._completed = False
8991
self._cancel_scope = anyio.CancelScope()
@@ -364,6 +366,7 @@ async def _receive_loop(self) -> None:
364366
request=validated_request,
365367
session=self,
366368
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
369+
message_metadata=message.metadata,
367370
)
368371

369372
self._in_flight[responder.request_id] = responder

tests/shared/test_sse.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import multiprocessing
23
import socket
34
import time
@@ -318,3 +319,200 @@ async def test_sse_client_basic_connection_mounted_app(
318319
# Test ping
319320
ping_result = await session.send_ping()
320321
assert isinstance(ping_result, EmptyResult)
322+
323+
324+
# Test server with request context that returns headers in the response
325+
class RequestContextServer(Server):
326+
def __init__(self):
327+
super().__init__("request_context_server")
328+
329+
@self.call_tool()
330+
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
331+
# Capture request context if available and return it
332+
headers_info = {}
333+
try:
334+
context = self.request_context
335+
if context.request:
336+
headers_info = context.request.get("headers", {})
337+
except LookupError:
338+
pass # No request context available
339+
340+
if name == "echo_headers":
341+
# Return the headers as JSON in the response
342+
import json
343+
344+
return [TextContent(type="text", text=json.dumps(headers_info))]
345+
elif name == "echo_context":
346+
# Return context info with request ID
347+
import json
348+
349+
context_data = {
350+
"request_id": args.get("request_id"),
351+
"headers": headers_info,
352+
}
353+
return [TextContent(type="text", text=json.dumps(context_data))]
354+
355+
return [TextContent(type="text", text=f"Called {name}")]
356+
357+
@self.list_tools()
358+
async def handle_list_tools() -> list[Tool]:
359+
return [
360+
Tool(
361+
name="echo_headers",
362+
description="Echoes request headers",
363+
inputSchema={"type": "object", "properties": {}},
364+
),
365+
Tool(
366+
name="echo_context",
367+
description="Echoes request context",
368+
inputSchema={
369+
"type": "object",
370+
"properties": {"request_id": {"type": "string"}},
371+
"required": ["request_id"],
372+
},
373+
),
374+
]
375+
376+
377+
def run_context_server(server_port: int) -> None:
378+
"""Run a server that captures request context"""
379+
sse = SseServerTransport("/messages/")
380+
context_server = RequestContextServer()
381+
382+
async def handle_sse(request: Request) -> Response:
383+
async with sse.connect_sse(
384+
request.scope, request.receive, request._send
385+
) as streams:
386+
await context_server.run(
387+
streams[0], streams[1], context_server.create_initialization_options()
388+
)
389+
return Response()
390+
391+
app = Starlette(
392+
routes=[
393+
Route("/sse", endpoint=handle_sse),
394+
Mount("/messages/", app=sse.handle_post_message),
395+
]
396+
)
397+
398+
server = uvicorn.Server(
399+
config=uvicorn.Config(
400+
app=app, host="127.0.0.1", port=server_port, log_level="error"
401+
)
402+
)
403+
print(f"starting context server on {server_port}")
404+
server.run()
405+
406+
407+
@pytest.fixture()
408+
def context_server(server_port: int) -> Generator[None, None, None]:
409+
"""Fixture that provides a server with request context capture"""
410+
proc = multiprocessing.Process(
411+
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
412+
)
413+
print("starting context server process")
414+
proc.start()
415+
416+
# Wait for server to be running
417+
max_attempts = 20
418+
attempt = 0
419+
print("waiting for context server to start")
420+
while attempt < max_attempts:
421+
try:
422+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
423+
s.connect(("127.0.0.1", server_port))
424+
break
425+
except ConnectionRefusedError:
426+
time.sleep(0.1)
427+
attempt += 1
428+
else:
429+
raise RuntimeError(
430+
f"Context server failed to start after {max_attempts} attempts"
431+
)
432+
433+
yield
434+
435+
print("killing context server")
436+
proc.kill()
437+
proc.join(timeout=2)
438+
if proc.is_alive():
439+
print("context server process failed to terminate")
440+
441+
442+
@pytest.mark.anyio
443+
async def test_request_context_propagation(
444+
context_server: None, server_url: str
445+
) -> None:
446+
"""Test that request context is properly propagated through SSE transport."""
447+
# Test with custom headers
448+
custom_headers = {
449+
"Authorization": "Bearer test-token",
450+
"X-Custom-Header": "test-value",
451+
"X-Trace-Id": "trace-123",
452+
}
453+
454+
async with sse_client(server_url + "/sse", headers=custom_headers) as (
455+
read_stream,
456+
write_stream,
457+
):
458+
async with ClientSession(read_stream, write_stream) as session:
459+
# Initialize the session
460+
result = await session.initialize()
461+
assert isinstance(result, InitializeResult)
462+
463+
# Call the tool that echoes headers back
464+
tool_result = await session.call_tool("echo_headers", {})
465+
466+
# Parse the JSON response
467+
468+
assert len(tool_result.content) == 1
469+
headers_data = json.loads(
470+
tool_result.content[0].text
471+
if tool_result.content[0].type == "text"
472+
else "{}"
473+
)
474+
475+
# Verify headers were propagated
476+
assert headers_data.get("authorization") == "Bearer test-token"
477+
assert headers_data.get("x-custom-header") == "test-value"
478+
assert headers_data.get("x-trace-id") == "trace-123"
479+
480+
481+
@pytest.mark.anyio
482+
async def test_request_context_isolation(context_server: None, server_url: str) -> None:
483+
"""Test that request contexts are isolated between different SSE clients."""
484+
contexts = []
485+
486+
# Create multiple clients with different headers
487+
for i in range(3):
488+
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
489+
490+
async with sse_client(server_url + "/sse", headers=headers) as (
491+
read_stream,
492+
write_stream,
493+
):
494+
async with ClientSession(read_stream, write_stream) as session:
495+
await session.initialize()
496+
497+
# Call the tool that echoes context
498+
tool_result = await session.call_tool(
499+
"echo_context", {"request_id": f"request-{i}"}
500+
)
501+
502+
# Parse and store the result
503+
import json
504+
505+
assert len(tool_result.content) == 1
506+
context_data = json.loads(
507+
tool_result.content[0].text
508+
if tool_result.content[0].type == "text"
509+
else "{}"
510+
)
511+
contexts.append(context_data)
512+
513+
# Verify each request had its own context
514+
assert len(contexts) == 3
515+
for i, ctx in enumerate(contexts):
516+
assert ctx["request_id"] == f"request-{i}"
517+
assert ctx["headers"].get("x-request-id") == f"request-{i}"
518+
assert ctx["headers"].get("x-custom-value") == f"value-{i}"

0 commit comments

Comments
 (0)