Skip to content

Commit 11c7ced

Browse files
committed
Merge branch 'main' into fix-sse-client-blocks-indefinitely-when-server-has-incorrect-base-url
2 parents 7d2df66 + 58c5e72 commit 11c7ced

File tree

8 files changed

+184
-22
lines changed

8 files changed

+184
-22
lines changed

examples/servers/simple-prompt/mcp_simple_prompt/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def get_prompt(
9090
if transport == "sse":
9191
from mcp.server.sse import SseServerTransport
9292
from starlette.applications import Starlette
93+
from starlette.responses import Response
9394
from starlette.routing import Mount, Route
9495

9596
sse = SseServerTransport("/messages/")
@@ -101,6 +102,7 @@ async def handle_sse(request):
101102
await app.run(
102103
streams[0], streams[1], app.create_initialization_options()
103104
)
105+
return Response()
104106

105107
starlette_app = Starlette(
106108
debug=True,

examples/servers/simple-resource/mcp_simple_resource/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def read_resource(uri: FileUrl) -> str | bytes:
4646
if transport == "sse":
4747
from mcp.server.sse import SseServerTransport
4848
from starlette.applications import Starlette
49+
from starlette.responses import Response
4950
from starlette.routing import Mount, Route
5051

5152
sse = SseServerTransport("/messages/")
@@ -57,11 +58,12 @@ async def handle_sse(request):
5758
await app.run(
5859
streams[0], streams[1], app.create_initialization_options()
5960
)
61+
return Response()
6062

6163
starlette_app = Starlette(
6264
debug=True,
6365
routes=[
64-
Route("/sse", endpoint=handle_sse),
66+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
6567
Mount("/messages/", app=sse.handle_post_message),
6668
],
6769
)

examples/servers/simple-tool/mcp_simple_tool/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ async def list_tools() -> list[types.Tool]:
6060
if transport == "sse":
6161
from mcp.server.sse import SseServerTransport
6262
from starlette.applications import Starlette
63+
from starlette.responses import Response
6364
from starlette.routing import Mount, Route
6465

6566
sse = SseServerTransport("/messages/")
@@ -71,11 +72,12 @@ async def handle_sse(request):
7172
await app.run(
7273
streams[0], streams[1], app.create_initialization_options()
7374
)
75+
return Response()
7476

7577
starlette_app = Starlette(
7678
debug=True,
7779
routes=[
78-
Route("/sse", endpoint=handle_sse),
80+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
7981
Mount("/messages/", app=sse.handle_post_message),
8082
],
8183
)

src/mcp/server/fastmcp/server.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
589589
streams[1],
590590
self._mcp_server.create_initialization_options(),
591591
)
592+
return Response()
592593

593594
# Create routes
594595
routes: list[Route | Mount] = []
@@ -624,19 +625,42 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
624625
)
625626
)
626627

627-
routes.append(
628-
Route(
629-
self.settings.sse_path,
630-
endpoint=RequireAuthMiddleware(handle_sse, required_scopes),
631-
methods=["GET"],
628+
# When auth is not configured, we shouldn't require auth
629+
if self._auth_server_provider:
630+
# Auth is enabled, wrap the endpoints with RequireAuthMiddleware
631+
routes.append(
632+
Route(
633+
self.settings.sse_path,
634+
endpoint=RequireAuthMiddleware(handle_sse, required_scopes),
635+
methods=["GET"],
636+
)
632637
)
633-
)
634-
routes.append(
635-
Mount(
636-
self.settings.message_path,
637-
app=RequireAuthMiddleware(sse.handle_post_message, required_scopes),
638+
routes.append(
639+
Mount(
640+
self.settings.message_path,
641+
app=RequireAuthMiddleware(sse.handle_post_message, required_scopes),
642+
)
643+
)
644+
else:
645+
# Auth is disabled, no need for RequireAuthMiddleware
646+
# Since handle_sse is an ASGI app, we need to create a compatible endpoint
647+
async def sse_endpoint(request: Request) -> None:
648+
# Convert the Starlette request to ASGI parameters
649+
await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage]
650+
651+
routes.append(
652+
Route(
653+
self.settings.sse_path,
654+
endpoint=sse_endpoint,
655+
methods=["GET"],
656+
)
657+
)
658+
routes.append(
659+
Mount(
660+
self.settings.message_path,
661+
app=sse.handle_post_message,
662+
)
638663
)
639-
)
640664
# mount these routes last, so they have the lowest route matching precedence
641665
routes.extend(self._custom_starlette_routes)
642666

src/mcp/server/session.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ def __init__(
104104
self._exit_stack.push_async_callback(
105105
lambda: self._incoming_message_stream_reader.aclose()
106106
)
107-
self._exit_stack.push_async_callback(
108-
lambda: self._incoming_message_stream_writer.aclose()
109-
)
110107

111108
@property
112109
def client_params(self) -> types.InitializeRequestParams | None:
@@ -144,6 +141,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
144141

145142
return True
146143

144+
async def _receive_loop(self) -> None:
145+
async with self._incoming_message_stream_writer:
146+
await super()._receive_loop()
147+
147148
async def _received_request(
148149
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
149150
):

src/mcp/server/sse.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
# Create Starlette routes for SSE and message handling
1212
routes = [
13-
Route("/sse", endpoint=handle_sse),
13+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
1414
Mount("/messages/", app=sse.handle_post_message),
1515
]
1616
@@ -22,12 +22,18 @@ async def handle_sse(request):
2222
await app.run(
2323
streams[0], streams[1], app.create_initialization_options()
2424
)
25+
# Return empty response to avoid NoneType error
26+
return Response()
2527
2628
# Create and run Starlette app
2729
starlette_app = Starlette(routes=routes)
2830
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
2931
```
3032
33+
Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
34+
object is not callable" error when client disconnects. The example above returns
35+
an empty Response() after the SSE connection ends to fix this.
36+
3137
See SseServerTransport class documentation for more details.
3238
"""
3339

@@ -120,11 +126,22 @@ async def sse_writer():
120126
)
121127

122128
async with anyio.create_task_group() as tg:
123-
response = EventSourceResponse(
124-
content=sse_stream_reader, data_sender_callable=sse_writer
125-
)
129+
130+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
131+
"""
132+
The EventSourceResponse returning signals a client close / disconnect.
133+
In this case we close our side of the streams to signal the client that
134+
the connection has been closed.
135+
"""
136+
await EventSourceResponse(
137+
content=sse_stream_reader, data_sender_callable=sse_writer
138+
)(scope, receive, send)
139+
await read_stream_writer.aclose()
140+
await write_stream_reader.aclose()
141+
logging.debug(f"Client session disconnected {session_id}")
142+
126143
logger.debug("Starting SSE response task")
127-
tg.start_soon(response, scope, receive, send)
144+
tg.start_soon(response_wrapper, scope, receive, send)
128145

129146
logger.debug("Yielding read and write streams")
130147
yield (read_stream, write_stream)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""
2+
Integration tests for FastMCP server functionality.
3+
4+
These tests validate the proper functioning of FastMCP in various configurations,
5+
including with and without authentication.
6+
"""
7+
8+
import multiprocessing
9+
import socket
10+
import time
11+
from collections.abc import Generator
12+
13+
import pytest
14+
import uvicorn
15+
16+
from mcp.client.session import ClientSession
17+
from mcp.client.sse import sse_client
18+
from mcp.server.fastmcp import FastMCP
19+
from mcp.types import InitializeResult, TextContent
20+
21+
22+
@pytest.fixture
23+
def server_port() -> int:
24+
"""Get a free port for testing."""
25+
with socket.socket() as s:
26+
s.bind(("127.0.0.1", 0))
27+
return s.getsockname()[1]
28+
29+
30+
@pytest.fixture
31+
def server_url(server_port: int) -> str:
32+
"""Get the server URL for testing."""
33+
return f"http://127.0.0.1:{server_port}"
34+
35+
36+
# Create a function to make the FastMCP server app
37+
def make_fastmcp_app():
38+
"""Create a FastMCP server without auth settings."""
39+
from starlette.applications import Starlette
40+
41+
mcp = FastMCP(name="NoAuthServer")
42+
43+
# Add a simple tool
44+
@mcp.tool(description="A simple echo tool")
45+
def echo(message: str) -> str:
46+
return f"Echo: {message}"
47+
48+
# Create the SSE app
49+
app: Starlette = mcp.sse_app()
50+
51+
return mcp, app
52+
53+
54+
def run_server(server_port: int) -> None:
55+
"""Run the server."""
56+
_, app = make_fastmcp_app()
57+
server = uvicorn.Server(
58+
config=uvicorn.Config(
59+
app=app, host="127.0.0.1", port=server_port, log_level="error"
60+
)
61+
)
62+
print(f"Starting server on port {server_port}")
63+
server.run()
64+
65+
66+
@pytest.fixture()
67+
def server(server_port: int) -> Generator[None, None, None]:
68+
"""Start the server in a separate process and clean up after the test."""
69+
proc = multiprocessing.Process(target=run_server, args=(server_port,), daemon=True)
70+
print("Starting server process")
71+
proc.start()
72+
73+
# Wait for server to be running
74+
max_attempts = 20
75+
attempt = 0
76+
print("Waiting for server to start")
77+
while attempt < max_attempts:
78+
try:
79+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
80+
s.connect(("127.0.0.1", server_port))
81+
break
82+
except ConnectionRefusedError:
83+
time.sleep(0.1)
84+
attempt += 1
85+
else:
86+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
87+
88+
yield
89+
90+
print("Killing server")
91+
proc.kill()
92+
proc.join(timeout=2)
93+
if proc.is_alive():
94+
print("Server process failed to terminate")
95+
96+
97+
@pytest.mark.anyio
98+
async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
99+
"""Test that FastMCP works when auth settings are not provided."""
100+
# Connect to the server
101+
async with sse_client(server_url + "/sse") as streams:
102+
async with ClientSession(*streams) as session:
103+
# Test initialization
104+
result = await session.initialize()
105+
assert isinstance(result, InitializeResult)
106+
assert result.serverInfo.name == "NoAuthServer"
107+
108+
# Test that we can call tools without authentication
109+
tool_result = await session.call_tool("echo", {"message": "hello"})
110+
assert len(tool_result.content) == 1
111+
assert isinstance(tool_result.content[0], TextContent)
112+
assert tool_result.content[0].text == "Echo: hello"

tests/shared/test_sse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic import AnyUrl
1111
from starlette.applications import Starlette
1212
from starlette.requests import Request
13+
from starlette.responses import Response
1314
from starlette.routing import Mount, Route
1415

1516
from mcp.client.session import ClientSession
@@ -83,13 +84,14 @@ def make_server_app() -> Starlette:
8384
sse = SseServerTransport("/messages/")
8485
server = ServerTest()
8586

86-
async def handle_sse(request: Request) -> None:
87+
async def handle_sse(request: Request) -> Response:
8788
async with sse.connect_sse(
8889
request.scope, request.receive, request._send
8990
) as streams:
9091
await server.run(
9192
streams[0], streams[1], server.create_initialization_options()
9293
)
94+
return Response()
9395

9496
app = Starlette(
9597
routes=[

0 commit comments

Comments
 (0)