diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index b7ff33280..bb81ea3c1 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -2,11 +2,16 @@ SSE Server Transport Module This module implements a Server-Sent Events (SSE) transport layer for MCP servers. +Endpoints are specified as relative paths. This aligns with common client URL +construction patterns (for example, `urllib.parse.urljoin`) and works correctly +when applications are deployed behind proxies or at subpaths. Example usage: -``` - # Create an SSE transport at an endpoint - sse = SseServerTransport("/messages/") +```python + # Recommended: provide a relative path segment (no scheme/host/query/fragment). + # Using "messages/" works well with clients that build absolute URLs using + # `urllib.parse.urljoin`, including in proxied/subpath deployments. + sse = SseServerTransport("messages/") # Create Starlette routes for SSE and message handling routes = [ @@ -30,6 +35,17 @@ async def handle_sse(request): uvicorn.run(starlette_app, host="127.0.0.1", port=port) ``` +Path behavior examples inside the server (final path emitted to clients): +- root_path="" and endpoint="messages/" -> "/messages/" +- root_path="/api" and endpoint="messages/" -> "/api/messages/" + +Note: When clients use `urllib.parse.urljoin(base, path)`, joining a segment that +starts with "/" replaces the base path. Providing a relative segment like +`"messages/?id=1"` preserves the base path as intended. + +For servers behind proxies or mounted at subpaths, prefer a relative path without +leading slash (e.g., "messages/") to ensure correct joining with `urljoin`. + Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' object is not callable" error when client disconnects. The example above returns an empty Response() after the SSE connection ends to fix this. @@ -83,8 +99,10 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | messages to the relative path given. Args: - endpoint: A relative path where messages should be posted - (e.g., "/messages/"). + endpoint: Relative path segment where messages should be posted + (e.g., "messages/"). Avoid scheme/host/query/fragment. When + clients construct absolute URLs using `urllib.parse.urljoin`, + relative segments preserve any existing base path. security_settings: Optional security settings for DNS rebinding protection. Note: @@ -96,28 +114,60 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | 3. Portability: The same endpoint configuration works across different environments (development, staging, production) + The endpoint path handling preserves the provided relative path and is + suitable for deployments under proxies or subpaths. + Raises: ValueError: If the endpoint is a full URL instead of a relative path """ super().__init__() - # Validate that endpoint is a relative path and not a full URL + # Validate that endpoint is a relative path and not a full URL. if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint: raise ValueError( - f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), " - "expecting a relative path (e.g., '/messages/')." + f"Given endpoint: {endpoint} is not a relative path (e.g., 'messages/'), " + "expecting a relative path with no scheme/host/query/fragment." ) - # Ensure endpoint starts with a forward slash - if not endpoint.startswith("/"): - endpoint = "/" + endpoint - + # Store the endpoint as provided to retain relative-path semantics and make + # client URL construction predictable across deployment topologies. self._endpoint = endpoint self._read_stream_writers = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") + def _build_message_path(self, root_path: str) -> str: + """ + Helper method to properly construct the message path + + Constructs the message path relative to the app's mount point and the + provided `root_path`. The stored endpoint is treated as path-absolute if + it starts with "/", otherwise as a relative segment. + + Args: + root_path: The root path from ASGI scope (e.g., "" or "/api_prefix") + + Returns: + The properly constructed path for client message posting + """ + # Clean up the root path + clean_root_path = root_path.rstrip("/") + + # If endpoint starts with "/", treat it as path-absolute from the app mount; + # otherwise, treat it as relative to `root_path`. + if self._endpoint.startswith("/"): + # Path-absolute within the app mount - just concatenate + full_path = clean_root_path + self._endpoint + else: + # Relative path - ensure proper joining + if clean_root_path: + full_path = clean_root_path + "/" + self._endpoint + else: + full_path = "/" + self._endpoint + + return full_path + @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": @@ -145,17 +195,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") - # Determine the full path for the message endpoint to be sent to the client. - # scope['root_path'] is the prefix where the current Starlette app - # instance is mounted. - # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". + # Use the new helper method for proper path construction root_path = scope.get("root_path", "") - - # self._endpoint is the path *within* this app, e.g., "/messages". - # Concatenating them gives the full absolute path from the server root. - # e.g., "" + "/messages" -> "/messages" - # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" - full_message_path_for_client = root_path.rstrip("/") + self._endpoint + full_message_path_for_client = self._build_message_path(root_path) # This is the URI (path + query) the client will use to POST messages. client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 43af35061..8cbe20398 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -291,3 +291,47 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +@pytest.mark.anyio +async def test_endpoint_validation_rejects_absolute_urls(): + """Validate endpoint format: relative path segments only. + + Context on URL joining (urllib.parse.urljoin): + - Joining a segment starting with "/" resets to the host root: + urljoin("http://host/app/sse", "/messages") -> "http://host/messages" + - Joining a relative segment appends relative to the base: + urljoin("http://host/hello/world", "messages") -> "http://host/hello/messages" + urljoin("http://host/hello/world/", "messages") -> "http://host/hello/world/messages" + + This test ensures the transport accepts relative path segments (e.g., "messages/"), + rejects full URLs or paths containing query/fragment components, and stores accepted + values verbatim (no normalization). Both leading-slash and non-leading-slash forms + are permitted because the server handles construction relative to its mount path. + """ + # Reject: fully-qualified URLs and segments that include query/fragment + invalid_endpoints = [ + "http://example.com/messages/", + "https://example.com/messages/", + "//example.com/messages/", + "/messages/?query=test", + "/messages/#fragment", + ] + + for invalid_endpoint in invalid_endpoints: + with pytest.raises(ValueError, match="is not a relative path"): + SseServerTransport(invalid_endpoint) + + # Accept: relative path forms; endpoint is stored as provided (no normalization) + valid_endpoints_and_expected = [ + ("/messages/", "/messages/"), # Leading-slash path segment + ("messages/", "messages/"), # Non-leading-slash path segment + ("/api/v1/messages/", "/api/v1/messages/"), + ("api/v1/messages/", "api/v1/messages/"), + ] + + for valid_endpoint, expected_stored_value in valid_endpoints_and_expected: + # Should not raise an exception + transport = SseServerTransport(valid_endpoint) + # Endpoint should be stored exactly as provided (no normalization) + assert transport._endpoint == expected_stored_value diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7b0d89cb4..a68c56b49 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -488,9 +488,9 @@ def test_sse_message_id_coercion(): @pytest.mark.parametrize( "endpoint, expected_result", [ - # Valid endpoints - should normalize and work + # Accept: relative path forms; endpoint is stored verbatim (no normalization) ("/messages/", "/messages/"), - ("messages/", "/messages/"), + ("messages/", "messages/"), ("/", "/"), # Invalid endpoints - should raise ValueError ("http://example.com/messages/", ValueError), @@ -501,13 +501,23 @@ def test_sse_message_id_coercion(): ], ) def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): - """Test that SseServerTransport properly validates and normalizes endpoints.""" - if isinstance(expected_result, type): + """Validate relative endpoint semantics and storage. + + Context on URL joining (urllib.parse.urljoin): + - Joining a segment starting with "/" resets to the host root: + urljoin("http://host/hello/world", "/messages") -> "http://host/messages" + - Joining a relative segment appends relative to the base: + urljoin("http://host/hello/world", "messages") -> "http://host/hello/messages" + urljoin("http://host/hello/world/", "messages/") -> "http://host/hello/world/messages/" + + The transport validates that endpoints are relative path segments (no scheme/host/query/fragment) + and stores accepted values exactly as provided. + """ + if isinstance(expected_result, type) and issubclass(expected_result, Exception): # Test invalid endpoints that should raise an exception with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"): SseServerTransport(endpoint) else: - # Test valid endpoints that should normalize correctly + # Endpoint should be stored exactly as provided (no normalization) sse = SseServerTransport(endpoint) assert sse._endpoint == expected_result - assert sse._endpoint.startswith("/")