diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 9ccffefa9..0f9cda920 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -14,29 +14,38 @@ @pytest.mark.anyio async def test_messages_are_executed_concurrently(): server = FastMCP("test") + call_timestamps = [] @server.tool("sleep") async def sleep_tool(): + call_timestamps.append(("tool_start_time", anyio.current_time())) await anyio.sleep(_sleep_time_seconds) + call_timestamps.append(("tool_end_time", anyio.current_time())) return "done" @server.resource(_resource_name) async def slow_resource(): + call_timestamps.append(("resource_start_time", anyio.current_time())) await anyio.sleep(_sleep_time_seconds) + call_timestamps.append(("resource_end_time", anyio.current_time())) return "slow" async with create_session(server._mcp_server) as client_session: - start_time = anyio.current_time() async with anyio.create_task_group() as tg: for _ in range(10): tg.start_soon(client_session.call_tool, "sleep") tg.start_soon(client_session.read_resource, AnyUrl(_resource_name)) - end_time = anyio.current_time() - - duration = end_time - start_time - assert duration < 10 * _sleep_time_seconds - print(duration) + active_calls = 0 + max_concurrent_calls = 0 + for call_type, _ in sorted(call_timestamps, key=lambda x: x[1]): + if "start" in call_type: + active_calls += 1 + max_concurrent_calls = max(max_concurrent_calls, active_calls) + else: + active_calls -= 1 + print(f"Max concurrent calls: {max_concurrent_calls}") + assert max_concurrent_calls > 1, "No concurrent calls were executed" def main():