Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/mcp/server/fastmcp/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

30 changes: 30 additions & 0 deletions src/mcp/server/fastmcp/middlewares/cors_middleware.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starlette has a CORSMiddleware class. You don't need to implement it by yourself...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks missed that

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp


class CORSMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, allow_origins=None, allow_methods=None, allow_headers=None, max_age=None):
super().__init__(app)
self.allow_origins = allow_origins or ["*"]
self.allow_methods = allow_methods or ["GET", "POST", "OPTIONS"]
self.allow_headers = allow_headers or ["*"]
self.max_age = max_age or "3600"

async def dispatch(self, request: Request, call_next):
# Handle OPTIONS method for CORS preflight requests
if request.method == "OPTIONS":
response = Response()
response.headers["Access-Control-Allow-Origin"] = ",".join(self.allow_origins)
response.headers["Access-Control-Allow-Methods"] = ",".join(self.allow_methods)
response.headers["Access-Control-Allow-Headers"] = ",".join(self.allow_headers)
response.headers["Access-Control-Max-Age"] = self.max_age
return response

# Process the request normally and then add CORS headers to the response
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = ",".join(self.allow_origins)
response.headers["Access-Control-Allow-Methods"] = ",".join(self.allow_methods)
response.headers["Access-Control-Allow-Headers"] = ",".join(self.allow_headers)
return response
15 changes: 13 additions & 2 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.fastmcp.middlewares.cors_middleware import CORSMiddleware
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.lowlevel.server import Server as MCPServer
Expand All @@ -49,7 +50,6 @@
from mcp.types import Resource as MCPResource
from mcp.types import ResourceTemplate as MCPResourceTemplate
from mcp.types import Tool as MCPTool

logger = get_logger(__name__)


Expand All @@ -75,6 +75,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
port: int = 8000
sse_path: str = "/sse"
message_path: str = "/messages/"
cors_enabled: bool = False
cors_allow_origins: list[str] = ["*"]
cors_allow_methods: list[str] = ["GET", "POST", "OPTIONS"]
cors_allow_headers: list[str] = ["*"]
cors_max_age: str = "3600"

# resource settings
warn_on_duplicate_resources: bool = True
Expand Down Expand Up @@ -467,7 +472,13 @@ async def run_stdio_async(self) -> None:
async def run_sse_async(self) -> None:
"""Run the server using SSE transport."""
starlette_app = self.sse_app()

starlette_app.add_middleware(
CORSMiddleware,
allow_origins=self.settings.cors_allow_origins,
allow_methods=self.settings.cors_allow_methods,
allow_headers=self.settings.cors_allow_headers,
max_age=self.settings.cors_max_age,
)
config = uvicorn.Config(
starlette_app,
host=self.settings.host,
Expand Down
Loading