Skip to content

Support for http request injection propagation to tools #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def main():
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -555,6 +555,13 @@ async def _handle_request(

token = None
try:
# Extract request context from message metadata
request_data = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context

# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
Expand All @@ -563,6 +570,7 @@ async def _handle_request(
message.request_meta,
session,
lifespan_context,
request=request_data,
)
)
response = await handler(req)
Expand Down
6 changes: 4 additions & 2 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -203,7 +203,9 @@ async def handle_post_message(
await writer.send(err)
return

session_message = SessionMessage(message)
# Pass the ASGI scope for framework-agnostic access to request data
metadata = ServerMessageMetadata(request_context=dict(request.scope))
session_message = SessionMessage(message, metadata=metadata)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import TypeVar

from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams
from mcp.types import RequestData, RequestId, RequestParams

SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
Expand All @@ -16,3 +16,4 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestData | None = None
4 changes: 3 additions & 1 deletion src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass

from mcp.types import JSONRPCMessage, RequestId
from mcp.types import JSONRPCMessage, RequestData, RequestId

ResumptionToken = str

Expand All @@ -30,6 +30,8 @@ class ServerMessageMetadata:
"""Metadata specific to server messages."""

related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: RequestData | None = None


MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
Expand Down
3 changes: 3 additions & 0 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def __init__(
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
Expand Down Expand Up @@ -364,6 +366,7 @@ async def _receive_loop(self) -> None:
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)

self._in_flight[responder.request_id] = responder
Expand Down
3 changes: 3 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
Role = Literal["user", "assistant"]
RequestId = str | int
AnyFunction: TypeAlias = Callable[..., Any]
# Dictionary containing request metadata (headers, path, method, etc.
# based on ASGI scope for most of the trasport implementations)
RequestData: TypeAlias = dict[str, Any]


class RequestParams(BaseModel):
Expand Down
188 changes: 178 additions & 10 deletions tests/server/fastmcp/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
including with and without authentication.
"""

import json
import multiprocessing
import socket
import time
Expand All @@ -13,13 +14,15 @@
import pytest
import uvicorn
from pydantic import AnyUrl
from starlette.applications import Starlette

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
Expand Down Expand Up @@ -78,8 +81,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
# Create a function to make the FastMCP server app
def make_fastmcp_app():
"""Create a FastMCP server without auth settings."""
from starlette.applications import Starlette

mcp = FastMCP(name="NoAuthServer")

# Add a simple tool
Expand All @@ -88,7 +89,7 @@ def echo(message: str) -> str:
return f"Echo: {message}"

# Create the SSE app
app: Starlette = mcp.sse_app()
app = mcp.sse_app()

return mcp, app

Expand Down Expand Up @@ -198,17 +199,14 @@ def complex_prompt(user_query: str, context: str = "general") -> str:

def make_everything_fastmcp_app():
"""Create a comprehensive FastMCP server with SSE transport."""
from starlette.applications import Starlette

mcp = make_everything_fastmcp()
# Create the SSE app
app: Starlette = mcp.sse_app()
app = mcp.sse_app()
return mcp, app


def make_fastmcp_streamable_http_app():
"""Create a FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette

mcp = FastMCP(name="NoAuthServer")

Expand All @@ -225,8 +223,6 @@ def echo(message: str) -> str:

def make_everything_fastmcp_streamable_http_app():
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
from starlette.applications import Starlette

# Create a new instance with different name for HTTP transport
mcp = make_everything_fastmcp()
# We can't change the name after creation, so we'll use the same name
Expand All @@ -237,7 +233,6 @@ def make_everything_fastmcp_streamable_http_app():

def make_fastmcp_stateless_http_app():
"""Create a FastMCP server with stateless StreamableHTTP transport."""
from starlette.applications import Starlette

mcp = FastMCP(name="StatelessServer", stateless_http=True)

Expand Down Expand Up @@ -435,6 +430,179 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert tool_result.content[0].text == "Echo: hello"


def make_fastmcp_with_context_app():
"""Create a FastMCP server that can access request context."""

mcp = FastMCP(name="ContextServer")

# Tool that echoes request headers
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context) -> str:
"""Returns the request headers as JSON."""
headers_info = {}
if ctx.request_context.request:
# Extract headers from ASGI scope
headers_list = ctx.request_context.request.get("headers", [])
headers_info = {
k.decode("latin-1"): v.decode("latin-1") for k, v in headers_list
}
return json.dumps(headers_info)

# Tool that returns full request context
@mcp.tool(description="Echo request context with custom data")
def echo_context(custom_request_id: str, ctx: Context) -> str:
"""Returns request context including headers and custom data."""
context_data = {
"custom_request_id": custom_request_id,
"headers": {},
"method": None,
"path": None,
}
if ctx.request_context.request:
# Extract data from ASGI scope
headers_list = ctx.request_context.request.get("headers", [])
context_data["headers"] = {
k.decode("latin-1"): v.decode("latin-1") for k, v in headers_list
}
context_data["method"] = ctx.request_context.request.get("method")
return json.dumps(context_data)

# Create the SSE app
app = mcp.sse_app()
return mcp, app


def run_context_server(server_port: int) -> None:
"""Run the context-aware FastMCP server."""
_, app = make_fastmcp_with_context_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting context server on port {server_port}")
server.run()


@pytest.fixture()
def context_aware_server(server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(
target=run_context_server, args=(server_port,), daemon=True
)
print("Starting context-aware server process")
proc.start()

# Wait for server to be running
max_attempts = 20
attempt = 0
print("Waiting for context-aware server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Context server failed to start after {max_attempts} attempts"
)

yield

print("Killing context-aware server")
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("Context server process failed to terminate")


@pytest.mark.anyio
async def test_fast_mcp_with_request_context(
context_aware_server: None, server_url: str
) -> None:
"""Test that FastMCP properly propagates request context to tools."""
# Test with custom headers
custom_headers = {
"Authorization": "Bearer fastmcp-test-token",
"X-Custom-Header": "fastmcp-value",
"X-Request-Id": "req-123",
}

async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
async with ClientSession(*streams) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "ContextServer"

# Test 1: Call tool that echoes headers
headers_result = await session.call_tool("echo_headers", {})
assert len(headers_result.content) == 1
assert isinstance(headers_result.content[0], TextContent)

headers_data = json.loads(headers_result.content[0].text)
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
assert headers_data.get("x-custom-header") == "fastmcp-value"
assert headers_data.get("x-request-id") == "req-123"

# Test 2: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)

context_data = json.loads(context_result.content[0].text)
assert context_data["custom_request_id"] == "test-123"
assert (
context_data["headers"].get("authorization")
== "Bearer fastmcp-test-token"
)
assert context_data["method"] == "POST" #


@pytest.mark.anyio
async def test_fast_mcp_request_context_isolation(
context_aware_server: None, server_url: str
) -> None:
"""Test that request contexts are isolated between different FastMCP clients."""
contexts = []

# Create multiple clients with different headers
for i in range(3):
headers = {
"Authorization": f"Bearer token-{i}",
"X-Request-Id": f"fastmcp-req-{i}",
"X-Custom-Value": f"value-{i}",
}

async with sse_client(server_url + "/sse", headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()

# Call the tool that returns context
tool_result = await session.call_tool(
"echo_context", {"custom_request_id": f"test-req-{i}"}
)

# Parse and store the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
context_data = json.loads(tool_result.content[0].text)
contexts.append(context_data)

# Verify each request had its own isolated context
assert len(contexts) == 3
for i, ctx in enumerate(contexts):
assert ctx["custom_request_id"] == f"test-req-{i}"
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"


@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str
Expand Down
Loading
Loading