Skip to content

Commit 512cf58

Browse files
authored
Merge branch 'main' into 307_redirect
2 parents 17f5444 + f55831e commit 512cf58

File tree

5 files changed

+214
-15
lines changed

5 files changed

+214
-15
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,9 +955,16 @@ async def list_prompts(self) -> list[MCPPrompt]:
955955
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
956956
"""Get a prompt by name with arguments."""
957957
try:
958-
messages = await self._prompt_manager.render_prompt(name, arguments)
958+
prompt = self._prompt_manager.get_prompt(name)
959+
if not prompt:
960+
raise ValueError(f"Unknown prompt: {name}")
959961

960-
return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages))
962+
messages = await prompt.render(arguments)
963+
964+
return GetPromptResult(
965+
description=prompt.description,
966+
messages=pydantic_core.to_jsonable_python(messages),
967+
)
961968
except Exception as e:
962969
logger.exception(f"Error getting prompt {name}")
963970
raise ValueError(str(e))

src/mcp/server/streamable_http.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def __init__(
174174
] = {}
175175
self._terminated = False
176176

177+
@property
178+
def is_terminated(self) -> bool:
179+
"""Check if this transport has been explicitly terminated."""
180+
return self._terminated
181+
177182
def _create_error_response(
178183
self,
179184
error_message: str,

src/mcp/server/streamable_http_manager.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class StreamableHTTPSessionManager:
5252
json_response: Whether to use JSON responses instead of SSE streams
5353
stateless: If True, creates a completely fresh transport for each request
5454
with no session tracking or state persistence between requests.
55-
5655
"""
5756

5857
def __init__(
@@ -173,12 +172,15 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
173172
async with http_transport.connect() as streams:
174173
read_stream, write_stream = streams
175174
task_status.started()
176-
await self.app.run(
177-
read_stream,
178-
write_stream,
179-
self.app.create_initialization_options(),
180-
stateless=True,
181-
)
175+
try:
176+
await self.app.run(
177+
read_stream,
178+
write_stream,
179+
self.app.create_initialization_options(),
180+
stateless=True,
181+
)
182+
except Exception:
183+
logger.exception("Stateless session crashed")
182184

183185
# Assert task group is not None for type checking
184186
assert self._task_group is not None
@@ -233,12 +235,31 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
233235
async with http_transport.connect() as streams:
234236
read_stream, write_stream = streams
235237
task_status.started()
236-
await self.app.run(
237-
read_stream,
238-
write_stream,
239-
self.app.create_initialization_options(),
240-
stateless=False, # Stateful mode
241-
)
238+
try:
239+
await self.app.run(
240+
read_stream,
241+
write_stream,
242+
self.app.create_initialization_options(),
243+
stateless=False, # Stateful mode
244+
)
245+
except Exception as e:
246+
logger.error(
247+
f"Session {http_transport.mcp_session_id} crashed: {e}",
248+
exc_info=True,
249+
)
250+
finally:
251+
# Only remove from instances if not terminated
252+
if (
253+
http_transport.mcp_session_id
254+
and http_transport.mcp_session_id in self._server_instances
255+
and not http_transport.is_terminated
256+
):
257+
logger.info(
258+
"Cleaning up crashed session "
259+
f"{http_transport.mcp_session_id} from "
260+
"active instances."
261+
)
262+
del self._server_instances[http_transport.mcp_session_id]
242263

243264
# Assert task group is not None for type checking
244265
assert self._task_group is not None

tests/server/fastmcp/test_server.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,46 @@ def fn(name: str) -> str:
982982
assert isinstance(content, TextContent)
983983
assert content.text == "Hello, World!"
984984

985+
@pytest.mark.anyio
986+
async def test_get_prompt_with_description(self):
987+
"""Test getting a prompt through MCP protocol."""
988+
mcp = FastMCP()
989+
990+
@mcp.prompt(description="Test prompt description")
991+
def fn(name: str) -> str:
992+
return f"Hello, {name}!"
993+
994+
async with client_session(mcp._mcp_server) as client:
995+
result = await client.get_prompt("fn", {"name": "World"})
996+
assert result.description == "Test prompt description"
997+
998+
@pytest.mark.anyio
999+
async def test_get_prompt_without_description(self):
1000+
"""Test getting a prompt without description returns empty string."""
1001+
mcp = FastMCP()
1002+
1003+
@mcp.prompt()
1004+
def fn(name: str) -> str:
1005+
return f"Hello, {name}!"
1006+
1007+
async with client_session(mcp._mcp_server) as client:
1008+
result = await client.get_prompt("fn", {"name": "World"})
1009+
assert result.description == ""
1010+
1011+
@pytest.mark.anyio
1012+
async def test_get_prompt_with_docstring_description(self):
1013+
"""Test prompt uses docstring as description when not explicitly provided."""
1014+
mcp = FastMCP()
1015+
1016+
@mcp.prompt()
1017+
def fn(name: str) -> str:
1018+
"""This is the function docstring."""
1019+
return f"Hello, {name}!"
1020+
1021+
async with client_session(mcp._mcp_server) as client:
1022+
result = await client.get_prompt("fn", {"name": "World"})
1023+
assert result.description == "This is the function docstring."
1024+
9851025
@pytest.mark.anyio
9861026
async def test_get_prompt_with_resource(self):
9871027
"""Test getting a prompt that returns resource content."""

tests/server/test_streamable_http_manager.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tests for StreamableHTTPSessionManager."""
22

3+
from unittest.mock import AsyncMock
4+
35
import anyio
46
import pytest
57

68
from mcp.server.lowlevel import Server
9+
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
710
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
811

912

@@ -71,3 +74,126 @@ async def send(message):
7174
await manager.handle_request(scope, receive, send)
7275

7376
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)
77+
78+
79+
class TestException(Exception):
80+
__test__ = False # Prevent pytest from collecting this as a test class
81+
pass
82+
83+
84+
@pytest.fixture
85+
async def running_manager():
86+
app = Server("test-cleanup-server")
87+
# It's important that the app instance used by the manager is the one we can patch
88+
manager = StreamableHTTPSessionManager(app=app)
89+
async with manager.run():
90+
# Patch app.run here if it's simpler, or patch it within the test
91+
yield manager, app
92+
93+
94+
@pytest.mark.anyio
95+
async def test_stateful_session_cleanup_on_graceful_exit(running_manager):
96+
manager, app = running_manager
97+
98+
mock_mcp_run = AsyncMock(return_value=None)
99+
# This will be called by StreamableHTTPSessionManager's run_server -> self.app.run
100+
app.run = mock_mcp_run
101+
102+
sent_messages = []
103+
104+
async def mock_send(message):
105+
sent_messages.append(message)
106+
107+
scope = {
108+
"type": "http",
109+
"method": "POST",
110+
"path": "/mcp",
111+
"headers": [(b"content-type", b"application/json")],
112+
}
113+
114+
async def mock_receive():
115+
return {"type": "http.request", "body": b"", "more_body": False}
116+
117+
# Trigger session creation
118+
await manager.handle_request(scope, mock_receive, mock_send)
119+
120+
# Extract session ID from response headers
121+
session_id = None
122+
for msg in sent_messages:
123+
if msg["type"] == "http.response.start":
124+
for header_name, header_value in msg.get("headers", []):
125+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
126+
session_id = header_value.decode()
127+
break
128+
if session_id: # Break outer loop if session_id is found
129+
break
130+
131+
assert session_id is not None, "Session ID not found in response headers"
132+
133+
# Ensure MCPServer.run was called
134+
mock_mcp_run.assert_called_once()
135+
136+
# At this point, mock_mcp_run has completed, and the finally block in
137+
# StreamableHTTPSessionManager's run_server should have executed.
138+
139+
# To ensure the task spawned by handle_request finishes and cleanup occurs:
140+
# Give other tasks a chance to run. This is important for the finally block.
141+
await anyio.sleep(0.01)
142+
143+
assert session_id not in manager._server_instances, (
144+
"Session ID should be removed from _server_instances after graceful exit"
145+
)
146+
assert not manager._server_instances, "No sessions should be tracked after the only session exits gracefully"
147+
148+
149+
@pytest.mark.anyio
150+
async def test_stateful_session_cleanup_on_exception(running_manager):
151+
manager, app = running_manager
152+
153+
mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash"))
154+
app.run = mock_mcp_run
155+
156+
sent_messages = []
157+
158+
async def mock_send(message):
159+
sent_messages.append(message)
160+
# If an exception occurs, the transport might try to send an error response
161+
# For this test, we mostly care that the session is established enough
162+
# to get an ID
163+
if message["type"] == "http.response.start" and message["status"] >= 500:
164+
pass # Expected if TestException propagates that far up the transport
165+
166+
scope = {
167+
"type": "http",
168+
"method": "POST",
169+
"path": "/mcp",
170+
"headers": [(b"content-type", b"application/json")],
171+
}
172+
173+
async def mock_receive():
174+
return {"type": "http.request", "body": b"", "more_body": False}
175+
176+
# Trigger session creation
177+
await manager.handle_request(scope, mock_receive, mock_send)
178+
179+
session_id = None
180+
for msg in sent_messages:
181+
if msg["type"] == "http.response.start":
182+
for header_name, header_value in msg.get("headers", []):
183+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
184+
session_id = header_value.decode()
185+
break
186+
if session_id: # Break outer loop if session_id is found
187+
break
188+
189+
assert session_id is not None, "Session ID not found in response headers"
190+
191+
mock_mcp_run.assert_called_once()
192+
193+
# Give other tasks a chance to run to ensure the finally block executes
194+
await anyio.sleep(0.01)
195+
196+
assert session_id not in manager._server_instances, (
197+
"Session ID should be removed from _server_instances after an exception"
198+
)
199+
assert not manager._server_instances, "No sessions should be tracked after the only session crashes"

0 commit comments

Comments
 (0)