Skip to content

Commit 81123c6

Browse files
committed
fast mcp tests
1 parent 567875e commit 81123c6

File tree

5 files changed

+183
-15
lines changed

5 files changed

+183
-15
lines changed

src/mcp/server/sse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55-
from mcp.shared.context import RequestData
5655
from mcp.shared.message import ServerMessageMetadata, SessionMessage
56+
from mcp.types import RequestData
5757

5858
logger = logging.getLogger(__name__)
5959

src/mcp/shared/context.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
from typing_extensions import TypeVar
55

66
from mcp.shared.session import BaseSession
7-
from mcp.types import RequestId, RequestParams
7+
from mcp.types import RequestData, RequestId, RequestParams
88

99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
1111

12-
# Type alias for request-specific data (e.g., headers, auth info)
13-
RequestData = dict[str, Any]
14-
1512

1613
@dataclass
1714
class RequestContext(Generic[SessionT, LifespanContextT]):

src/mcp/shared/message.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass
1010

11-
from mcp.shared.context import RequestData
12-
from mcp.types import JSONRPCMessage, RequestId
11+
from mcp.types import JSONRPCMessage, RequestData, RequestId
1312

1413
ResumptionToken = str
1514

src/mcp/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
Role = Literal["user", "assistant"]
3737
RequestId = str | int
3838
AnyFunction: TypeAlias = Callable[..., Any]
39+
# Http request-specific data (e.g., headers, auth info)
40+
RequestData = dict[str, Any]
3941

4042

4143
class RequestParams(BaseModel):

tests/server/fastmcp/test_integration.py

Lines changed: 178 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
including with and without authentication.
66
"""
77

8+
import json
89
import multiprocessing
910
import socket
1011
import time
@@ -13,13 +14,15 @@
1314
import pytest
1415
import uvicorn
1516
from pydantic import AnyUrl
17+
from starlette.applications import Starlette
1618

1719
import mcp.types as types
1820
from mcp.client.session import ClientSession
1921
from mcp.client.sse import sse_client
2022
from mcp.client.streamable_http import streamablehttp_client
2123
from mcp.server.fastmcp import FastMCP
2224
from mcp.server.fastmcp.resources import FunctionResource
25+
from mcp.server.fastmcp.server import Context
2326
from mcp.shared.context import RequestContext
2427
from mcp.types import (
2528
CreateMessageRequestParams,
@@ -78,8 +81,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str:
7881
# Create a function to make the FastMCP server app
7982
def make_fastmcp_app():
8083
"""Create a FastMCP server without auth settings."""
81-
from starlette.applications import Starlette
82-
8384
mcp = FastMCP(name="NoAuthServer")
8485

8586
# Add a simple tool
@@ -198,8 +199,6 @@ def complex_prompt(user_query: str, context: str = "general") -> str:
198199

199200
def make_everything_fastmcp_app():
200201
"""Create a comprehensive FastMCP server with SSE transport."""
201-
from starlette.applications import Starlette
202-
203202
mcp = make_everything_fastmcp()
204203
# Create the SSE app
205204
app: Starlette = mcp.sse_app()
@@ -208,7 +207,6 @@ def make_everything_fastmcp_app():
208207

209208
def make_fastmcp_streamable_http_app():
210209
"""Create a FastMCP server with StreamableHTTP transport."""
211-
from starlette.applications import Starlette
212210

213211
mcp = FastMCP(name="NoAuthServer")
214212

@@ -225,8 +223,6 @@ def echo(message: str) -> str:
225223

226224
def make_everything_fastmcp_streamable_http_app():
227225
"""Create a comprehensive FastMCP server with StreamableHTTP transport."""
228-
from starlette.applications import Starlette
229-
230226
# Create a new instance with different name for HTTP transport
231227
mcp = make_everything_fastmcp()
232228
# We can't change the name after creation, so we'll use the same name
@@ -237,7 +233,6 @@ def make_everything_fastmcp_streamable_http_app():
237233

238234
def make_fastmcp_stateless_http_app():
239235
"""Create a FastMCP server with stateless StreamableHTTP transport."""
240-
from starlette.applications import Starlette
241236

242237
mcp = FastMCP(name="StatelessServer", stateless_http=True)
243238

@@ -435,6 +430,181 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
435430
assert tool_result.content[0].text == "Echo: hello"
436431

437432

433+
def make_fastmcp_with_context_app():
434+
"""Create a FastMCP server that can access request context."""
435+
436+
mcp = FastMCP(name="ContextServer")
437+
438+
# Tool that echoes request headers
439+
@mcp.tool(description="Echo request headers from context")
440+
def echo_headers(ctx: Context) -> str:
441+
"""Returns the request headers as JSON."""
442+
headers_info = {}
443+
try:
444+
if ctx.request_context.request:
445+
headers_info = ctx.request_context.request.get("headers", {})
446+
except Exception:
447+
pass
448+
return json.dumps(headers_info)
449+
450+
# Tool that returns full request context
451+
@mcp.tool(description="Echo request context with custom data")
452+
def echo_context(custom_request_id: str, ctx: Context) -> str:
453+
"""Returns request context including headers and custom data."""
454+
context_data = {
455+
"custom_request_id": custom_request_id,
456+
"headers": {},
457+
"method": None,
458+
"url": None,
459+
}
460+
try:
461+
if ctx.request_context.request:
462+
context_data["headers"] = ctx.request_context.request.get("headers", {})
463+
context_data["method"] = ctx.request_context.request.get("method")
464+
context_data["url"] = ctx.request_context.request.get("url")
465+
except Exception:
466+
pass
467+
return json.dumps(context_data)
468+
469+
# Create the SSE app
470+
app: Starlette = mcp.sse_app()
471+
return mcp, app
472+
473+
474+
def run_context_server(server_port: int) -> None:
475+
"""Run the context-aware FastMCP server."""
476+
_, app = make_fastmcp_with_context_app()
477+
server = uvicorn.Server(
478+
config=uvicorn.Config(
479+
app=app, host="127.0.0.1", port=server_port, log_level="error"
480+
)
481+
)
482+
print(f"Starting context server on port {server_port}")
483+
server.run()
484+
485+
486+
@pytest.fixture()
487+
def context_aware_server(server_port: int) -> Generator[None, None, None]:
488+
"""Start the context-aware server in a separate process."""
489+
proc = multiprocessing.Process(
490+
target=run_context_server, args=(server_port,), daemon=True
491+
)
492+
print("Starting context-aware server process")
493+
proc.start()
494+
495+
# Wait for server to be running
496+
max_attempts = 20
497+
attempt = 0
498+
print("Waiting for context-aware server to start")
499+
while attempt < max_attempts:
500+
try:
501+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
502+
s.connect(("127.0.0.1", server_port))
503+
break
504+
except ConnectionRefusedError:
505+
time.sleep(0.1)
506+
attempt += 1
507+
else:
508+
raise RuntimeError(
509+
f"Context server failed to start after {max_attempts} attempts"
510+
)
511+
512+
yield
513+
514+
print("Killing context-aware server")
515+
proc.kill()
516+
proc.join(timeout=2)
517+
if proc.is_alive():
518+
print("Context server process failed to terminate")
519+
520+
521+
@pytest.mark.anyio
522+
async def test_fast_mcp_with_request_context(
523+
context_aware_server: None, server_url: str
524+
) -> None:
525+
"""Test that FastMCP properly propagates request context to tools."""
526+
# Test with custom headers
527+
custom_headers = {
528+
"Authorization": "Bearer fastmcp-test-token",
529+
"X-Custom-Header": "fastmcp-value",
530+
"X-Request-Id": "req-123",
531+
}
532+
533+
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
534+
async with ClientSession(*streams) as session:
535+
# Initialize the session
536+
result = await session.initialize()
537+
assert isinstance(result, InitializeResult)
538+
assert result.serverInfo.name == "ContextServer"
539+
540+
# Test 1: Call tool that echoes headers
541+
headers_result = await session.call_tool("echo_headers", {})
542+
assert len(headers_result.content) == 1
543+
assert isinstance(headers_result.content[0], TextContent)
544+
545+
headers_data = json.loads(headers_result.content[0].text)
546+
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
547+
assert headers_data.get("x-custom-header") == "fastmcp-value"
548+
assert headers_data.get("x-request-id") == "req-123"
549+
550+
# Test 2: Call tool that returns full context
551+
context_result = await session.call_tool(
552+
"echo_context", {"custom_request_id": "test-123"}
553+
)
554+
assert len(context_result.content) == 1
555+
assert isinstance(context_result.content[0], TextContent)
556+
557+
context_data = json.loads(context_result.content[0].text)
558+
assert context_data["custom_request_id"] == "test-123"
559+
assert (
560+
context_data["headers"].get("authorization")
561+
== "Bearer fastmcp-test-token"
562+
)
563+
assert context_data["method"] == "POST" # SSE messages are POSTed
564+
assert (
565+
"/messages/" in context_data["url"]
566+
) # Should contain the messages endpoint
567+
568+
569+
@pytest.mark.anyio
570+
async def test_fast_mcp_request_context_isolation(
571+
context_aware_server: None, server_url: str
572+
) -> None:
573+
"""Test that request contexts are isolated between different FastMCP clients."""
574+
contexts = []
575+
576+
# Create multiple clients with different headers
577+
for i in range(3):
578+
headers = {
579+
"Authorization": f"Bearer token-{i}",
580+
"X-Request-Id": f"fastmcp-req-{i}",
581+
"X-Custom-Value": f"value-{i}",
582+
}
583+
584+
async with sse_client(server_url + "/sse", headers=headers) as streams:
585+
async with ClientSession(*streams) as session:
586+
await session.initialize()
587+
588+
# Call the tool that returns context
589+
tool_result = await session.call_tool(
590+
"echo_context", {"custom_request_id": f"test-req-{i}"}
591+
)
592+
593+
# Parse and store the result
594+
assert len(tool_result.content) == 1
595+
assert isinstance(tool_result.content[0], TextContent)
596+
context_data = json.loads(tool_result.content[0].text)
597+
contexts.append(context_data)
598+
599+
# Verify each request had its own isolated context
600+
assert len(contexts) == 3
601+
for i, ctx in enumerate(contexts):
602+
assert ctx["custom_request_id"] == f"test-req-{i}"
603+
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
604+
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
605+
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
606+
607+
438608
@pytest.mark.anyio
439609
async def test_fastmcp_streamable_http(
440610
streamable_http_server: None, http_server_url: str

0 commit comments

Comments
 (0)