Skip to content

Commit 428e7a4

Browse files
committed
Fix sync/async detection and add failing test for reconnects
1 parent 17fc21e commit 428e7a4

File tree

2 files changed

+60
-10
lines changed

2 files changed

+60
-10
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,24 +355,20 @@ def _get_invocation_mode(self, info: Tool, client_supports_async: bool) -> Liter
355355

356356
# New clients see the invocationMode field
357357
modes = info.invocation_modes
358-
if self._is_async_only(modes):
359-
return "async" # Async-only
360-
if self._is_sync_only(modes) or self._is_hybrid(modes):
361-
return "sync" # Hybrid or explicit sync
358+
if self._is_async_capable(modes):
359+
return "async" # Hybrid or explicit async
360+
if self._is_sync_only(modes):
361+
return "sync"
362362
return None
363363

364-
def _is_async_only(self, modes: list[InvocationMode]) -> bool:
364+
def _is_async_capable(self, modes: list[InvocationMode]) -> bool:
365365
"""Return True if invocation_modes is async-only."""
366-
return modes == ["async"]
366+
return "async" in modes
367367

368368
def _is_sync_only(self, modes: list[InvocationMode]) -> bool:
369369
"""Return True if invocation_modes is sync-only."""
370370
return modes == ["sync"]
371371

372-
def _is_hybrid(self, modes: list[InvocationMode]) -> bool:
373-
"""Return True if invocation_modes contains both sync and async."""
374-
return "sync" in modes and "async" in modes and len(modes) > 1
375-
376372
async def list_tools(self) -> list[MCPTool]:
377373
"""List all available tools."""
378374
tools = self._tool_manager.list_tools()

tests/server/fastmcp/test_integration.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ async def test_async_tool_basic(server_transport: str, server_url: str) -> None:
757757
@pytest.mark.parametrize(
758758
"server_transport",
759759
[
760+
# ("async_tool_basic", "sse"),
760761
("async_tool_basic", "streamable-http"),
761762
],
762763
indirect=True,
@@ -795,6 +796,59 @@ async def test_async_tool_basic_legacy_protocol(server_transport: str, server_ur
795796
assert "Processed: HELLO" in hybrid_result.content[0].text
796797

797798

799+
@pytest.mark.anyio
800+
@pytest.mark.parametrize(
801+
"server_transport",
802+
[
803+
# ("async_tool_basic", "sse"),
804+
("async_tool_basic", "streamable-http"),
805+
],
806+
indirect=True,
807+
)
808+
async def test_async_tool_reconnection(server_transport: str, server_url: str) -> None:
809+
"""Test that async operations can be retrieved after reconnecting with a new session."""
810+
transport = server_transport
811+
client_cm1 = create_client_for_transport(transport, server_url)
812+
813+
# Start async operation in first session
814+
async with client_cm1 as client_streams:
815+
read_stream, write_stream = unpack_streams(client_streams)
816+
async with ClientSession(read_stream, write_stream, protocol_version="next") as session1:
817+
await session1.initialize()
818+
819+
# Start async operation
820+
result = await session1.call_tool("process_text", {"text": "test data"})
821+
assert result.operation is not None
822+
token = result.operation.token
823+
824+
# Reconnect with new session and retrieve result
825+
client_cm2 = create_client_for_transport(transport, server_url)
826+
async with client_cm2 as client_streams:
827+
read_stream, write_stream = unpack_streams(client_streams)
828+
async with ClientSession(read_stream, write_stream, protocol_version="next") as session2:
829+
await session2.initialize()
830+
831+
# Poll for completion in new session
832+
max_attempts = 20
833+
attempt = 0
834+
while attempt < max_attempts:
835+
status = await session2.get_operation_status(token)
836+
if status.status == "completed":
837+
final_result = await session2.get_operation_result(token)
838+
assert not final_result.result.isError
839+
assert len(final_result.result.content) == 1
840+
content = final_result.result.content[0]
841+
assert isinstance(content, TextContent)
842+
break
843+
elif status.status == "failed":
844+
pytest.fail(f"Operation failed: {status.error}")
845+
846+
attempt += 1
847+
await anyio.sleep(0.5)
848+
else:
849+
pytest.fail("Async operation timed out")
850+
851+
798852
# Test structured output example
799853
@pytest.mark.anyio
800854
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)