diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 0552f2770..6e27b2a6c 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -90,6 +90,7 @@ async def get_prompt( if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -101,6 +102,7 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response("MCP SSE") starlette_app = Starlette( debug=True, diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 0ec1d926a..cbc494ee3 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -46,6 +46,7 @@ async def read_resource(uri: FileUrl) -> str | bytes: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -57,6 +58,7 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response("MCP SSE") starlette_app = Starlette( debug=True, diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 3eace52ea..4722bf993 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -60,6 +60,7 @@ async def list_tools() -> list[types.Tool]: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -71,6 +72,7 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response("MCP SSE") starlette_app = Starlette( debug=True, diff --git a/pyproject.toml b/pyproject.toml index 2a5558a42..907cef79f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", + "typer>=0.12.4", ] [project.optional-dependencies] @@ -48,6 +49,7 @@ dev-dependencies = [ "trio>=0.26.2", "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", + "sse-starlette>=2.3.4", ] [build-system] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 122acebb4..40f743afc 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -476,6 +476,7 @@ async def run_stdio_async(self) -> None: async def run_sse_async(self) -> None: """Run the server using SSE transport.""" from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -489,6 +490,7 @@ async def handle_sse(request): streams[1], self._mcp_server.create_initialization_options(), ) + return Response("OK") starlette_app = Starlette( debug=self.settings.debug, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 0127753d0..e1b11f57f 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -6,6 +6,8 @@ Example usage: ``` # Create an SSE transport at an endpoint + from starlette.responses import Response + sse = SseServerTransport("/messages/") # Create Starlette routes for SSE and message handling @@ -22,6 +24,7 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response("MCP SSE") # Create and run Starlette app starlette_app = Starlette(routes=routes) @@ -43,7 +46,7 @@ async def handle_sse(request): from sse_starlette import EventSourceResponse from starlette.requests import Request from starlette.responses import Response -from starlette.types import Receive, Scope, Send +from starlette.types import Message, Receive, Scope, Send import mcp.types as types @@ -120,9 +123,19 @@ async def sse_writer(): } ) + async def handle_see_disconnect(message: Message) -> None: + logger.debug(f"Disconnect sse {session_id}") + del self._read_stream_writers[session_id] + await read_stream.aclose() + await read_stream_writer.aclose() + await write_stream.aclose() + await write_stream_reader.aclose() + async with anyio.create_task_group() as tg: response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer + content=sse_stream_reader, + data_sender_callable=sse_writer, + client_close_handler_callable=handle_see_disconnect, ) logger.debug("Starting SSE response task") tg.start_soon(response, scope, receive, send) diff --git a/tests/server/test_sse_disconnect.py b/tests/server/test_sse_disconnect.py new file mode 100644 index 000000000..3dc6e0055 --- /dev/null +++ b/tests/server/test_sse_disconnect.py @@ -0,0 +1,54 @@ +import asyncio +from uuid import UUID + +import pytest +from starlette.types import Message, Scope + +from mcp.server.sse import SseServerTransport + + +@pytest.mark.anyio +async def test_sse_disconnect_handle(): + transport = SseServerTransport(endpoint="/sse") + # Create a minimal ASGI scope for an HTTP GET request + scope: Scope = { + "type": "http", + "method": "GET", + "path": "/sse", + "headers": [], + } + send_disconnect = False + + # Dummy receive and send functions + async def receive() -> dict: + nonlocal send_disconnect + if not send_disconnect: + send_disconnect = True + return {"type": "http.request"} + else: + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + await asyncio.sleep(0) + + # Run the connect_sse context manager + async with transport.connect_sse(scope, receive, send) as ( + read_stream, + write_stream, + ): + # Assert that streams are provided + assert read_stream is not None + assert write_stream is not None + + # There should be exactly one session + assert len(transport._read_stream_writers) == 1 + # Check that the session key is a UUID + session_id = next(iter(transport._read_stream_writers.keys())) + assert isinstance(session_id, UUID) + + # Check that the writer is still open + writer = transport._read_stream_writers[session_id] + assert writer is not None + + # After context exits, session should be cleaned up + assert len(transport._read_stream_writers) == 0 diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 87129ba91..8fc4fc3cf 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -80,16 +80,19 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: # Test fixtures def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" + from starlette.responses import Response + sse = SseServerTransport("/messages/") server = ServerTest() - async def handle_sse(request: Request) -> None: + async def handle_sse(request: Request) -> Response: async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: await server.run( streams[0], streams[1], server.create_initialization_options() ) + return Response("MCP SSE") app = Starlette( routes=[ diff --git a/uv.lock b/uv.lock index 7ff1a3ea2..9151902d7 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -23,17 +24,17 @@ wheels = [ [[package]] name = "anyio" -version = "4.5.0" +version = "4.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/44/66874c5256e9fbc30103b31927fd9341c8da6ccafd4721b2b3e81e6ef176/anyio-4.5.0.tar.gz", hash = "sha256:c5a275fe5ca0afd788001f58fca1e69e29ce706d746e317d660e21f70c530ef9", size = 169376 } +sdist = { url = "https://files.pythonhosted.org/packages/f6/40/318e58f669b1a9e00f5c4453910682e2d9dd594334539c7b7817dabb765f/anyio-4.7.0.tar.gz", hash = "sha256:2f834749c602966b7d456a7567cafcb309f96482b5081d14ac93ccd457f9dd48", size = 177076 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250 }, + { url = "https://files.pythonhosted.org/packages/a0/7a/4daaf3b6c08ad7ceffea4634ec206faeff697526421c20f07628c7372156/anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352", size = 93052 }, ] [[package]] @@ -78,7 +79,7 @@ name = "click" version = "8.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/45/2b/7ebad1e59a99207d417c0784f7fb67893465eef84b5b47c788324f1b4095/click-8.1.0.tar.gz", hash = "sha256:977c213473c7665d3aa092b41ff12063227751c41d7b17165013e10069cc5cd2", size = 329986 } wheels = [ @@ -191,7 +192,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.3.0.dev0" +version = "1.3.0" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -201,6 +202,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "sse-starlette" }, { name = "starlette" }, + { name = "typer" }, { name = "uvicorn" }, ] @@ -220,6 +222,7 @@ dev = [ { name = "pytest-flakefinder" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "sse-starlette" }, { name = "trio" }, ] @@ -234,9 +237,11 @@ requires-dist = [ { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, + { name = "typer", specifier = ">=0.12.4" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, { name = "uvicorn", specifier = ">=0.23.1" }, ] +provides-extras = ["rich", "cli"] [package.metadata.requires-dev] dev = [ @@ -245,6 +250,7 @@ dev = [ { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, + { name = "sse-starlette", specifier = ">=2.3.4" }, { name = "trio", specifier = ">=0.26.2" }, ] @@ -647,26 +653,27 @@ wheels = [ [[package]] name = "sse-starlette" -version = "1.6.1" +version = "2.3.5" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "anyio" }, { name = "starlette" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/40/88/0af7f586894cfe61bd212f33e571785c4570085711b24fb7445425a5eeb0/sse-starlette-1.6.1.tar.gz", hash = "sha256:6208af2bd7d0887c92f1379da14bd1f4db56bd1274cc5d36670c683d2aa1de6a", size = 14555 } +sdist = { url = "https://files.pythonhosted.org/packages/10/5f/28f45b1ff14bee871bacafd0a97213f7ec70e389939a80c60c0fb72a9fc9/sse_starlette-2.3.5.tar.gz", hash = "sha256:228357b6e42dcc73a427990e2b4a03c023e2495ecee82e14f07ba15077e334b2", size = 17511 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/f7/499e5d0c181a52a205d5b0982fd71cf162d1e070c97dca90c60520bbf8bf/sse_starlette-1.6.1-py3-none-any.whl", hash = "sha256:d8f18f1c633e355afe61cc5e9c92eea85badcb8b2d56ec8cfb0a006994aa55da", size = 9553 }, + { url = "https://files.pythonhosted.org/packages/c8/48/3e49cf0f64961656402c0023edbc51844fe17afe53ab50e958a6dbbbd499/sse_starlette-2.3.5-py3-none-any.whl", hash = "sha256:251708539a335570f10eaaa21d1848a10c42ee6dc3a9cf37ef42266cdb1c52a8", size = 10233 }, ] [[package]] name = "starlette" -version = "0.27.0" +version = "0.41.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/68/559bed5484e746f1ab2ebbe22312f2c25ec62e4b534916d41a8c21147bf8/starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75", size = 51394 } +sdist = { url = "https://files.pythonhosted.org/packages/1a/4c/9b5764bd22eec91c4039ef4c55334e9187085da2d8a2df7bd570869aae18/starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835", size = 2574159 } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/f8/e2cca22387965584a409795913b774235752be4176d276714e15e1a58884/starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91", size = 66978 }, + { url = "https://files.pythonhosted.org/packages/96/00/2b325970b3060c7cecebab6d295afe763365822b1306a12eeab198f74323/starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7", size = 73225 }, ] [[package]]