Skip to content
86 changes: 64 additions & 22 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
SSE Server Transport Module

This module implements a Server-Sent Events (SSE) transport layer for MCP servers.
Endpoints are specified as relative paths. This aligns with common client URL
construction patterns (for example, `urllib.parse.urljoin`) and works correctly
when applications are deployed behind proxies or at subpaths.

Example usage:
```
# Create an SSE transport at an endpoint
sse = SseServerTransport("/messages/")
```python
# Recommended: provide a relative path segment (no scheme/host/query/fragment).
# Using "messages/" works well with clients that build absolute URLs using
# `urllib.parse.urljoin`, including in proxied/subpath deployments.
sse = SseServerTransport("messages/")

# Create Starlette routes for SSE and message handling
routes = [
Expand All @@ -30,6 +35,17 @@ async def handle_sse(request):
uvicorn.run(starlette_app, host="127.0.0.1", port=port)
```

Path behavior examples inside the server (final path emitted to clients):
- root_path="" and endpoint="messages/" -> "/messages/"
- root_path="/api" and endpoint="messages/" -> "/api/messages/"

Note: When clients use `urllib.parse.urljoin(base, path)`, joining a segment that
starts with "/" replaces the base path. Providing a relative segment like
`"messages/?id=1"` preserves the base path as intended.

For servers behind proxies or mounted at subpaths, prefer a relative path without
leading slash (e.g., "messages/") to ensure correct joining with `urljoin`.

Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
object is not callable" error when client disconnects. The example above returns
an empty Response() after the SSE connection ends to fix this.
Expand Down Expand Up @@ -83,8 +99,10 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
messages to the relative path given.

Args:
endpoint: A relative path where messages should be posted
(e.g., "/messages/").
endpoint: Relative path segment where messages should be posted
(e.g., "messages/"). Avoid scheme/host/query/fragment. When
clients construct absolute URLs using `urllib.parse.urljoin`,
relative segments preserve any existing base path.
security_settings: Optional security settings for DNS rebinding protection.

Note:
Expand All @@ -96,28 +114,60 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
3. Portability: The same endpoint configuration works across different
environments (development, staging, production)

The endpoint path handling preserves the provided relative path and is
suitable for deployments under proxies or subpaths.

Raises:
ValueError: If the endpoint is a full URL instead of a relative path
"""

super().__init__()

# Validate that endpoint is a relative path and not a full URL
# Validate that endpoint is a relative path and not a full URL.
if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint:
raise ValueError(
f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), "
"expecting a relative path (e.g., '/messages/')."
f"Given endpoint: {endpoint} is not a relative path (e.g., 'messages/'), "
"expecting a relative path with no scheme/host/query/fragment."
)

# Ensure endpoint starts with a forward slash
if not endpoint.startswith("/"):
endpoint = "/" + endpoint

# Store the endpoint as provided to retain relative-path semantics and make
# client URL construction predictable across deployment topologies.
self._endpoint = endpoint
self._read_stream_writers = {}
self._security = TransportSecurityMiddleware(security_settings)
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

def _build_message_path(self, root_path: str) -> str:
"""
Helper method to properly construct the message path

Constructs the message path relative to the app's mount point and the
provided `root_path`. The stored endpoint is treated as path-absolute if
it starts with "/", otherwise as a relative segment.

Args:
root_path: The root path from ASGI scope (e.g., "" or "/api_prefix")

Returns:
The properly constructed path for client message posting
"""
# Clean up the root path
clean_root_path = root_path.rstrip("/")

# If endpoint starts with "/", treat it as path-absolute from the app mount;
# otherwise, treat it as relative to `root_path`.
if self._endpoint.startswith("/"):
# Path-absolute within the app mount - just concatenate
full_path = clean_root_path + self._endpoint
else:
# Relative path - ensure proper joining
if clean_root_path:
full_path = clean_root_path + "/" + self._endpoint
else:
full_path = "/" + self._endpoint

return full_path

@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
Expand Down Expand Up @@ -145,17 +195,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

# Determine the full path for the message endpoint to be sent to the client.
# scope['root_path'] is the prefix where the current Starlette app
# instance is mounted.
# e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix".
# Use the new helper method for proper path construction
root_path = scope.get("root_path", "")

# self._endpoint is the path *within* this app, e.g., "/messages".
# Concatenating them gives the full absolute path from the server root.
# e.g., "" + "/messages" -> "/messages"
# e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages"
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
full_message_path_for_client = self._build_message_path(root_path)

# This is the URI (path + query) the client will use to POST messages.
client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
Expand Down
44 changes: 44 additions & 0 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,47 @@ async def test_sse_security_post_valid_content_type(server_port: int):
finally:
process.terminate()
process.join()


@pytest.mark.anyio
async def test_endpoint_validation_rejects_absolute_urls():
"""Validate endpoint format: relative path segments only.

Context on URL joining (urllib.parse.urljoin):
- Joining a segment starting with "/" resets to the host root:
urljoin("http://host/app/sse", "/messages") -> "http://host/messages"
- Joining a relative segment appends relative to the base:
urljoin("http://host/hello/world", "messages") -> "http://host/hello/messages"
urljoin("http://host/hello/world/", "messages") -> "http://host/hello/world/messages"

This test ensures the transport accepts relative path segments (e.g., "messages/"),
rejects full URLs or paths containing query/fragment components, and stores accepted
values verbatim (no normalization). Both leading-slash and non-leading-slash forms
are permitted because the server handles construction relative to its mount path.
"""
# Reject: fully-qualified URLs and segments that include query/fragment
invalid_endpoints = [
"http://example.com/messages/",
"https://example.com/messages/",
"//example.com/messages/",
"/messages/?query=test",
"/messages/#fragment",
]

for invalid_endpoint in invalid_endpoints:
with pytest.raises(ValueError, match="is not a relative path"):
SseServerTransport(invalid_endpoint)

# Accept: relative path forms; endpoint is stored as provided (no normalization)
valid_endpoints_and_expected = [
("/messages/", "/messages/"), # Leading-slash path segment
("messages/", "messages/"), # Non-leading-slash path segment
("/api/v1/messages/", "/api/v1/messages/"),
("api/v1/messages/", "api/v1/messages/"),
]

for valid_endpoint, expected_stored_value in valid_endpoints_and_expected:
# Should not raise an exception
transport = SseServerTransport(valid_endpoint)
# Endpoint should be stored exactly as provided (no normalization)
assert transport._endpoint == expected_stored_value
22 changes: 16 additions & 6 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ def test_sse_message_id_coercion():
@pytest.mark.parametrize(
"endpoint, expected_result",
[
# Valid endpoints - should normalize and work
# Accept: relative path forms; endpoint is stored verbatim (no normalization)
("/messages/", "/messages/"),
("messages/", "/messages/"),
("messages/", "messages/"),
("/", "/"),
# Invalid endpoints - should raise ValueError
("http://example.com/messages/", ValueError),
Expand All @@ -501,13 +501,23 @@ def test_sse_message_id_coercion():
],
)
def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]):
"""Test that SseServerTransport properly validates and normalizes endpoints."""
if isinstance(expected_result, type):
"""Validate relative endpoint semantics and storage.

Context on URL joining (urllib.parse.urljoin):
- Joining a segment starting with "/" resets to the host root:
urljoin("http://host/hello/world", "/messages") -> "http://host/messages"
- Joining a relative segment appends relative to the base:
urljoin("http://host/hello/world", "messages") -> "http://host/hello/messages"
urljoin("http://host/hello/world/", "messages/") -> "http://host/hello/world/messages/"

The transport validates that endpoints are relative path segments (no scheme/host/query/fragment)
and stores accepted values exactly as provided.
"""
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
# Test invalid endpoints that should raise an exception
with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"):
SseServerTransport(endpoint)
else:
# Test valid endpoints that should normalize correctly
# Endpoint should be stored exactly as provided (no normalization)
sse = SseServerTransport(endpoint)
assert sse._endpoint == expected_result
assert sse._endpoint.startswith("/")
Loading