From 366b3c46f2d2296c7f93f2a47ffa9aebf1ef4c77 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Fri, 30 May 2025 09:13:09 -0700 Subject: [PATCH 1/7] Add support for DNS rebinding protections --- src/mcp/server/fastmcp/server.py | 6 + src/mcp/server/sse.py | 26 +- src/mcp/server/streamable_http.py | 16 + src/mcp/server/streamable_http_manager.py | 5 + src/mcp/server/transport_security.py | 135 +++++++ tests/server/fastmcp/test_integration.py | 216 ++++++++++- tests/server/test_sse_security.py | 342 ++++++++++++++++++ tests/server/test_streamable_http_security.py | 333 +++++++++++++++++ tests/shared/test_sse.py | 198 +++++++++- tests/shared/test_streamable_http.py | 14 +- 10 files changed, 1276 insertions(+), 15 deletions(-) create mode 100644 src/mcp/server/transport_security.py create mode 100644 tests/server/test_sse_security.py create mode 100644 tests/server/test_streamable_http_security.py diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae61..4cbc589281 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -49,6 +49,7 @@ from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, @@ -119,6 +120,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): ) = Field(None, description="Lifespan context manager") auth: AuthSettings | None = None + + # Transport security settings (DNS rebinding protection) + transport_security: TransportSecuritySettings | None = None def lifespan_wrapper( @@ -670,6 +674,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette: sse = SseServerTransport( normalized_message_endpoint, + security_settings=self.settings.transport_security, ) async def handle_sse(scope: Scope, receive: Receive, send: Send): @@ -777,6 +782,7 @@ def streamable_http_app(self) -> Starlette: event_store=self._event_store, json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting + security_settings=self.settings.transport_security, ) # Create the ASGI handler diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index bae2bbf52e..782578b785 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,6 +52,10 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.server.transport_security import ( + TransportSecurityMiddleware, + TransportSecuritySettings, +) from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -71,16 +75,24 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _security: TransportSecurityMiddleware - def __init__(self, endpoint: str) -> None: + def __init__( + self, endpoint: str, security_settings: TransportSecuritySettings | None = None + ) -> None: """ Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. + + Args: + endpoint: The relative or absolute URL for POST messages. + security_settings: Optional security settings for DNS rebinding protection. """ super().__init__() self._endpoint = endpoint self._read_stream_writers = {} + self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager @@ -89,6 +101,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") + # Validate request headers for DNS rebinding protection + request = Request(scope, receive) + error_response = await self._security.validate_request(request, is_post=False) + if error_response: + await error_response(scope, receive, send) + raise ValueError("Request validation failed") + logger.debug("Setting up SSE connection") read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] @@ -169,6 +188,11 @@ async def handle_post_message( ) -> None: logger.debug("Handling POST message") request = Request(scope, receive) + + # Validate request headers for DNS rebinding protection + error_response = await self._security.validate_request(request, is_post=True) + if error_response: + return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") if session_id_param is None: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8f4a1f5126..29c9491b42 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,6 +24,10 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.transport_security import ( + TransportSecurityMiddleware, + TransportSecuritySettings, +) from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.types import ( INTERNAL_ERROR, @@ -131,12 +135,14 @@ class StreamableHTTPServerTransport: _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None _write_stream: MemoryObjectSendStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _security: TransportSecurityMiddleware def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, event_store: EventStore | None = None, + security_settings: TransportSecuritySettings | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -149,6 +155,7 @@ def __init__( event_store: Event store for resumability support. If provided, resumability will be enabled, allowing clients to reconnect and resume messages. + security_settings: Optional security settings for DNS rebinding protection. Raises: ValueError: If the session ID contains invalid characters. @@ -163,6 +170,7 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store + self._security = TransportSecurityMiddleware(security_settings) self._request_streams: dict[ RequestId, tuple[ @@ -260,6 +268,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) + + # Validate request headers for DNS rebinding protection + is_post = request.method == "POST" + error_response = await self._security.validate_request(request, is_post=is_post) + if error_response: + await error_response(scope, receive, send) + return + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index e5ef8b4aa1..3ec1fd168c 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -22,6 +22,7 @@ EventStore, StreamableHTTPServerTransport, ) +from mcp.server.transport_security import TransportSecuritySettings logger = logging.getLogger(__name__) @@ -60,11 +61,13 @@ def __init__( event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, + security_settings: TransportSecuritySettings | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless + self.security_settings = security_settings # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -162,6 +165,7 @@ async def _handle_stateless_request( mcp_session_id=None, # No session tracking in stateless mode is_json_response_enabled=self.json_response, event_store=None, # No event store in stateless mode + security_settings=self.security_settings, ) # Start server in a new task @@ -222,6 +226,7 @@ async def _handle_stateful_request( mcp_session_id=new_session_id, is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) + security_settings=self.security_settings, ) assert http_transport.mcp_session_id is not None diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py new file mode 100644 index 0000000000..63cacbc9e7 --- /dev/null +++ b/src/mcp/server/transport_security.py @@ -0,0 +1,135 @@ +"""DNS rebinding protection for MCP server transports.""" + +import logging + +from pydantic import BaseModel, Field +from starlette.requests import Request +from starlette.responses import Response + +logger = logging.getLogger(__name__) + + +class TransportSecuritySettings(BaseModel): + """Settings for MCP transport security features. + + These settings help protect against DNS rebinding attacks by validating + incoming request headers. + """ + + enable_dns_rebinding_protection: bool = Field( + default=True, + description="Enable DNS rebinding protection (recommended for production)" + ) + + allowed_hosts: list[str] = Field( + default=[], + description="List of allowed Host header values. If None, all hosts " + "are allowed when protection is disabled, or only localhost/127.0.0.1 " + "when enabled." + ) + + allowed_origins: list[str] = Field( + default=[], + description="List of allowed Origin header values. If None, all " + "origins are allowed when protection is disabled, or only localhost " + "origins when enabled." + ) + + +class TransportSecurityMiddleware: + """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" + + def __init__(self, settings: TransportSecuritySettings | None = None): + # If not specified, disable DNS rebinding protection by default + # for backwards compatibility + self.settings = settings or TransportSecuritySettings( + enable_dns_rebinding_protection=False + ) + + def _validate_host(self, host: str | None) -> bool: + """Validate the Host header against allowed values.""" + if not self.settings.enable_dns_rebinding_protection: + return True + + if not host: + logger.warning("Missing Host header in request") + return False + + # Check exact match first + if host in self.settings.allowed_hosts: + return True + + # Check wildcard port patterns + for allowed in self.settings.allowed_hosts: + if allowed.endswith(":*"): + # Extract base host from pattern + base_host = allowed[:-2] + # Check if the actual host starts with base host and has a port + if host.startswith(base_host + ":"): + return True + + logger.warning(f"Invalid Host header: {host}") + return False + + def _validate_origin(self, origin: str | None) -> bool: + """Validate the Origin header against allowed values.""" + if not self.settings.enable_dns_rebinding_protection: + return True + + # Origin can be absent for same-origin requests + if not origin: + return True + + # Check exact match first + if origin in self.settings.allowed_origins: + return True + + # Check wildcard port patterns + for allowed in self.settings.allowed_origins: + if allowed.endswith(":*"): + # Extract base origin from pattern + base_origin = allowed[:-2] + # Check if the actual origin starts with base origin and has a port + if origin.startswith(base_origin + ":"): + return True + + logger.warning(f"Invalid Origin header: {origin}") + return False + + def _validate_content_type(self, content_type: str | None) -> bool: + """Validate the Content-Type header for POST requests.""" + if not content_type: + logger.warning("Missing Content-Type header in POST request") + return False + + # Content-Type must start with application/json + if not content_type.lower().startswith("application/json"): + logger.warning(f"Invalid Content-Type header: {content_type}") + return False + + return True + + async def validate_request( + self, request: Request, is_post: bool = False + ) -> Response | None: + """Validate request headers for DNS rebinding protection. + + Returns None if validation passes, or an error Response if validation fails. + """ + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=400) + + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=400) + + # Validate Content-Type for POST requests + if is_post: + content_type = request.headers.get("content-type") + if not self._validate_content_type(content_type): + return Response("Invalid Content-Type header", status_code=400) + + return None \ No newline at end of file diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 79285ecb1b..eb4a031631 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -78,9 +78,13 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str: # Create a function to make the FastMCP server app def make_fastmcp_app(): """Create a FastMCP server without auth settings.""" - from starlette.applications import Starlette - - mcp = FastMCP(name="NoAuthServer") + from mcp.server.transport_security import TransportSecuritySettings + + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) # Add a simple tool @mcp.tool(description="A simple echo tool") @@ -96,8 +100,13 @@ def echo(message: str) -> str: def make_everything_fastmcp() -> FastMCP: """Create a FastMCP server with all features enabled for testing.""" from mcp.server.fastmcp import Context - - mcp = FastMCP(name="EverythingServer") + from mcp.server.transport_security import TransportSecuritySettings + + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="EverythingServer", transport_security=transport_security) # Tool with context for logging and progress @mcp.tool(description="A tool that demonstrates logging and progress") @@ -208,9 +217,13 @@ def make_everything_fastmcp_app(): def make_fastmcp_streamable_http_app(): """Create a FastMCP server with StreamableHTTP transport.""" - from starlette.applications import Starlette - - mcp = FastMCP(name="NoAuthServer") + from mcp.server.transport_security import TransportSecuritySettings + + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) # Add a simple tool @mcp.tool(description="A simple echo tool") @@ -237,9 +250,17 @@ def make_everything_fastmcp_streamable_http_app(): def make_fastmcp_stateless_http_app(): """Create a FastMCP server with stateless StreamableHTTP transport.""" - from starlette.applications import Starlette - - mcp = FastMCP(name="StatelessServer", stateless_http=True) + from mcp.server.transport_security import TransportSecuritySettings + + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP( + name="StatelessServer", + stateless_http=True, + transport_security=transport_security + ) # Add a simple tool @mcp.tool(description="A simple echo tool") @@ -435,6 +456,179 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert tool_result.content[0].text == "Echo: hello" +def make_fastmcp_with_context_app(): + """Create a FastMCP server that can access request context.""" + from mcp.server.transport_security import TransportSecuritySettings + + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="ContextServer", transport_security=transport_security) + + # Tool that echoes request headers + @mcp.tool(description="Echo request headers from context") + def echo_headers(ctx: Context[Any, Any, Request]) -> str: + """Returns the request headers as JSON.""" + headers_info = {} + if ctx.request_context.request: + # Now the type system knows request is a Starlette Request object + headers_info = dict(ctx.request_context.request.headers) + return json.dumps(headers_info) + + # Tool that returns full request context + @mcp.tool(description="Echo request context with custom data") + def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: + """Returns request context including headers and custom data.""" + context_data = { + "custom_request_id": custom_request_id, + "headers": {}, + "method": None, + "path": None, + } + if ctx.request_context.request: + request = ctx.request_context.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return json.dumps(context_data) + + # Create the SSE app + app = mcp.sse_app() + return mcp, app + + +def run_context_server(server_port: int) -> None: + """Run the context-aware FastMCP server.""" + _, app = make_fastmcp_with_context_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting context server on port {server_port}") + server.run() + + +@pytest.fixture() +def context_aware_server(server_port: int) -> Generator[None, None, None]: + """Start the context-aware server in a separate process.""" + proc = multiprocessing.Process( + target=run_context_server, args=(server_port,), daemon=True + ) + print("Starting context-aware server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for context-aware server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Context server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing context-aware server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Context server process failed to terminate") + + +@pytest.mark.anyio +async def test_fast_mcp_with_request_context( + context_aware_server: None, server_url: str +) -> None: + """Test that FastMCP properly propagates request context to tools.""" + # Test with custom headers + custom_headers = { + "Authorization": "Bearer fastmcp-test-token", + "X-Custom-Header": "fastmcp-value", + "X-Request-Id": "req-123", + } + + async with sse_client(server_url + "/sse", headers=custom_headers) as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "ContextServer" + + # Test 1: Call tool that echoes headers + headers_result = await session.call_tool("echo_headers", {}) + assert len(headers_result.content) == 1 + assert isinstance(headers_result.content[0], TextContent) + + headers_data = json.loads(headers_result.content[0].text) + assert headers_data.get("authorization") == "Bearer fastmcp-test-token" + assert headers_data.get("x-custom-header") == "fastmcp-value" + assert headers_data.get("x-request-id") == "req-123" + + # Test 2: Call tool that returns full context + context_result = await session.call_tool( + "echo_context", {"custom_request_id": "test-123"} + ) + assert len(context_result.content) == 1 + assert isinstance(context_result.content[0], TextContent) + + context_data = json.loads(context_result.content[0].text) + assert context_data["custom_request_id"] == "test-123" + assert ( + context_data["headers"].get("authorization") + == "Bearer fastmcp-test-token" + ) + assert context_data["method"] == "POST" # + + +@pytest.mark.anyio +async def test_fast_mcp_request_context_isolation( + context_aware_server: None, server_url: str +) -> None: + """Test that request contexts are isolated between different FastMCP clients.""" + contexts = [] + + # Create multiple clients with different headers + for i in range(3): + headers = { + "Authorization": f"Bearer token-{i}", + "X-Request-Id": f"fastmcp-req-{i}", + "X-Custom-Value": f"value-{i}", + } + + async with sse_client(server_url + "/sse", headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # Call the tool that returns context + tool_result = await session.call_tool( + "echo_context", {"custom_request_id": f"test-req-{i}"} + ) + + # Parse and store the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + context_data = json.loads(tool_result.content[0].text) + contexts.append(context_data) + + # Verify each request had its own isolated context + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): + assert ctx["custom_request_id"] == f"test-req-{i}" + assert ctx["headers"].get("authorization") == f"Bearer token-{i}" + assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}" + + @pytest.mark.anyio async def test_fastmcp_streamable_http( streamable_http_server: None, http_server_url: str diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py new file mode 100644 index 0000000000..a260f32103 --- /dev/null +++ b/tests/server/test_sse_security.py @@ -0,0 +1,342 @@ +"""Tests for SSE server DNS rebinding protection.""" + +import logging +import multiprocessing +import socket +import time + +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import Tool + +logger = logging.getLogger(__name__) +SERVER_NAME = "test_sse_security_server" + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +class SecurityTestServer(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + async def on_list_tools(self) -> list[Tool]: + return [] + + +def run_server_with_settings( + port: int, security_settings: TransportSecuritySettings | None = None +): + """Run the SSE server with specified security settings.""" + app = SecurityTestServer() + sse_transport = SseServerTransport("/messages/", security_settings) + + async def handle_sse(request: Request): + try: + async with sse_transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + if streams: + await app.run( + streams[0], streams[1], app.create_initialization_options() + ) + except ValueError as e: + # Validation error was already handled inside connect_sse + logger.debug(f"SSE connection failed validation: {e}") + return Response() + + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + + starlette_app = Starlette(routes=routes) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + + +def start_server_process( + port: int, security_settings: TransportSecuritySettings | None = None +): + """Start server in a separate process.""" + process = multiprocessing.Process( + target=run_server_with_settings, args=(port, security_settings) + ) + process.start() + # Give server time to start + time.sleep(1) + return process + + +@pytest.mark.anyio +async def test_sse_security_default_settings(server_port: int): + """Test SSE with default security settings (protection disabled).""" + process = start_server_process(server_port) + + try: + headers = {"Host": "evil.com", "Origin": "http://evil.com"} + + async with httpx.AsyncClient(timeout=5.0) as client: + async with client.stream( + "GET", + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) as response: + assert response.status_code == 200 + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_invalid_host_header(server_port: int): + """Test SSE with invalid Host header.""" + # Enable security by providing settings with an empty allowed_hosts list + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["example.com"] + ) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid host header + headers = {"Host": "evil.com"} + + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) + assert response.status_code == 400 + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_invalid_origin_header(server_port: int): + """Test SSE with invalid Origin header.""" + # Configure security to allow the host but restrict origins + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1:*"], + allowed_origins=["http://localhost:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid origin header + headers = {"Origin": "http://evil.com"} + + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) + assert response.status_code == 400 + assert response.text == "Invalid Origin header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_post_invalid_content_type(server_port: int): + """Test POST endpoint with invalid Content-Type header.""" + # Configure security to allow the host + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1:*"], + allowed_origins=["http://127.0.0.1:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + # Test POST with invalid content type + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test" + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post( + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + content="test" + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_disabled(server_port: int): + """Test SSE with security disabled.""" + settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) + process = start_server_process(server_port, settings) + + try: + # Test with invalid host header - should still work + headers = {"Host": "evil.com"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream( + "GET", + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) as response: + # Should connect successfully even with invalid host + assert response.status_code == 200 + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_custom_allowed_hosts(server_port: int): + """Test SSE with custom allowed hosts.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost", "127.0.0.1", "custom.host"], + allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"] + ) + process = start_server_process(server_port, settings) + + try: + # Test with custom allowed host + headers = {"Host": "custom.host"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream( + "GET", + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) as response: + # Should connect successfully with custom host + assert response.status_code == 200 + + # Test with non-allowed host + headers = {"Host": "evil.com"} + + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) + assert response.status_code == 400 + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_wildcard_ports(server_port: int): + """Test SSE with wildcard port patterns.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost:*", "127.0.0.1:*"], + allowed_origins=["http://localhost:*", "http://127.0.0.1:*"] + ) + process = start_server_process(server_port, settings) + + try: + # Test with various port numbers + for test_port in [8080, 3000, 9999]: + headers = {"Host": f"localhost:{test_port}"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream( + "GET", + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) as response: + # Should connect successfully with any port + assert response.status_code == 200 + + headers = {"Origin": f"http://localhost:{test_port}"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream( + "GET", + f"http://127.0.0.1:{server_port}/sse", + headers=headers + ) as response: + # Should connect successfully with any port + assert response.status_code == 200 + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_post_valid_content_type(server_port: int): + """Test POST endpoint with valid Content-Type headers.""" + # Configure security to allow the host + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1:*"], + allowed_origins=["http://127.0.0.1:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + async with httpx.AsyncClient() as client: + # Test with various valid content types + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + + for content_type in valid_content_types: + # Use a valid UUID format (even though session won't exist) + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"} + ) + # Will get 404 because session doesn't exist, but that's OK + # We're testing that it passes the content-type check + assert response.status_code == 404 + assert response.text == "Could not find session" + + finally: + process.terminate() + process.join() \ No newline at end of file diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py new file mode 100644 index 0000000000..3330bcfe27 --- /dev/null +++ b/tests/server/test_streamable_http_security.py @@ -0,0 +1,333 @@ +"""Tests for StreamableHTTP server DNS rebinding protection.""" + +import logging +import multiprocessing +import socket +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +from mcp.server import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import Tool + +logger = logging.getLogger(__name__) +SERVER_NAME = "test_streamable_http_security_server" + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +class SecurityTestServer(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + async def on_list_tools(self) -> list[Tool]: + return [] + + +def run_server_with_settings( + port: int, security_settings: TransportSecuritySettings | None = None +): + """Run the StreamableHTTP server with specified security settings.""" + app = SecurityTestServer() + + # Create session manager with security settings + session_manager = StreamableHTTPSessionManager( + app=app, + json_response=False, + stateless=False, + security_settings=security_settings, + ) + + # Create the ASGI handler + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + # Create Starlette app with lifespan + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: + async with session_manager.run(): + yield + + routes = [ + Mount("/", app=handle_streamable_http), + ] + + starlette_app = Starlette(routes=routes, lifespan=lifespan) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + + +def start_server_process( + port: int, security_settings: TransportSecuritySettings | None = None +): + """Start server in a separate process.""" + process = multiprocessing.Process( + target=run_server_with_settings, args=(port, security_settings) + ) + process.start() + # Give server time to start + time.sleep(1) + return process + + +@pytest.mark.anyio +async def test_streamable_http_security_default_settings(server_port: int): + """Test StreamableHTTP with default security settings (protection enabled).""" + process = start_server_process(server_port) + + try: + # Test with valid localhost headers + async with httpx.AsyncClient(timeout=5.0) as client: + # POST request to initialize session + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + ) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_invalid_host_header(server_port: int): + """Test StreamableHTTP with invalid Host header.""" + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True + ) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid host header + headers = { + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers + ) + assert response.status_code == 400 + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_invalid_origin_header(server_port: int): + """Test StreamableHTTP with invalid Origin header.""" + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid origin header + headers = { + "Origin": "http://evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers + ) + assert response.status_code == 400 + assert response.text == "Invalid Origin header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_invalid_content_type(server_port: int): + """Test StreamableHTTP POST with invalid Content-Type header.""" + process = start_server_process(server_port) + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + # Test POST with invalid content type + response = await client.post( + f"http://127.0.0.1:{server_port}/", + headers={ + "Content-Type": "text/plain", + "Accept": "application/json, text/event-stream", + }, + content="test" + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post( + f"http://127.0.0.1:{server_port}/", + headers={"Accept": "application/json, text/event-stream"}, + content="test" + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_disabled(server_port: int): + """Test StreamableHTTP with security disabled.""" + settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) + process = start_server_process(server_port, settings) + + try: + # Test with invalid host header - should still work + headers = { + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers + ) + # Should connect successfully even with invalid host + assert response.status_code == 200 + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_custom_allowed_hosts(server_port: int): + """Test StreamableHTTP with custom allowed hosts.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost", "127.0.0.1", "custom.host"], + allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"] + ) + process = start_server_process(server_port, settings) + + try: + # Test with custom allowed host + headers = { + "Host": "custom.host", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers + ) + # Should connect successfully with custom host + assert response.status_code == 200 + + # Test with non-allowed host + headers = { + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers + ) + assert response.status_code == 400, response.text + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_get_request(server_port: int): + """Test StreamableHTTP GET request with security.""" + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1"] + ) + process = start_server_process(server_port, security_settings) + + try: + # Test GET request with invalid host header + headers = { + "Host": "evil.com", + "Accept": "text/event-stream", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.get( + f"http://127.0.0.1:{server_port}/", + headers=headers + ) + assert response.status_code == 400 + assert response.text == "Invalid Host header" + + # Test GET request with valid host header + headers = { + "Host": "127.0.0.1", + "Accept": "text/event-stream", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + # GET requests need a session ID in StreamableHTTP + # So it will fail with "Missing session ID" not security error + response = await client.get( + f"http://127.0.0.1:{server_port}/", + headers=headers + ) + # This should pass security but fail on session validation + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] + + finally: + process.terminate() + process.join() + diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index e55983e01f..0cadb87d6b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -81,7 +81,13 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: # Test fixtures def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" - sse = SseServerTransport("/messages/") + from mcp.server.transport_security import TransportSecuritySettings + # Configure security with allowed hosts/origins for testing + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + sse = SseServerTransport("/messages/", security_settings=security_settings) server = ServerTest() async def handle_sse(request: Request) -> Response: @@ -318,3 +324,193 @@ async def test_sse_client_basic_connection_mounted_app( # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) + + +# Test server with request context that returns headers in the response +class RequestContextServer(Server[object, Request]): + def __init__(self): + super().__init__("request_context_server") + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + headers_info = {} + context = self.request_context + if context.request: + headers_info = dict(context.request.headers) + + if name == "echo_headers": + return [TextContent(type="text", text=json.dumps(headers_info))] + elif name == "echo_context": + context_data = { + "request_id": args.get("request_id"), + "headers": headers_info, + } + return [TextContent(type="text", text=json.dumps(context_data))] + + return [TextContent(type="text", text=f"Called {name}")] + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echoes request headers", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + inputSchema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), + ] + + +def run_context_server(server_port: int) -> None: + """Run a server that captures request context""" + from mcp.server.transport_security import TransportSecuritySettings + # Configure security with allowed hosts/origins for testing + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + sse = SseServerTransport("/messages/", security_settings=security_settings) + context_server = RequestContextServer() + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await context_server.run( + streams[0], streams[1], context_server.create_initialization_options() + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ] + ) + + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting context server on {server_port}") + server.run() + + +@pytest.fixture() +def context_server(server_port: int) -> Generator[None, None, None]: + """Fixture that provides a server with request context capture""" + proc = multiprocessing.Process( + target=run_context_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting context server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for context server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Context server failed to start after {max_attempts} attempts" + ) + + yield + + print("killing context server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("context server process failed to terminate") + + +@pytest.mark.anyio +async def test_request_context_propagation( + context_server: None, server_url: str +) -> None: + """Test that request context is properly propagated through SSE transport.""" + # Test with custom headers + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + "X-Trace-Id": "trace-123", + } + + async with sse_client(server_url + "/sse", headers=custom_headers) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) + + # Parse the JSON response + + assert len(tool_result.content) == 1 + headers_data = json.loads( + tool_result.content[0].text + if tool_result.content[0].type == "text" + else "{}" + ) + + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" + + +@pytest.mark.anyio +async def test_request_context_isolation(context_server: None, server_url: str) -> None: + """Test that request contexts are isolated between different SSE clients.""" + contexts = [] + + # Create multiple clients with different headers + for i in range(3): + headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} + + async with sse_client(server_url + "/sse", headers=headers) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call the tool that echoes context + tool_result = await session.call_tool( + "echo_context", {"request_id": f"request-{i}"} + ) + + assert len(tool_result.content) == 1 + context_data = json.loads( + tool_result.content[0].text + if tool_result.content[0].type == "text" + else "{}" + ) + contexts.append(context_data) + + # Verify each request had its own context + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): + assert ctx["request_id"] == f"request-{i}" + assert ctx["headers"].get("x-request-id") == f"request-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f1c7ef8097..6292e79739 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -225,10 +225,16 @@ def create_app( server = ServerTest() # Create the session manager + from mcp.server.transport_security import TransportSecuritySettings + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) session_manager = StreamableHTTPSessionManager( app=server, event_store=event_store, json_response=is_json_response_enabled, + security_settings=security_settings, ) # Create an ASGI application that uses the session manager @@ -438,8 +444,12 @@ def test_content_type_validation(basic_server, basic_server_url): }, data="This is not JSON", ) - assert response.status_code == 415 - assert "Unsupported Media Type" in response.text + # May return 400 (security middleware) or 415 (transport validation) + assert response.status_code in (400, 415) + assert any( + msg in response.text + for msg in ["Invalid Content-Type", "Unsupported Media Type"] + ) def test_json_validation(basic_server, basic_server_url): From 29a5e3ac996b526c5194319f2c748ee2586bad18 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Fri, 30 May 2025 09:39:47 -0700 Subject: [PATCH 2/7] Update tests --- tests/server/fastmcp/test_integration.py | 7 ++----- tests/server/test_streamable_http_security.py | 17 ----------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index a13df65de2..14313bacce 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -22,7 +22,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.resources import FunctionResource from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext @@ -101,10 +101,7 @@ def echo(message: str) -> str: def make_everything_fastmcp() -> FastMCP: - """Create a FastMCP server with all features enabled for testing.""" - from mcp.server.fastmcp import Context - from mcp.server.transport_security import TransportSecuritySettings - + """Create a FastMCP server with all features enabled for testing.""" transport_security = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 3330bcfe27..a35053d2c2 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -263,23 +263,6 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): ) # Should connect successfully with custom host assert response.status_code == 200 - - # Test with non-allowed host - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers - ) - assert response.status_code == 400, response.text - assert response.text == "Invalid Host header" - finally: process.terminate() process.join() From aeab63120efc943aadd7cf3062c1cc7456b65d68 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Fri, 30 May 2025 09:47:39 -0700 Subject: [PATCH 3/7] Clean up --- src/mcp/server/transport_security.py | 10 +- tests/server/fastmcp/test_integration.py | 173 ----------------------- tests/shared/test_sse.py | 1 - tests/shared/test_streamable_http.py | 11 +- 4 files changed, 8 insertions(+), 187 deletions(-) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 63cacbc9e7..ee17291653 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -23,16 +23,14 @@ class TransportSecuritySettings(BaseModel): allowed_hosts: list[str] = Field( default=[], - description="List of allowed Host header values. If None, all hosts " - "are allowed when protection is disabled, or only localhost/127.0.0.1 " - "when enabled." + description="List of allowed Host header values. Only applies when " + + "enable_dns_rebinding_protection is True." ) allowed_origins: list[str] = Field( default=[], - description="List of allowed Origin header values. If None, all " - "origins are allowed when protection is disabled, or only localhost " - "origins when enabled." + description="List of allowed Origin header values. Only applies when " + + "enable_dns_rebinding_protection is True." ) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 14313bacce..54378f4cad 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -475,179 +475,6 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert tool_result.content[0].text == "Echo: hello" -def make_fastmcp_with_context_app(): - """Create a FastMCP server that can access request context.""" - from mcp.server.transport_security import TransportSecuritySettings - - transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - mcp = FastMCP(name="ContextServer", transport_security=transport_security) - - # Tool that echoes request headers - @mcp.tool(description="Echo request headers from context") - def echo_headers(ctx: Context[Any, Any, Request]) -> str: - """Returns the request headers as JSON.""" - headers_info = {} - if ctx.request_context.request: - # Now the type system knows request is a Starlette Request object - headers_info = dict(ctx.request_context.request.headers) - return json.dumps(headers_info) - - # Tool that returns full request context - @mcp.tool(description="Echo request context with custom data") - def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: - """Returns request context including headers and custom data.""" - context_data = { - "custom_request_id": custom_request_id, - "headers": {}, - "method": None, - "path": None, - } - if ctx.request_context.request: - request = ctx.request_context.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return json.dumps(context_data) - - # Create the SSE app - app = mcp.sse_app() - return mcp, app - - -def run_context_server(server_port: int) -> None: - """Run the context-aware FastMCP server.""" - _, app = make_fastmcp_with_context_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) - print(f"Starting context server on port {server_port}") - server.run() - - -@pytest.fixture() -def context_aware_server(server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process( - target=run_context_server, args=(server_port,), daemon=True - ) - print("Starting context-aware server process") - proc.start() - - # Wait for server to be running - max_attempts = 20 - attempt = 0 - print("Waiting for context-aware server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError( - f"Context server failed to start after {max_attempts} attempts" - ) - - yield - - print("Killing context-aware server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("Context server process failed to terminate") - - -@pytest.mark.anyio -async def test_fast_mcp_with_request_context( - context_aware_server: None, server_url: str -) -> None: - """Test that FastMCP properly propagates request context to tools.""" - # Test with custom headers - custom_headers = { - "Authorization": "Bearer fastmcp-test-token", - "X-Custom-Header": "fastmcp-value", - "X-Request-Id": "req-123", - } - - async with sse_client(server_url + "/sse", headers=custom_headers) as streams: - async with ClientSession(*streams) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "ContextServer" - - # Test 1: Call tool that echoes headers - headers_result = await session.call_tool("echo_headers", {}) - assert len(headers_result.content) == 1 - assert isinstance(headers_result.content[0], TextContent) - - headers_data = json.loads(headers_result.content[0].text) - assert headers_data.get("authorization") == "Bearer fastmcp-test-token" - assert headers_data.get("x-custom-header") == "fastmcp-value" - assert headers_data.get("x-request-id") == "req-123" - - # Test 2: Call tool that returns full context - context_result = await session.call_tool( - "echo_context", {"custom_request_id": "test-123"} - ) - assert len(context_result.content) == 1 - assert isinstance(context_result.content[0], TextContent) - - context_data = json.loads(context_result.content[0].text) - assert context_data["custom_request_id"] == "test-123" - assert ( - context_data["headers"].get("authorization") - == "Bearer fastmcp-test-token" - ) - assert context_data["method"] == "POST" # - - -@pytest.mark.anyio -async def test_fast_mcp_request_context_isolation( - context_aware_server: None, server_url: str -) -> None: - """Test that request contexts are isolated between different FastMCP clients.""" - contexts = [] - - # Create multiple clients with different headers - for i in range(3): - headers = { - "Authorization": f"Bearer token-{i}", - "X-Request-Id": f"fastmcp-req-{i}", - "X-Custom-Value": f"value-{i}", - } - - async with sse_client(server_url + "/sse", headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - - # Call the tool that returns context - tool_result = await session.call_tool( - "echo_context", {"custom_request_id": f"test-req-{i}"} - ) - - # Parse and store the result - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - context_data = json.loads(tool_result.content[0].text) - contexts.append(context_data) - - # Verify each request had its own isolated context - assert len(contexts) == 3 - for i, ctx in enumerate(contexts): - assert ctx["custom_request_id"] == f"test-req-{i}" - assert ctx["headers"].get("authorization") == f"Bearer token-{i}" - assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}" - assert ctx["headers"].get("x-custom-value") == f"value-{i}" - - @pytest.mark.anyio async def test_fastmcp_streamable_http( streamable_http_server: None, http_server_url: str diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5f253f7196..8e527d38f2 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -83,7 +83,6 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: # Test fixtures def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" - from mcp.server.transport_security import TransportSecuritySettings # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3ac261d30d..8966f737e6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -36,6 +36,7 @@ StreamId, ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ( @@ -227,7 +228,6 @@ def create_app( server = ServerTest() # Create the session manager - from mcp.server.transport_security import TransportSecuritySettings security_settings = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] @@ -446,12 +446,9 @@ def test_content_type_validation(basic_server, basic_server_url): }, data="This is not JSON", ) - # May return 400 (security middleware) or 415 (transport validation) - assert response.status_code in (400, 415) - assert any( - msg in response.text - for msg in ["Invalid Content-Type", "Unsupported Media Type"] - ) + + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text def test_json_validation(basic_server, basic_server_url): From fb3ce680ee1eaff84f06850f052d73a2d5767d06 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Fri, 30 May 2025 09:50:45 -0700 Subject: [PATCH 4/7] Rerun tests From b2bbcd105ec92459bd5ca2accb53aa2962b69d61 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Wed, 11 Jun 2025 09:00:36 -0700 Subject: [PATCH 5/7] Move gate to validate_request to avoid calling functions unnecessarily --- src/mcp/server/transport_security.py | 72 ++++++++++++++-------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index ee17291653..5cc5640fb3 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -15,48 +15,45 @@ class TransportSecuritySettings(BaseModel): These settings help protect against DNS rebinding attacks by validating incoming request headers. """ - + enable_dns_rebinding_protection: bool = Field( default=True, - description="Enable DNS rebinding protection (recommended for production)" + description="Enable DNS rebinding protection (recommended for production)", ) - + allowed_hosts: list[str] = Field( default=[], - description="List of allowed Host header values. Only applies when " + - "enable_dns_rebinding_protection is True." + description="List of allowed Host header values. Only applies when " + + "enable_dns_rebinding_protection is True.", ) - + allowed_origins: list[str] = Field( default=[], - description="List of allowed Origin header values. Only applies when " + - "enable_dns_rebinding_protection is True." + description="List of allowed Origin header values. Only applies when " + + "enable_dns_rebinding_protection is True.", ) class TransportSecurityMiddleware: """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" - + def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default # for backwards compatibility self.settings = settings or TransportSecuritySettings( enable_dns_rebinding_protection=False ) - + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not self.settings.enable_dns_rebinding_protection: - return True - if not host: logger.warning("Missing Host header in request") return False - + # Check exact match first if host in self.settings.allowed_hosts: return True - + # Check wildcard port patterns for allowed in self.settings.allowed_hosts: if allowed.endswith(":*"): @@ -65,23 +62,20 @@ def _validate_host(self, host: str | None) -> bool: # Check if the actual host starts with base host and has a port if host.startswith(base_host + ":"): return True - + logger.warning(f"Invalid Host header: {host}") return False - + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" - if not self.settings.enable_dns_rebinding_protection: - return True - # Origin can be absent for same-origin requests if not origin: return True - + # Check exact match first if origin in self.settings.allowed_origins: return True - + # Check wildcard port patterns for allowed in self.settings.allowed_origins: if allowed.endswith(":*"): @@ -90,44 +84,48 @@ def _validate_origin(self, origin: str | None) -> bool: # Check if the actual origin starts with base origin and has a port if origin.startswith(base_origin + ":"): return True - + logger.warning(f"Invalid Origin header: {origin}") return False - + def _validate_content_type(self, content_type: str | None) -> bool: """Validate the Content-Type header for POST requests.""" if not content_type: logger.warning("Missing Content-Type header in POST request") return False - + # Content-Type must start with application/json if not content_type.lower().startswith("application/json"): logger.warning(f"Invalid Content-Type header: {content_type}") return False - + return True - + async def validate_request( self, request: Request, is_post: bool = False ) -> Response | None: """Validate request headers for DNS rebinding protection. - + Returns None if validation passes, or an error Response if validation fails. """ + # Always validate Content-Type for POST requests + if is_post: + content_type = request.headers.get("content-type") + if not self._validate_content_type(content_type): + return Response("Invalid Content-Type header", status_code=400) + + # Skip remaining validation if DNS rebinding protection is disabled + if not self.settings.enable_dns_rebinding_protection: + return None + # Validate Host header host = request.headers.get("host") if not self._validate_host(host): return Response("Invalid Host header", status_code=400) - + # Validate Origin header origin = request.headers.get("origin") if not self._validate_origin(origin): return Response("Invalid Origin header", status_code=400) - - # Validate Content-Type for POST requests - if is_post: - content_type = request.headers.get("content-type") - if not self._validate_content_type(content_type): - return Response("Invalid Content-Type header", status_code=400) - - return None \ No newline at end of file + + return None From f349d6fbff7b7469fdf3039672694c9327bb72e1 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Wed, 11 Jun 2025 09:35:07 -0700 Subject: [PATCH 6/7] Fix formatting --- src/mcp/server/fastmcp/server.py | 2 +- src/mcp/server/sse.py | 8 +- src/mcp/server/streamable_http.py | 4 +- src/mcp/server/transport_security.py | 10 +- tests/server/fastmcp/test_integration.py | 24 +-- tests/server/test_sse_security.py | 157 ++++++------------ tests/server/test_streamable_http_security.py | 111 +++++-------- tests/shared/test_sse.py | 6 +- tests/shared/test_streamable_http.py | 3 +- 9 files changed, 118 insertions(+), 207 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 550205fee7..a6edffe46c 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -119,7 +119,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): ) auth: AuthSettings | None = None - + # Transport security settings (DNS rebinding protection) transport_security: TransportSecuritySettings | None = None diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 55361dd8b1..41145e49f6 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -77,13 +77,11 @@ class SseServerTransport: _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] _security: TransportSecurityMiddleware - def __init__( - self, endpoint: str, security_settings: TransportSecuritySettings | None = None - ) -> None: + def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: """ Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. - + Args: endpoint: The relative or absolute URL for POST messages. security_settings: Optional security settings for DNS rebinding protection. @@ -178,7 +176,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) - + # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) if error_response: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index af8023fa09..dc5f7a986c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -256,14 +256,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) - + # Validate request headers for DNS rebinding protection is_post = request.method == "POST" error_response = await self._security.validate_request(request, is_post=is_post) if error_response: await error_response(scope, receive, send) return - + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 5cc5640fb3..b7eec474d6 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -11,7 +11,7 @@ class TransportSecuritySettings(BaseModel): """Settings for MCP transport security features. - + These settings help protect against DNS rebinding attacks by validating incoming request headers. """ @@ -40,9 +40,7 @@ class TransportSecurityMiddleware: def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default # for backwards compatibility - self.settings = settings or TransportSecuritySettings( - enable_dns_rebinding_protection=False - ) + self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" @@ -101,9 +99,7 @@ def _validate_content_type(self, content_type: str | None) -> bool: return True - async def validate_request( - self, request: Request, is_post: bool = False - ) -> Response | None: + async def validate_request(self, request: Request, is_post: bool = False) -> Response | None: """Validate request headers for DNS rebinding protection. Returns None if validation passes, or an error Response if validation fails. diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index dfa97c70da..7128eccf17 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -84,8 +84,7 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str: def make_fastmcp_app(): """Create a FastMCP server without auth settings.""" transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) @@ -101,10 +100,9 @@ def echo(message: str) -> str: def make_everything_fastmcp() -> FastMCP: - """Create a FastMCP server with all features enabled for testing.""" + """Create a FastMCP server with all features enabled for testing.""" transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) mcp = FastMCP(name="EverythingServer", transport_security=transport_security) @@ -237,10 +235,9 @@ def make_everything_fastmcp_app(): def make_fastmcp_streamable_http_app(): - """Create a FastMCP server with StreamableHTTP transport.""" + """Create a FastMCP server with StreamableHTTP transport.""" transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) @@ -266,16 +263,11 @@ def make_everything_fastmcp_streamable_http_app(): def make_fastmcp_stateless_http_app(): - """Create a FastMCP server with stateless StreamableHTTP transport.""" + """Create a FastMCP server with stateless StreamableHTTP transport.""" transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - mcp = FastMCP( - name="StatelessServer", - stateless_http=True, - transport_security=transport_security + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) + mcp = FastMCP(name="StatelessServer", stateless_http=True, transport_security=transport_security) # Add a simple tool @mcp.tool(description="A simple echo tool") diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index a260f32103..7db7297644 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -37,48 +37,38 @@ def server_url(server_port: int) -> str: class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) - + async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings( - port: int, security_settings: TransportSecuritySettings | None = None -): +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - + async def handle_sse(request: Request): try: - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: if streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + await app.run(streams[0], streams[1], app.create_initialization_options()) except ValueError as e: # Validation error was already handled inside connect_sse logger.debug(f"SSE connection failed validation: {e}") return Response() - + routes = [ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse_transport.handle_post_message), ] - + starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") -def start_server_process( - port: int, security_settings: TransportSecuritySettings | None = None -): +def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" - process = multiprocessing.Process( - target=run_server_with_settings, args=(port, security_settings) - ) + process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() # Give server time to start time.sleep(1) @@ -89,16 +79,12 @@ def start_server_process( async def test_sse_security_default_settings(server_port: int): """Test SSE with default security settings (protection disabled).""" process = start_server_process(server_port) - + try: headers = {"Host": "evil.com", "Origin": "http://evil.com"} - + async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream( - "GET", - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) as response: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: assert response.status_code == 200 finally: process.terminate() @@ -109,24 +95,18 @@ async def test_sse_security_default_settings(server_port: int): async def test_sse_security_invalid_host_header(server_port: int): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list - security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["example.com"] - ) + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) process = start_server_process(server_port, security_settings) try: # Test with invalid host header headers = {"Host": "evil.com"} - + async with httpx.AsyncClient() as client: - response = await client.get( - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 400 assert response.text == "Invalid Host header" - + finally: process.terminate() process.join() @@ -137,24 +117,19 @@ async def test_sse_security_invalid_origin_header(server_port: int): """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1:*"], - allowed_origins=["http://localhost:*"] + enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) process = start_server_process(server_port, security_settings) - + try: # Test with invalid origin header headers = {"Origin": "http://evil.com"} - + async with httpx.AsyncClient() as client: - response = await client.get( - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 400 assert response.text == "Invalid Origin header" - + finally: process.terminate() process.join() @@ -165,12 +140,10 @@ async def test_sse_security_post_invalid_content_type(server_port: int): """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1:*"], - allowed_origins=["http://127.0.0.1:*"] + enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) process = start_server_process(server_port, security_settings) - + try: async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type @@ -178,19 +151,18 @@ async def test_sse_security_post_invalid_content_type(server_port: int): response = await client.post( f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", headers={"Content-Type": "text/plain"}, - content="test" + content="test", ) assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - + # Test POST with missing content type response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - content="test" + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" ) assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - + finally: process.terminate() process.join() @@ -201,21 +173,17 @@ async def test_sse_security_disabled(server_port: int): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) process = start_server_process(server_port, settings) - + try: # Test with invalid host header - should still work headers = {"Host": "evil.com"} - + async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream( - "GET", - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) as response: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: # Should connect successfully even with invalid host assert response.status_code == 200 - + finally: process.terminate() process.join() @@ -227,35 +195,28 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], - allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"] + allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) process = start_server_process(server_port, settings) - + try: # Test with custom allowed host headers = {"Host": "custom.host"} - + async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream( - "GET", - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) as response: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: # Should connect successfully with custom host assert response.status_code == 200 - + # Test with non-allowed host headers = {"Host": "evil.com"} - + async with httpx.AsyncClient() as client: - response = await client.get( - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 400 assert response.text == "Invalid Host header" - + finally: process.terminate() process.join() @@ -267,53 +228,43 @@ async def test_sse_security_wildcard_ports(server_port: int): settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], - allowed_origins=["http://localhost:*", "http://127.0.0.1:*"] + allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) process = start_server_process(server_port, settings) - + try: # Test with various port numbers for test_port in [8080, 3000, 9999]: headers = {"Host": f"localhost:{test_port}"} - + async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream( - "GET", - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) as response: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: # Should connect successfully with any port assert response.status_code == 200 - + headers = {"Origin": f"http://localhost:{test_port}"} - + async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream( - "GET", - f"http://127.0.0.1:{server_port}/sse", - headers=headers - ) as response: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: # Should connect successfully with any port assert response.status_code == 200 - + finally: process.terminate() process.join() -@pytest.mark.anyio +@pytest.mark.anyio async def test_sse_security_post_valid_content_type(server_port: int): """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1:*"], - allowed_origins=["http://127.0.0.1:*"] + enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) process = start_server_process(server_port, security_settings) - + try: async with httpx.AsyncClient() as client: # Test with various valid content types @@ -323,20 +274,20 @@ async def test_sse_security_post_valid_content_type(server_port: int): "application/json;charset=utf-8", "APPLICATION/JSON", # Case insensitive ] - + for content_type in valid_content_types: # Use a valid UUID format (even though session won't exist) fake_session_id = "12345678123456781234567812345678" response = await client.post( f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", headers={"Content-Type": content_type}, - json={"test": "data"} + json={"test": "data"}, ) # Will get 404 because session doesn't exist, but that's OK # We're testing that it passes the content-type check assert response.status_code == 404 assert response.text == "Could not find session" - + finally: process.terminate() - process.join() \ No newline at end of file + process.join() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index a35053d2c2..b5dc81c4aa 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -38,17 +38,15 @@ def server_url(server_port: int) -> str: class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) - + async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings( - port: int, security_settings: TransportSecuritySettings | None = None -): +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): """Run the StreamableHTTP server with specified security settings.""" app = SecurityTestServer() - + # Create session manager with security settings session_manager = StreamableHTTPSessionManager( app=app, @@ -56,34 +54,28 @@ def run_server_with_settings( stateless=False, security_settings=security_settings, ) - + # Create the ASGI handler - async def handle_streamable_http( - scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await session_manager.handle_request(scope, receive, send) - + # Create Starlette app with lifespan @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with session_manager.run(): yield - + routes = [ Mount("/", app=handle_streamable_http), ] - + starlette_app = Starlette(routes=routes, lifespan=lifespan) uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") -def start_server_process( - port: int, security_settings: TransportSecuritySettings | None = None -): +def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): """Start server in a separate process.""" - process = multiprocessing.Process( - target=run_server_with_settings, args=(port, security_settings) - ) + process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() # Give server time to start time.sleep(1) @@ -94,7 +86,7 @@ def start_server_process( async def test_streamable_http_security_default_settings(server_port: int): """Test StreamableHTTP with default security settings (protection enabled).""" process = start_server_process(server_port) - + try: # Test with valid localhost headers async with httpx.AsyncClient(timeout=5.0) as client: @@ -105,11 +97,11 @@ async def test_streamable_http_security_default_settings(server_port: int): headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", - } + }, ) assert response.status_code == 200 assert "mcp-session-id" in response.headers - + finally: process.terminate() process.join() @@ -118,11 +110,9 @@ async def test_streamable_http_security_default_settings(server_port: int): @pytest.mark.anyio async def test_streamable_http_security_invalid_host_header(server_port: int): """Test StreamableHTTP with invalid Host header.""" - security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True - ) + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) process = start_server_process(server_port, security_settings) - + try: # Test with invalid host header headers = { @@ -130,16 +120,16 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): "Accept": "application/json, text/event-stream", "Content-Type": "application/json", } - + async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( f"http://127.0.0.1:{server_port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers + headers=headers, ) assert response.status_code == 400 assert response.text == "Invalid Host header" - + finally: process.terminate() process.join() @@ -148,12 +138,9 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): @pytest.mark.anyio async def test_streamable_http_security_invalid_origin_header(server_port: int): """Test StreamableHTTP with invalid Origin header.""" - security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1:*"] - ) + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) process = start_server_process(server_port, security_settings) - + try: # Test with invalid origin header headers = { @@ -161,16 +148,16 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): "Accept": "application/json, text/event-stream", "Content-Type": "application/json", } - + async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( f"http://127.0.0.1:{server_port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers + headers=headers, ) assert response.status_code == 400 assert response.text == "Invalid Origin header" - + finally: process.terminate() process.join() @@ -180,7 +167,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): async def test_streamable_http_security_invalid_content_type(server_port: int): """Test StreamableHTTP POST with invalid Content-Type header.""" process = start_server_process(server_port) - + try: async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type @@ -190,20 +177,20 @@ async def test_streamable_http_security_invalid_content_type(server_port: int): "Content-Type": "text/plain", "Accept": "application/json, text/event-stream", }, - content="test" + content="test", ) assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - + # Test POST with missing content type response = await client.post( f"http://127.0.0.1:{server_port}/", headers={"Accept": "application/json, text/event-stream"}, - content="test" + content="test", ) assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - + finally: process.terminate() process.join() @@ -214,7 +201,7 @@ async def test_streamable_http_security_disabled(server_port: int): """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) process = start_server_process(server_port, settings) - + try: # Test with invalid host header - should still work headers = { @@ -222,16 +209,16 @@ async def test_streamable_http_security_disabled(server_port: int): "Accept": "application/json, text/event-stream", "Content-Type": "application/json", } - + async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( f"http://127.0.0.1:{server_port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers + headers=headers, ) # Should connect successfully even with invalid host assert response.status_code == 200 - + finally: process.terminate() process.join() @@ -243,10 +230,10 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], - allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"] + allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) process = start_server_process(server_port, settings) - + try: # Test with custom allowed host headers = { @@ -254,12 +241,12 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): "Accept": "application/json, text/event-stream", "Content-Type": "application/json", } - + async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( f"http://127.0.0.1:{server_port}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers + headers=headers, ) # Should connect successfully with custom host assert response.status_code == 200 @@ -271,46 +258,36 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): @pytest.mark.anyio async def test_streamable_http_security_get_request(server_port: int): """Test StreamableHTTP GET request with security.""" - security_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1"] - ) + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) process = start_server_process(server_port, security_settings) - + try: # Test GET request with invalid host header headers = { "Host": "evil.com", "Accept": "text/event-stream", } - + async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get( - f"http://127.0.0.1:{server_port}/", - headers=headers - ) + response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) assert response.status_code == 400 assert response.text == "Invalid Host header" - + # Test GET request with valid host header headers = { "Host": "127.0.0.1", "Accept": "text/event-stream", } - + async with httpx.AsyncClient(timeout=5.0) as client: # GET requests need a session ID in StreamableHTTP # So it will fail with "Missing session ID" not security error - response = await client.get( - f"http://127.0.0.1:{server_port}/", - headers=headers - ) + response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) # This should pass security but fail on session validation assert response.status_code == 400 body = response.json() assert "Missing session ID" in body["error"]["message"] - + finally: process.terminate() process.join() - diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index d2d81a4e1f..8e1912e9bd 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -83,8 +83,7 @@ def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) server = ServerTest() @@ -347,8 +346,7 @@ def run_context_server(server_port: int) -> None: """Run a server that captures request context""" # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) context_server = RequestContextServer() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index c306627955..0f2296fab9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -217,8 +217,7 @@ def create_app(is_json_response_enabled=False, event_store: EventStore | None = # Create the session manager security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) session_manager = StreamableHTTPSessionManager( app=server, From 64ddbe21f74ebda19ed783d9ae637923f53cd1f8 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Fri, 13 Jun 2025 14:37:52 -0700 Subject: [PATCH 7/7] Use HTTP 421 for invalid Host headers in DNS rebinding protection --- src/mcp/server/transport_security.py | 2 +- tests/server/test_sse_security.py | 4 ++-- tests/server/test_streamable_http_security.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index b7eec474d6..3a884ee2b5 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -117,7 +117,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res # Validate Host header host = request.headers.get("host") if not self._validate_host(host): - return Response("Invalid Host header", status_code=400) + return Response("Invalid Host header", status_code=421) # Validate Origin header origin = request.headers.get("origin") diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 7db7297644..43af35061b 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -104,7 +104,7 @@ async def test_sse_security_invalid_host_header(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 400 + assert response.status_code == 421 assert response.text == "Invalid Host header" finally: @@ -214,7 +214,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 400 + assert response.status_code == 421 assert response.text == "Invalid Host header" finally: diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index b5dc81c4aa..eed7919249 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -127,7 +127,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) - assert response.status_code == 400 + assert response.status_code == 421 assert response.text == "Invalid Host header" finally: @@ -270,7 +270,7 @@ async def test_streamable_http_security_get_request(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 400 + assert response.status_code == 421 assert response.text == "Invalid Host header" # Test GET request with valid host header