|
1 | 1 | import anyio
|
2 | 2 | import pytest
|
3 |
| -from pydantic import AnyUrl |
4 | 3 |
|
5 | 4 | from mcp.server.fastmcp import FastMCP
|
6 |
| -from mcp.shared.memory import ( |
7 |
| - create_connected_server_and_client_session as create_session, |
8 |
| -) |
9 |
| - |
10 |
| -_sleep_time_seconds = 0.01 |
11 |
| -_resource_name = "slow://slow_resource" |
| 5 | +from mcp.shared.memory import create_connected_server_and_client_session as create_session |
12 | 6 |
|
13 | 7 |
|
14 | 8 | @pytest.mark.anyio
|
15 | 9 | async def test_messages_are_executed_concurrently():
|
16 | 10 | server = FastMCP("test")
|
17 |
| - call_timestamps = [] |
| 11 | + event = anyio.Event() |
| 12 | + tool_started = anyio.Event() |
| 13 | + call_order = [] |
18 | 14 |
|
19 | 15 | @server.tool("sleep")
|
20 | 16 | async def sleep_tool():
|
21 |
| - call_timestamps.append(("tool_start_time", anyio.current_time())) |
22 |
| - await anyio.sleep(_sleep_time_seconds) |
23 |
| - call_timestamps.append(("tool_end_time", anyio.current_time())) |
| 17 | + call_order.append("waiting_for_event") |
| 18 | + tool_started.set() |
| 19 | + await event.wait() |
| 20 | + call_order.append("tool_end") |
24 | 21 | return "done"
|
25 | 22 |
|
26 |
| - @server.resource(_resource_name) |
27 |
| - async def slow_resource(): |
28 |
| - call_timestamps.append(("resource_start_time", anyio.current_time())) |
29 |
| - await anyio.sleep(_sleep_time_seconds) |
30 |
| - call_timestamps.append(("resource_end_time", anyio.current_time())) |
| 23 | + @server.tool("trigger") |
| 24 | + async def trigger(): |
| 25 | + # Wait for tool to start before setting the event |
| 26 | + await tool_started.wait() |
| 27 | + call_order.append("trigger_started") |
| 28 | + event.set() |
| 29 | + call_order.append("trigger_end") |
31 | 30 | return "slow"
|
32 | 31 |
|
33 | 32 | async with create_session(server._mcp_server) as client_session:
|
| 33 | + # First tool will wait on event, second will set it |
34 | 34 | async with anyio.create_task_group() as tg:
|
35 |
| - for _ in range(10): |
36 |
| - tg.start_soon(client_session.call_tool, "sleep") |
37 |
| - tg.start_soon(client_session.read_resource, AnyUrl(_resource_name)) |
38 |
| - |
39 |
| - active_calls = 0 |
40 |
| - max_concurrent_calls = 0 |
41 |
| - for call_type, _ in sorted(call_timestamps, key=lambda x: x[1]): |
42 |
| - if "start" in call_type: |
43 |
| - active_calls += 1 |
44 |
| - max_concurrent_calls = max(max_concurrent_calls, active_calls) |
45 |
| - else: |
46 |
| - active_calls -= 1 |
47 |
| - print(f"Max concurrent calls: {max_concurrent_calls}") |
48 |
| - assert max_concurrent_calls > 1, "No concurrent calls were executed" |
49 |
| - |
50 |
| - |
51 |
| -def main(): |
52 |
| - anyio.run(test_messages_are_executed_concurrently) |
53 |
| - |
54 |
| - |
55 |
| -if __name__ == "__main__": |
56 |
| - import logging |
57 |
| - |
58 |
| - logging.basicConfig(level=logging.DEBUG) |
59 |
| - |
60 |
| - main() |
| 35 | + # Start the tool first (it will wait on event) |
| 36 | + tg.start_soon(client_session.call_tool, "sleep") |
| 37 | + # Then the trigger tool will set the event to allow the first tool to continue |
| 38 | + await client_session.call_tool("trigger") |
| 39 | + |
| 40 | + # Verify that both ran concurrently |
| 41 | + assert call_order == [ |
| 42 | + "waiting_for_event", |
| 43 | + "trigger_started", |
| 44 | + "trigger_end", |
| 45 | + "tool_end", |
| 46 | + ], f"Expected concurrent execution, but got: {call_order}" |
0 commit comments