diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index da8158a98..f718df801 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,37 +1,17 @@ import contextlib import logging +from collections.abc import AsyncIterator import anyio import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import ( - StreamableHTTPServerTransport, -) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount +from starlette.types import Receive, Scope, Send logger = logging.getLogger(__name__) -# Global task group that will be initialized in the lifespan -task_group = None - - -@contextlib.asynccontextmanager -async def lifespan(app): - """Application lifespan context manager for managing task group.""" - global task_group - - async with anyio.create_task_group() as tg: - task_group = tg - logger.info("Application started, task group initialized!") - try: - yield - finally: - logger.info("Application shutting down, cleaning up resources...") - if task_group: - tg.cancel_scope.cancel() - task_group = None - logger.info("Resources cleaned up successfully.") @click.command() @@ -122,35 +102,28 @@ async def list_tools() -> list[types.Tool]: ) ] - # ASGI handler for stateless HTTP connections - async def handle_streamable_http(scope, receive, send): - logger.debug("Creating new transport") - # Use lock to prevent race conditions when creating new sessions - http_transport = StreamableHTTPServerTransport( - mcp_session_id=None, - is_json_response_enabled=json_response, - ) - async with http_transport.connect() as streams: - read_stream, write_stream = streams - - if not task_group: - raise RuntimeError("Task group is not initialized") - - async def run_server(): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - # Runs in standalone mode for stateless deployments - # where clients perform initialization with any node - standalone_mode=True, - ) - - # Start server task - task_group.start_soon(run_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + # Create the session manager with true stateless mode + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=None, + json_response=json_response, + stateless=True, + ) + + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + """Context manager for session manager.""" + async with session_manager.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") # Create an ASGI application using the transport starlette_app = Starlette( diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index d36686720..1a76097b5 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,58 +1,22 @@ import contextlib import logging -from http import HTTPStatus -from uuid import uuid4 +from collections.abc import AsyncIterator import anyio import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamable_http import ( - MCP_SESSION_ID_HEADER, - StreamableHTTPServerTransport, -) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount +from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore # Configure logging logger = logging.getLogger(__name__) -# Global task group that will be initialized in the lifespan -task_group = None - -# Event store for resumability -# The InMemoryEventStore enables resumability support for StreamableHTTP transport. -# It stores SSE events with unique IDs, allowing clients to: -# 1. Receive event IDs for each SSE message -# 2. Resume streams by sending Last-Event-ID in GET requests -# 3. Replay missed events after reconnection -# Note: This in-memory implementation is for demonstration ONLY. -# For production, use a persistent storage solution. -event_store = InMemoryEventStore() - - -@contextlib.asynccontextmanager -async def lifespan(app): - """Application lifespan context manager for managing task group.""" - global task_group - - async with anyio.create_task_group() as tg: - task_group = tg - logger.info("Application started, task group initialized!") - try: - yield - finally: - logger.info("Application shutting down, cleaning up resources...") - if task_group: - tg.cancel_scope.cancel() - task_group = None - logger.info("Resources cleaned up successfully.") - @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @@ -156,60 +120,38 @@ async def list_tools() -> list[types.Tool]: ) ] - # We need to store the server instances between requests - server_instances = {} - # Lock to prevent race conditions when creating new sessions - session_creation_lock = anyio.Lock() + # Create event store for resumability + # The InMemoryEventStore enables resumability support for StreamableHTTP transport. + # It stores SSE events with unique IDs, allowing clients to: + # 1. Receive event IDs for each SSE message + # 2. Resume streams by sending Last-Event-ID in GET requests + # 3. Replay missed events after reconnection + # Note: This in-memory implementation is for demonstration ONLY. + # For production, use a persistent storage solution. + event_store = InMemoryEventStore() + + # Create the session manager with our app and event store + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=event_store, # Enable resumability + json_response=json_response, + ) # ASGI handler for streamable HTTP connections - async def handle_streamable_http(scope, receive, send): - request = Request(scope, receive) - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) - if ( - request_mcp_session_id is not None - and request_mcp_session_id in server_instances - ): - transport = server_instances[request_mcp_session_id] - logger.debug("Session already exists, handling request directly") - await transport.handle_request(scope, receive, send) - elif request_mcp_session_id is None: - # try to establish new session - logger.debug("Creating new transport") - # Use lock to prevent race conditions when creating new sessions - async with session_creation_lock: - new_session_id = uuid4().hex - http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, - is_json_response_enabled=json_response, - event_store=event_store, # Enable resumability - ) - server_instances[http_transport.mcp_session_id] = http_transport - logger.info(f"Created new transport with session ID: {new_session_id}") - - async def run_server(task_status=None): - async with http_transport.connect() as streams: - read_stream, write_stream = streams - if task_status: - task_status.started() - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) - - if not task_group: - raise RuntimeError("Task group is not initialized") - - await task_group.start(run_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - else: - response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + """Context manager for managing session manager lifecycle.""" + async with session_manager.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") # Create an ASGI application using the transport starlette_app = Starlette( diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index ea0214f0f..c31f29d4c 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -47,6 +47,8 @@ from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, @@ -90,6 +92,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]): mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) sse_path: str = "/sse" message_path: str = "/messages/" + streamable_http_path: str = "/mcp" + + # StreamableHTTP settings + json_response: bool = False + stateless_http: bool = ( + False # If True, uses true stateless mode (new transport per request) + ) # resource settings warn_on_duplicate_resources: bool = True @@ -131,6 +140,7 @@ def __init__( instructions: str | None = None, auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, + event_store: EventStore | None = None, **settings: Any, ): self.settings = Settings(**settings) @@ -162,8 +172,10 @@ def __init__( "is specified" ) self._auth_server_provider = auth_server_provider + self._event_store = event_store self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies + self._session_manager: StreamableHTTPSessionManager | None = None # Set up MCP protocol handlers self._setup_handlers() @@ -179,25 +191,47 @@ def name(self) -> str: def instructions(self) -> str | None: return self._mcp_server.instructions + @property + def session_manager(self) -> StreamableHTTPSessionManager: + """Get the StreamableHTTP session manager. + + This is exposed to enable advanced use cases like mounting multiple + FastMCP servers in a single FastAPI application. + + Raises: + RuntimeError: If called before streamable_http_app() has been called. + """ + if self._session_manager is None: + raise RuntimeError( + "Session manager can only be accessed after" + "calling streamable_http_app()." + "The session manager is created lazily" + "to avoid unnecessary initialization." + ) + return self._session_manager + def run( self, - transport: Literal["stdio", "sse"] = "stdio", + transport: Literal["stdio", "sse", "streamable-http"] = "stdio", mount_path: str | None = None, ) -> None: """Run the FastMCP server. Note this is a synchronous function. Args: - transport: Transport protocol to use ("stdio" or "sse") + transport: Transport protocol to use ("stdio", "sse", or "streamable-http") mount_path: Optional mount path for SSE transport """ - TRANSPORTS = Literal["stdio", "sse"] + TRANSPORTS = Literal["stdio", "sse", "streamable-http"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") - if transport == "stdio": - anyio.run(self.run_stdio_async) - else: # transport == "sse" - anyio.run(lambda: self.run_sse_async(mount_path)) + match transport: + case "stdio": + anyio.run(self.run_stdio_async) + case "sse": + anyio.run(lambda: self.run_sse_async(mount_path)) + case "streamable-http": + anyio.run(self.run_streamable_http_async) def _setup_handlers(self) -> None: """Set up core MCP protocol handlers.""" @@ -573,6 +607,21 @@ async def run_sse_async(self, mount_path: str | None = None) -> None: server = uvicorn.Server(config) await server.serve() + async def run_streamable_http_async(self) -> None: + """Run the server using StreamableHTTP transport.""" + import uvicorn + + starlette_app = self.streamable_http_app() + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + def _normalize_path(self, mount_path: str, endpoint: str) -> str: """ Combine mount path and endpoint to return a normalized path. @@ -687,9 +736,9 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): else: # Auth is disabled, no need for RequireAuthMiddleware # Since handle_sse is an ASGI app, we need to create a compatible endpoint - async def sse_endpoint(request: Request) -> None: + async def sse_endpoint(request: Request) -> Response: # Convert the Starlette request to ASGI parameters - await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] + return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] routes.append( Route( @@ -712,6 +761,80 @@ async def sse_endpoint(request: Request) -> None: debug=self.settings.debug, routes=routes, middleware=middleware ) + def streamable_http_app(self) -> Starlette: + """Return an instance of the StreamableHTTP server app.""" + from starlette.middleware import Middleware + from starlette.routing import Mount + + # Create session manager on first call (lazy initialization) + if self._session_manager is None: + self._session_manager = StreamableHTTPSessionManager( + app=self._mcp_server, + event_store=self._event_store, + json_response=self.settings.json_response, + stateless=self.settings.stateless_http, # Use the stateless setting + ) + + # Create the ASGI handler + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await self.session_manager.handle_request(scope, receive, send) + + # Create routes + routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] + required_scopes = [] + + # Add auth endpoints if auth provider is configured + if self._auth_server_provider: + assert self.settings.auth + from mcp.server.auth.routes import create_auth_routes + + required_scopes = self.settings.auth.required_scopes or [] + + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend( + provider=self._auth_server_provider, + ), + ), + Middleware(AuthContextMiddleware), + ] + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + routes.append( + Mount( + self.settings.streamable_http_path, + app=RequireAuthMiddleware(handle_streamable_http, required_scopes), + ) + ) + else: + # Auth is disabled, no wrapper needed + routes.append( + Mount( + self.settings.streamable_http_path, + app=handle_streamable_http, + ) + ) + + routes.extend(self._custom_starlette_routes) + + return Starlette( + debug=self.settings.debug, + routes=routes, + middleware=middleware, + lifespan=lambda app: self.session_manager.run(), + ) + async def list_prompts(self) -> list[MCPPrompt]: """List all available prompts.""" prompts = self._prompt_manager.list_prompts() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py new file mode 100644 index 000000000..e5ef8b4aa --- /dev/null +++ b/src/mcp/server/streamable_http_manager.py @@ -0,0 +1,258 @@ +"""StreamableHTTP Session Manager for MCP servers.""" + +from __future__ import annotations + +import contextlib +import logging +import threading +from collections.abc import AsyncIterator +from http import HTTPStatus +from typing import Any +from uuid import uuid4 + +import anyio +from anyio.abc import TaskStatus +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.server.lowlevel.server import Server as MCPServer +from mcp.server.streamable_http import ( + MCP_SESSION_ID_HEADER, + EventStore, + StreamableHTTPServerTransport, +) + +logger = logging.getLogger(__name__) + + +class StreamableHTTPSessionManager: + """ + Manages StreamableHTTP sessions with optional resumability via event store. + + This class abstracts away the complexity of session management, event storage, + and request handling for StreamableHTTP transports. It handles: + + 1. Session tracking for clients + 2. Resumability via an optional event store + 3. Connection management and lifecycle + 4. Request handling and transport setup + + Important: Only one StreamableHTTPSessionManager instance should be created + per application. The instance cannot be reused after its run() context has + completed. If you need to restart the manager, create a new instance. + + Args: + app: The MCP server instance + event_store: Optional event store for resumability support. + If provided, enables resumable connections where clients + can reconnect and receive missed events. + If None, sessions are still tracked but not resumable. + json_response: Whether to use JSON responses instead of SSE streams + stateless: If True, creates a completely fresh transport for each request + with no session tracking or state persistence between requests. + + """ + + def __init__( + self, + app: MCPServer[Any], + event_store: EventStore | None = None, + json_response: bool = False, + stateless: bool = False, + ): + self.app = app + self.event_store = event_store + self.json_response = json_response + self.stateless = stateless + + # Session tracking (only used if not stateless) + self._session_creation_lock = anyio.Lock() + self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + + # The task group will be set during lifespan + self._task_group = None + # Thread-safe tracking of run() calls + self._run_lock = threading.Lock() + self._has_started = False + + @contextlib.asynccontextmanager + async def run(self) -> AsyncIterator[None]: + """ + Run the session manager with proper lifecycle management. + + This creates and manages the task group for all session operations. + + Important: This method can only be called once per instance. The same + StreamableHTTPSessionManager instance cannot be reused after this + context manager exits. Create a new instance if you need to restart. + + Use this in the lifespan context manager of your Starlette app: + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + """ + # Thread-safe check to ensure run() is only called once + with self._run_lock: + if self._has_started: + raise RuntimeError( + "StreamableHTTPSessionManager .run() can only be called " + "once per instance. Create a new instance if you need to run again." + ) + self._has_started = True + + async with anyio.create_task_group() as tg: + # Store the task group for later use + self._task_group = tg + logger.info("StreamableHTTP session manager started") + try: + yield # Let the application run + finally: + logger.info("StreamableHTTP session manager shutting down") + # Cancel task group to stop all spawned tasks + tg.cancel_scope.cancel() + self._task_group = None + # Clear any remaining server instances + self._server_instances.clear() + + async def handle_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process ASGI request with proper session handling and transport setup. + + Dispatches to the appropriate handler based on stateless mode. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + if self._task_group is None: + raise RuntimeError("Task group is not initialized. Make sure to use run().") + + # Dispatch to the appropriate handler + if self.stateless: + await self._handle_stateless_request(scope, receive, send) + else: + await self._handle_stateful_request(scope, receive, send) + + async def _handle_stateless_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateless mode - creating a new transport for each request. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + logger.debug("Stateless mode: Creating new transport for this request") + # No session ID needed in stateless mode + http_transport = StreamableHTTPServerTransport( + mcp_session_id=None, # No session tracking in stateless mode + is_json_response_enabled=self.json_response, + event_store=None, # No event store in stateless mode + ) + + # Start server in a new task + async def run_stateless_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=True, + ) + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_stateless_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + async def _handle_stateful_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateful mode - maintaining session state between requests. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Existing session case + if ( + request_mcp_session_id is not None + and request_mcp_session_id in self._server_instances + ): + transport = self._server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + return + + if request_mcp_session_id is None: + # New session case + logger.debug("Creating new transport") + async with self._session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, # May be None (no resumability) + ) + + assert http_transport.mcp_session_id is not None + self._server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") + + # Define the server runner + async def run_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, # Stateful mode + ) + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + else: + # Invalid session ID + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 281db2dbc..67911e9e7 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -15,6 +15,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP from mcp.types import InitializeResult, TextContent @@ -33,6 +34,34 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" +@pytest.fixture +def http_server_port() -> int: + """Get a free port for testing the StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def http_server_url(http_server_port: int) -> str: + """Get the StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{http_server_port}" + + +@pytest.fixture +def stateless_http_server_port() -> int: + """Get a free port for testing the stateless StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def stateless_http_server_url(stateless_http_server_port: int) -> str: + """Get the stateless StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{stateless_http_server_port}" + + # Create a function to make the FastMCP server app def make_fastmcp_app(): """Create a FastMCP server without auth settings.""" @@ -51,6 +80,40 @@ def echo(message: str) -> str: return mcp, app +def make_fastmcp_streamable_http_app(): + """Create a FastMCP server with StreamableHTTP transport.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="NoAuthServer") + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + + return mcp, app + + +def make_fastmcp_stateless_http_app(): + """Create a FastMCP server with stateless StreamableHTTP transport.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="StatelessServer", stateless_http=True) + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + + return mcp, app + + def run_server(server_port: int) -> None: """Run the server.""" _, app = make_fastmcp_app() @@ -63,6 +126,30 @@ def run_server(server_port: int) -> None: server.run() +def run_streamable_http_server(server_port: int) -> None: + """Run the StreamableHTTP server.""" + _, app = make_fastmcp_streamable_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting StreamableHTTP server on port {server_port}") + server.run() + + +def run_stateless_http_server(server_port: int) -> None: + """Run the stateless StreamableHTTP server.""" + _, app = make_fastmcp_stateless_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting stateless StreamableHTTP server on port {server_port}") + server.run() + + @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: """Start the server in a separate process and clean up after the test.""" @@ -94,6 +181,80 @@ def server(server_port: int) -> Generator[None, None, None]: print("Server process failed to terminate") +@pytest.fixture() +def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: + """Start the StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_streamable_http_server, args=(http_server_port,), daemon=True + ) + print("Starting StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"StreamableHTTP server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("StreamableHTTP server process failed to terminate") + + +@pytest.fixture() +def stateless_http_server( + stateless_http_server_port: int, +) -> Generator[None, None, None]: + """Start the stateless StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_stateless_http_server, + args=(stateless_http_server_port,), + daemon=True, + ) + print("Starting stateless StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for stateless StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", stateless_http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Stateless server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing stateless StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Stateless StreamableHTTP server process failed to terminate") + + @pytest.mark.anyio async def test_fastmcp_without_auth(server: None, server_url: str) -> None: """Test that FastMCP works when auth settings are not provided.""" @@ -110,3 +271,55 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_fastmcp_streamable_http( + streamable_http_server: None, http_server_url: str +) -> None: + """Test that FastMCP works with StreamableHTTP transport.""" + # Connect to the server using StreamableHTTP + async with streamablehttp_client(http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Test that we can call tools without authentication + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_fastmcp_stateless_streamable_http( + stateless_http_server: None, stateless_http_server_url: str +) -> None: + """Test that FastMCP works with stateless StreamableHTTP transport.""" + # Connect to the server using StreamableHTTP + async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "StatelessServer" + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + for i in range(3): + tool_result = await session.call_tool("echo", {"message": f"test_{i}"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == f"Echo: test_{i}" diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py new file mode 100644 index 000000000..32782e458 --- /dev/null +++ b/tests/server/test_streamable_http_manager.py @@ -0,0 +1,81 @@ +"""Tests for StreamableHTTPSessionManager.""" + +import anyio +import pytest + +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + + +@pytest.mark.anyio +async def test_run_can_only_be_called_once(): + """Test that run() can only be called once per instance.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + # First call should succeed + async with manager.run(): + pass + + # Second call should raise RuntimeError + with pytest.raises(RuntimeError) as excinfo: + async with manager.run(): + pass + + assert ( + "StreamableHTTPSessionManager .run() can only be called once per instance" + in str(excinfo.value) + ) + + +@pytest.mark.anyio +async def test_run_prevents_concurrent_calls(): + """Test that concurrent calls to run() are prevented.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + errors = [] + + async def try_run(): + try: + async with manager.run(): + # Simulate some work + await anyio.sleep(0.1) + except RuntimeError as e: + errors.append(e) + + # Try to run concurrently + async with anyio.create_task_group() as tg: + tg.start_soon(try_run) + tg.start_soon(try_run) + + # One should succeed, one should fail + assert len(errors) == 1 + assert ( + "StreamableHTTPSessionManager .run() can only be called once per instance" + in str(errors[0]) + ) + + +@pytest.mark.anyio +async def test_handle_request_without_run_raises_error(): + """Test that handle_request raises error if run() hasn't been called.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + # Mock ASGI parameters + scope = {"type": "http", "method": "POST", "path": "/test"} + + async def receive(): + return {"type": "http.request", "body": b""} + + async def send(message): + pass + + # Should raise error because run() hasn't been called + with pytest.raises(RuntimeError) as excinfo: + await manager.handle_request(scope, receive, send) + + assert "Task group is not initialized. Make sure to use run()." in str( + excinfo.value + ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b1dc7ea33..28d29ac23 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,13 +4,10 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ -import contextlib import multiprocessing import socket import time from collections.abc import Generator -from http import HTTPStatus -from uuid import uuid4 import anyio import httpx @@ -19,8 +16,6 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount import mcp.types as types @@ -37,6 +32,7 @@ StreamableHTTPServerTransport, StreamId, ) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.exceptions import McpError from mcp.shared.message import ( ClientMessageMetadata, @@ -184,7 +180,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: def create_app( is_json_response_enabled=False, event_store: EventStore | None = None ) -> Starlette: - """Create a Starlette application for testing that matches the example server. + """Create a Starlette application for testing using the session manager. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. @@ -193,85 +189,20 @@ def create_app( # Create server instance server = ServerTest() - server_instances = {} - # Lock to prevent race conditions when creating new sessions - session_creation_lock = anyio.Lock() - task_group = None - - @contextlib.asynccontextmanager - async def lifespan(app): - """Application lifespan context manager for managing task group.""" - nonlocal task_group - - async with anyio.create_task_group() as tg: - task_group = tg - try: - yield - finally: - if task_group: - tg.cancel_scope.cancel() - task_group = None - - async def handle_streamable_http(scope, receive, send): - request = Request(scope, receive) - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) - - # Use existing transport if session ID matches - if ( - request_mcp_session_id is not None - and request_mcp_session_id in server_instances - ): - transport = server_instances[request_mcp_session_id] - - await transport.handle_request(scope, receive, send) - elif request_mcp_session_id is None: - async with session_creation_lock: - new_session_id = uuid4().hex - - http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, - is_json_response_enabled=is_json_response_enabled, - event_store=event_store, - ) - - async def run_server(task_status=None): - async with http_transport.connect() as streams: - read_stream, write_stream = streams - if task_status: - task_status.started() - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - - if task_group is None: - response = Response( - "Internal Server Error: Task group is not initialized", - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - return - - # Store the instance before starting the task to prevent races - server_instances[http_transport.mcp_session_id] = http_transport - await task_group.start(run_server) - - await http_transport.handle_request(scope, receive, send) - else: - response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) + # Create the session manager + session_manager = StreamableHTTPSessionManager( + app=server, + event_store=event_store, + json_response=is_json_response_enabled, + ) - # Create an ASGI application + # Create an ASGI application that uses the session manager app = Starlette( debug=True, routes=[ - Mount("/mcp", app=handle_streamable_http), + Mount("/mcp", app=session_manager.handle_request), ], - lifespan=lifespan, + lifespan=lambda app: session_manager.run(), ) return app