From b3d06575000b3fd997c13a3d8b17214a487522ce Mon Sep 17 00:00:00 2001 From: Akash DSouza Date: Thu, 1 May 2025 17:11:47 -0700 Subject: [PATCH 1/5] Fix --- .../simple-prompt/mcp_simple_prompt/server.py | 10 +++--- .../mcp_simple_resource/server.py | 10 +++--- .../simple-tool/mcp_simple_tool/server.py | 10 +++--- src/mcp/server/fastmcp/server.py | 10 ++---- src/mcp/server/session.py | 4 +++ src/mcp/server/sse.py | 31 +++++++++++++------ tests/shared/test_sse.py | 11 +++---- 7 files changed, 45 insertions(+), 41 deletions(-) diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 0552f2770..3e4e8a75a 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -90,14 +90,12 @@ async def get_prompt( if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.routing import Mount, Route + from starlette.routing import Mount sse = SseServerTransport("/messages/") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async def handle_sse(scope, receive, send): + async with sse.connect_sse(scope, receive, send) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) @@ -105,7 +103,7 @@ async def handle_sse(request): starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Mount("/sse", app=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 0ec1d926a..9053ca1a4 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -46,14 +46,12 @@ async def read_resource(uri: FileUrl) -> str | bytes: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.routing import Mount, Route + from starlette.routing import Mount sse = SseServerTransport("/messages/") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async def handle_sse(scope, receive, send): + async with sse.connect_sse(scope, receive, send) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) @@ -61,7 +59,7 @@ async def handle_sse(request): starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Mount("/sse", app=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 3eace52ea..8ec2eb12c 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -60,14 +60,12 @@ async def list_tools() -> list[types.Tool]: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.routing import Mount, Route + from starlette.routing import Mount sse = SseServerTransport("/messages/") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async def handle_sse(scope, receive, send): + async with sse.connect_sse(scope, receive, send) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) @@ -75,7 +73,7 @@ async def handle_sse(request): starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Mount("/sse", app=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index f3bb2586a..dbc008811 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -19,8 +19,8 @@ from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette -from starlette.requests import Request from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager @@ -481,12 +481,8 @@ def sse_app(self) -> Starlette: """Return an instance of the SSE server app.""" sse = SseServerTransport(self.settings.message_path) - async def handle_sse(request: Request) -> None: - async with sse.connect_sse( - request.scope, - request.receive, - request._send, # type: ignore[reportPrivateUsage] - ) as streams: + async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None: + async with sse.connect_sse(scope, receive, send) as streams: await self._mcp_server.run( streams[0], streams[1], diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b9..6321a54a7 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -137,6 +137,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return True + async def _receive_loop(self) -> None: + async with self._incoming_message_stream_writer: + await super()._receive_loop() + async def _received_request( self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25bf..94c3b6e26 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -10,15 +10,13 @@ # Create Starlette routes for SSE and message handling routes = [ - Route("/sse", endpoint=handle_sse), + Mount("/sse", app=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] # Define handler functions - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async def handle_sse(scope, receive, send): + async with sse.connect_sse(scope, receive, send) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) @@ -28,6 +26,11 @@ async def handle_sse(request): uvicorn.run(starlette_app, host="0.0.0.0", port=port) ``` +Note: If you get a "TypeError: 'NoneType' object is not callable" error, +you need to change from Route("/sse", endpoint=handle_sse) to +Mount("/sse", app=handle_sse) and update the handle_sse signature to +accept (scope, receive, send) instead of (request). See examples for details. + See SseServerTransport class documentation for more details. """ @@ -121,11 +124,21 @@ async def sse_writer(): ) async with anyio.create_task_group() as tg: - response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer - ) + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """ + The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse( + content=sse_stream_reader, data_sender_callable=sse_writer + )(scope, receive, send) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + logger.debug("Starting SSE response task") - tg.start_soon(response, scope, receive, send) + tg.start_soon(response_wrapper, scope, receive, send) logger.debug("Yielding read and write streams") yield (read_stream, write_stream) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index f5158c3c3..9aa3cbf09 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -9,8 +9,7 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.requests import Request -from starlette.routing import Mount, Route +from starlette.routing import Mount from mcp.client.session import ClientSession from mcp.client.sse import sse_client @@ -83,17 +82,15 @@ def make_server_app() -> Starlette: sse = SseServerTransport("/messages/") server = ServerTest() - async def handle_sse(request: Request) -> None: - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async def handle_sse(scope, receive, send) -> None: + async with sse.connect_sse(scope, receive, send) as streams: await server.run( streams[0], streams[1], server.create_initialization_options() ) app = Starlette( routes=[ - Route("/sse", endpoint=handle_sse), + Mount("/sse", app=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) From 74dec3ae18d29e3682056dbf656864924f0682e0 Mon Sep 17 00:00:00 2001 From: Akash DSouza Date: Thu, 1 May 2025 17:26:35 -0700 Subject: [PATCH 2/5] trailing slashes for mount --- src/mcp/server/fastmcp/server.py | 2 +- tests/shared/test_sse.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 65d342e1a..ad0def90d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -87,7 +87,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # HTTP settings host: str = "0.0.0.0" port: int = 8000 - sse_path: str = "/sse" + sse_path: str = "/sse/" message_path: str = "/messages/" # resource settings diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 9aa3cbf09..1f9b404b9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -161,7 +161,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: async with anyio.create_task_group(): async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: + async with http_client.stream("GET", "/sse/") as response: assert response.status_code == 200 assert ( response.headers["content-type"] @@ -185,7 +185,7 @@ async def connection_test() -> None: @pytest.mark.anyio async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: + async with sse_client(server_url + "/sse/") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -201,7 +201,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non async def initialized_sse_client_session( server, server_url: str ) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with sse_client(server_url + "/sse/", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session From 79f79b3b58da0aef43f2e289329b45beb604985b Mon Sep 17 00:00:00 2001 From: Akash DSouza Date: Thu, 1 May 2025 17:31:42 -0700 Subject: [PATCH 3/5] remove exit stack cleanup now that we have context mgr --- src/mcp/server/session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6321a54a7..a2229c73c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -97,9 +97,6 @@ def __init__( self._exit_stack.push_async_callback( lambda: self._incoming_message_stream_reader.aclose() ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_writer.aclose() - ) @property def client_params(self) -> types.InitializeRequestParams | None: From 41ba6698a59a184d12c5b88827c24d42521b7c03 Mon Sep 17 00:00:00 2001 From: Akash DSouza Date: Thu, 1 May 2025 18:17:49 -0700 Subject: [PATCH 4/5] fix --- .../simple-prompt/mcp_simple_prompt/server.py | 11 ++++++++--- .../mcp_simple_resource/server.py | 12 ++++++++---- .../simple-tool/mcp_simple_tool/server.py | 12 ++++++++---- src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/sse.py | 18 +++++++++++------- tests/shared/test_sse.py | 14 +++++++++----- 6 files changed, 46 insertions(+), 24 deletions(-) diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 3e4e8a75a..e79b5e700 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -2,6 +2,8 @@ import click import mcp.types as types from mcp.server.lowlevel import Server +from starlette.responses import Response +from starlette.routing import Route def create_messages( @@ -94,16 +96,19 @@ async def get_prompt( sse = SseServerTransport("/messages/") - async def handle_sse(scope, receive, send): - async with sse.connect_sse(scope, receive, send) as streams: + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Mount("/sse", app=handle_sse), + Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 9053ca1a4..06f567fbe 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -46,20 +46,24 @@ async def read_resource(uri: FileUrl) -> str | bytes: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.routing import Mount + from starlette.responses import Response + from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") - async def handle_sse(scope, receive, send): - async with sse.connect_sse(scope, receive, send) as streams: + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Mount("/sse", app=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 8ec2eb12c..04224af5d 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -60,20 +60,24 @@ async def list_tools() -> list[types.Tool]: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.routing import Mount + from starlette.responses import Response + from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") - async def handle_sse(scope, receive, send): - async with sse.connect_sse(scope, receive, send) as streams: + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Mount("/sse", app=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index ad0def90d..09824df18 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -87,7 +87,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # HTTP settings host: str = "0.0.0.0" port: int = 8000 - sse_path: str = "/sse/" + sse_path: str = "/sse" message_path: str = "/messages/" # resource settings @@ -589,6 +589,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): streams[1], self._mcp_server.create_initialization_options(), ) + return Response() # Create routes routes: list[Route | Mount] = [] diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 94c3b6e26..9ad74e798 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -10,26 +10,29 @@ # Create Starlette routes for SSE and message handling routes = [ - Mount("/sse", app=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ] # Define handler functions - async def handle_sse(scope, receive, send): - async with sse.connect_sse(scope, receive, send) as streams: + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: await app.run( streams[0], streams[1], app.create_initialization_options() ) + # Return empty response to avoid NoneType error + return Response() # Create and run Starlette app starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="0.0.0.0", port=port) ``` -Note: If you get a "TypeError: 'NoneType' object is not callable" error, -you need to change from Route("/sse", endpoint=handle_sse) to -Mount("/sse", app=handle_sse) and update the handle_sse signature to -accept (scope, receive, send) instead of (request). See examples for details. +Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' +object is not callable" error when client disconnects. The example above shows the +correct approach by returning an empty Response() after the SSE connection ends. See SseServerTransport class documentation for more details. """ @@ -124,6 +127,7 @@ async def sse_writer(): ) async with anyio.create_task_group() as tg: + async def response_wrapper(scope: Scope, receive: Receive, send: Send): """ The EventSourceResponse returning signals a client close / disconnect. diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 1f9b404b9..befed02c1 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -9,7 +9,8 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.routing import Mount +from starlette.responses import Response +from starlette.routing import Mount, Route from mcp.client.session import ClientSession from mcp.client.sse import sse_client @@ -82,15 +83,18 @@ def make_server_app() -> Starlette: sse = SseServerTransport("/messages/") server = ServerTest() - async def handle_sse(scope, receive, send) -> None: - async with sse.connect_sse(scope, receive, send) as streams: + async def handle_sse(request) -> None: + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: await server.run( streams[0], streams[1], server.create_initialization_options() ) + return Response() app = Starlette( routes=[ - Mount("/sse", app=handle_sse), + Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) @@ -161,7 +165,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: async with anyio.create_task_group(): async def connection_test() -> None: - async with http_client.stream("GET", "/sse/") as response: + async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 assert ( response.headers["content-type"] From 952b434eb3573ecaa4cbf577ec8ecf303763df19 Mon Sep 17 00:00:00 2001 From: Akash DSouza Date: Thu, 1 May 2025 18:20:48 -0700 Subject: [PATCH 5/5] misc fixes --- examples/servers/simple-prompt/mcp_simple_prompt/server.py | 5 ++--- src/mcp/server/sse.py | 4 ++-- tests/shared/test_sse.py | 7 ++++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index e79b5e700..bc14b7cd0 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -2,8 +2,6 @@ import click import mcp.types as types from mcp.server.lowlevel import Server -from starlette.responses import Response -from starlette.routing import Route def create_messages( @@ -92,7 +90,8 @@ async def get_prompt( if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.routing import Mount + from starlette.responses import Response + from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9ad74e798..5b0f6004d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -31,8 +31,8 @@ async def handle_sse(request): ``` Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' -object is not callable" error when client disconnects. The example above shows the -correct approach by returning an empty Response() after the SSE connection ends. +object is not callable" error when client disconnects. The example above returns +an empty Response() after the SSE connection ends to fix this. See SseServerTransport class documentation for more details. """ diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index befed02c1..4558bb88c 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -9,6 +9,7 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette +from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route @@ -83,7 +84,7 @@ def make_server_app() -> Starlette: sse = SseServerTransport("/messages/") server = ServerTest() - async def handle_sse(request) -> None: + async def handle_sse(request: Request) -> Response: async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: @@ -189,7 +190,7 @@ async def connection_test() -> None: @pytest.mark.anyio async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse/") as streams: + async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -205,7 +206,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non async def initialized_sse_client_session( server, server_url: str ) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse/", sse_read_timeout=0.5) as streams: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session