Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def main():
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -215,7 +215,9 @@ def get_capabilities(
)

@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()

Expand Down Expand Up @@ -555,6 +557,13 @@ async def _handle_request(

token = None
try:
# Extract request context from message metadata
request_data = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context

# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
Expand All @@ -563,6 +572,7 @@ async def _handle_request(
message.request_meta,
session,
lifespan_context,
request=request_data,
)
)
response = await handler(req)
Expand Down
17 changes: 15 additions & 2 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import RequestData

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -203,7 +204,19 @@ async def handle_post_message(
await writer.send(err)
return

session_message = SessionMessage(message)
# Extract request headers and other context
request_context: RequestData = {
"headers": dict(request.headers),
"method": request.method,
"url": str(request.url),
"client": request.client,
"path_params": request.path_params,
"query_params": dict(request.query_params),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead of passing an open dict, just pass the request object? This gives the full power to the consumer. They can still isinstance check to see what specific object they are getting. This way we are a bit more generic than specifiying an open dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern with that is that we are putting starlette.requests into the core and types. Not sure I'm super happy about that, but equally we use starlette everywhere in the SDK

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should pass the request.scope, which is a dictionary, and contains all the connection information the user will ever need.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest changes use request, not the request.scope I mentioned tho 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw request.scope only after pushed the update... will update in a few hours

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kludex would that work for any transport? the problem with the general RequestData is that it's:

  1. Has no clear structure.
  2. Might de-facto depend on HTTP specifics

I was wondering if we can make this a generic and actually pass the origianl request objects. This could be in gRPC context something different than Websocket or HTTP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly none of the options are great here. I'd love to use generics to make this properly typed, but that would break backwards compatibility for everyone using RequestContext[SessionT, LifespanContextT] - they'd get errors about missing type arguments.

So for now, i'd suggest to go with request.scope as a dict. It's not perfect but:

  • No breaking changes
  • No framework dependencies in core types
  • Gets us all the request data we need

For WebSocket, it'll work the same way - we can pass the WebSocket's ASGI scope which contains initial handshake headers etc

The scope structure is consistent across HTTP and WebSocket per the ASGI spec, so tools can handle both transports uniformly.

Not ideal, but it gets the job done without breaking existing code. I'd love to revisit this in sdk v2 though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly none of the options are great here. I'd love to use generics to make this properly typed, but that would break backwards compatibility for everyone using RequestContext[SessionT, LifespanContextT] - they'd get errors about missing type arguments.

The way to do for this is to set a default value on the TypeVar - so people don't get errors.

from typing_extensions import TypeVar

T = TypeVar("T", default=...)`

I don't consider type changes as breaking changes btw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, cool, thanks, didn't realise typing_extensions had TypeVar with default


# Create session message with request context
metadata = ServerMessageMetadata(request_context=request_context)
session_message = SessionMessage(message, metadata=metadata)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import TypeVar

from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams
from mcp.types import RequestData, RequestId, RequestParams

SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
Expand All @@ -16,3 +16,4 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestData | None = None
4 changes: 3 additions & 1 deletion src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass

from mcp.types import JSONRPCMessage, RequestId
from mcp.types import JSONRPCMessage, RequestData, RequestId

ResumptionToken = str

Expand All @@ -30,6 +30,8 @@ class ServerMessageMetadata:
"""Metadata specific to server messages."""

related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: RequestData | None = None


MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
Expand Down
3 changes: 3 additions & 0 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def __init__(
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
Expand Down Expand Up @@ -364,6 +366,7 @@ async def _receive_loop(self) -> None:
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)

self._in_flight[responder.request_id] = responder
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
Role = Literal["user", "assistant"]
RequestId = str | int
AnyFunction: TypeAlias = Callable[..., Any]
# Http request-specific data (e.g., headers, auth info)
RequestData = dict[str, Any]


class RequestParams(BaseModel):
Expand Down
186 changes: 178 additions & 8 deletions tests/server/fastmcp/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
including with and without authentication.
"""

import json
import multiprocessing
import socket
import time
Expand All @@ -13,13 +14,15 @@
import pytest
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
Expand Down Expand Up @@ -78,8 +81,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
# Create a function to make the FastMCP server app
def make_fastmcp_app():
"""Create a FastMCP server without auth settings."""
from starlette.applications import Starlette

mcp = FastMCP(name="NoAuthServer")

# Add a simple tool
Expand Down Expand Up @@ -198,8 +199,6 @@ def complex_prompt(user_query: str, context: str = "general") -> str:

def make_everything_fastmcp_app():
"""Create a comprehensive FastMCP server with SSE transport."""
from starlette.applications import Starlette

mcp = make_everything_fastmcp()
# Create the SSE app
app: Starlette = mcp.sse_app()
Expand All @@ -208,7 +207,6 @@ def make_everything_fastmcp_app():

def make_fastmcp_streamable_http_app():
"""Create a FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette

mcp = FastMCP(name="NoAuthServer")

Expand All @@ -225,8 +223,6 @@ def echo(message: str) -> str:

def make_everything_fastmcp_streamable_http_app():
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette

# Create a new instance with different name for HTTP transport
mcp = make_everything_fastmcp()
# We can't change the name after creation, so we'll use the same name
Expand All @@ -237,7 +233,6 @@ def make_everything_fastmcp_streamable_http_app():

def make_fastmcp_stateless_http_app():
"""Create a FastMCP server with stateless StreamableHTTP transport."""
from starlette.applications import Starlette

mcp = FastMCP(name="StatelessServer", stateless_http=True)

Expand Down Expand Up @@ -435,6 +430,181 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert tool_result.content[0].text == "Echo: hello"


def make_fastmcp_with_context_app():
"""Create a FastMCP server that can access request context."""

mcp = FastMCP(name="ContextServer")

# Tool that echoes request headers
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
try:
if ctx.request_context.request:
headers_info = ctx.request_context.request.get("headers", {})
except Exception:
pass
return json.dumps(headers_info)

# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"url": None,
}
try:
if ctx.request_context.request:
context_data["headers"] = ctx.request_context.request.get("headers", {})
context_data["method"] = ctx.request_context.request.get("method")
context_data["url"] = ctx.request_context.request.get("url")
except Exception:
pass
return json.dumps(context_data)

# Create the SSE app
app: Starlette = mcp.sse_app()
return mcp, app


def run_context_server(server_port: int) -> None:
"""Run the context-aware FastMCP server."""
_, app = make_fastmcp_with_context_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting context server on port {server_port}")
server.run()


@pytest.fixture()
def context_aware_server(server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_server, args=(server_port,), daemon=True
)
print("Starting context-aware server process")
proc.start()

# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for context-aware server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)

yield

print("Killing context-aware server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context server process failed to terminate")


@pytest.mark.anyio
async def test_fast_mcp_with_request_context(
context_aware_server: None, server_url: str
) -> None:
"""Test that FastMCP properly propagates request context to tools."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer fastmcp-test-token",
"X-Custom-Header": "fastmcp-value",
"X-Request-Id": "req-123",
}

async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
async with ClientSession(*streams) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextServer"

# Test 1: Call tool that echoes headers
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)

headers_data = json.loads(headers_result.content[0].text)
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
assert headers_data.get("x-custom-header") == "fastmcp-value"
assert headers_data.get("x-request-id") == "req-123"

# Test 2: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)

context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
assert (
context_data["headers"].get("authorization")
== "Bearer fastmcp-test-token"
)
assert context_data["method"] == "POST" # SSE messages are POSTed
assert (
"/messages/" in context_data["url"]
) # Should contain the messages endpoint


@pytest.mark.anyio
async def test_fast_mcp_request_context_isolation(
context_aware_server: None, server_url: str
) -> None:
"""Test that request contexts are isolated between different FastMCP clients."""
contexts = []

# Create multiple clients with different headers
for i in range(3):
headers = {
"Authorization": f"Bearer token-{i}",
"X-Request-Id": f"fastmcp-req-{i}",
"X-Custom-Value": f"value-{i}",
}

async with sse_client(server_url + "/sse", headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()

# Call the tool that returns context
tool_result = await session.call_tool(
"echo_context", {"custom_request_id": f"test-req-{i}"}
)

# Parse and store the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)

# Verify each request had its own isolated context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["custom_request_id"] == f"test-req-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"


@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str
Expand Down
Loading
Loading