diff --git a/src/mcp/server/fastmcp/middlewares/__init__.py b/src/mcp/server/fastmcp/middlewares/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/mcp/server/fastmcp/middlewares/__init__.py @@ -0,0 +1 @@ + diff --git a/src/mcp/server/fastmcp/middlewares/cors_middleware.py b/src/mcp/server/fastmcp/middlewares/cors_middleware.py new file mode 100644 index 000000000..17ea4a261 --- /dev/null +++ b/src/mcp/server/fastmcp/middlewares/cors_middleware.py @@ -0,0 +1,30 @@ +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp + + +class CORSMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, allow_origins=None, allow_methods=None, allow_headers=None, max_age=None): + super().__init__(app) + self.allow_origins = allow_origins or ["*"] + self.allow_methods = allow_methods or ["GET", "POST", "OPTIONS"] + self.allow_headers = allow_headers or ["*"] + self.max_age = max_age or "3600" + + async def dispatch(self, request: Request, call_next): + # Handle OPTIONS method for CORS preflight requests + if request.method == "OPTIONS": + response = Response() + response.headers["Access-Control-Allow-Origin"] = ",".join(self.allow_origins) + response.headers["Access-Control-Allow-Methods"] = ",".join(self.allow_methods) + response.headers["Access-Control-Allow-Headers"] = ",".join(self.allow_headers) + response.headers["Access-Control-Max-Age"] = self.max_age + return response + + # Process the request normally and then add CORS headers to the response + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = ",".join(self.allow_origins) + response.headers["Access-Control-Allow-Methods"] = ",".join(self.allow_methods) + response.headers["Access-Control-Allow-Headers"] = ",".join(self.allow_headers) + return response diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bf0ce880a..e7885ad7f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -29,6 +29,7 @@ from mcp.server.fastmcp.tools import ToolManager from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.fastmcp.utilities.types import Image +from mcp.server.fastmcp.middlewares.cors_middleware import CORSMiddleware from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer @@ -49,7 +50,6 @@ from mcp.types import Resource as MCPResource from mcp.types import ResourceTemplate as MCPResourceTemplate from mcp.types import Tool as MCPTool - logger = get_logger(__name__) @@ -75,6 +75,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): port: int = 8000 sse_path: str = "/sse" message_path: str = "/messages/" + cors_enabled: bool = False + cors_allow_origins: list[str] = ["*"] + cors_allow_methods: list[str] = ["GET", "POST", "OPTIONS"] + cors_allow_headers: list[str] = ["*"] + cors_max_age: str = "3600" # resource settings warn_on_duplicate_resources: bool = True @@ -467,7 +472,13 @@ async def run_stdio_async(self) -> None: async def run_sse_async(self) -> None: """Run the server using SSE transport.""" starlette_app = self.sse_app() - + starlette_app.add_middleware( + CORSMiddleware, + allow_origins=self.settings.cors_allow_origins, + allow_methods=self.settings.cors_allow_methods, + allow_headers=self.settings.cors_allow_headers, + max_age=self.settings.cors_max_age, + ) config = uvicorn.Config( starlette_app, host=self.settings.host,