|  | 
|  | 1 | +"""Test that cancelled requests don't cause double responses.""" | 
|  | 2 | + | 
|  | 3 | +import anyio | 
|  | 4 | +import pytest | 
|  | 5 | + | 
|  | 6 | +import mcp.types as types | 
|  | 7 | +from mcp.server.lowlevel.server import Server | 
|  | 8 | +from mcp.shared.exceptions import McpError | 
|  | 9 | +from mcp.shared.memory import create_connected_server_and_client_session | 
|  | 10 | +from mcp.types import ( | 
|  | 11 | +    CallToolRequest, | 
|  | 12 | +    CallToolRequestParams, | 
|  | 13 | +    CallToolResult, | 
|  | 14 | +    CancelledNotification, | 
|  | 15 | +    CancelledNotificationParams, | 
|  | 16 | +    ClientNotification, | 
|  | 17 | +    ClientRequest, | 
|  | 18 | +    Tool, | 
|  | 19 | +) | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +@pytest.mark.anyio | 
|  | 23 | +async def test_server_remains_functional_after_cancel(): | 
|  | 24 | +    """Verify server can handle new requests after a cancellation.""" | 
|  | 25 | + | 
|  | 26 | +    server = Server("test-server") | 
|  | 27 | + | 
|  | 28 | +    # Track tool calls | 
|  | 29 | +    call_count = 0 | 
|  | 30 | +    ev_first_call = anyio.Event() | 
|  | 31 | +    first_request_id = None | 
|  | 32 | + | 
|  | 33 | +    @server.list_tools() | 
|  | 34 | +    async def handle_list_tools() -> list[Tool]: | 
|  | 35 | +        return [ | 
|  | 36 | +            Tool( | 
|  | 37 | +                name="test_tool", | 
|  | 38 | +                description="Tool for testing", | 
|  | 39 | +                inputSchema={}, | 
|  | 40 | +            ) | 
|  | 41 | +        ] | 
|  | 42 | + | 
|  | 43 | +    @server.call_tool() | 
|  | 44 | +    async def handle_call_tool(name: str, arguments: dict | None) -> list: | 
|  | 45 | +        nonlocal call_count, first_request_id | 
|  | 46 | +        if name == "test_tool": | 
|  | 47 | +            call_count += 1 | 
|  | 48 | +            if call_count == 1: | 
|  | 49 | +                first_request_id = server.request_context.request_id | 
|  | 50 | +                ev_first_call.set() | 
|  | 51 | +                await anyio.sleep(5)  # First call is slow | 
|  | 52 | +            return [types.TextContent(type="text", text=f"Call number: {call_count}")] | 
|  | 53 | +        raise ValueError(f"Unknown tool: {name}") | 
|  | 54 | + | 
|  | 55 | +    async with create_connected_server_and_client_session(server) as client: | 
|  | 56 | +        # First request (will be cancelled) | 
|  | 57 | +        async def first_request(): | 
|  | 58 | +            try: | 
|  | 59 | +                await client.send_request( | 
|  | 60 | +                    ClientRequest( | 
|  | 61 | +                        CallToolRequest( | 
|  | 62 | +                            method="tools/call", | 
|  | 63 | +                            params=CallToolRequestParams(name="test_tool", arguments={}), | 
|  | 64 | +                        ) | 
|  | 65 | +                    ), | 
|  | 66 | +                    CallToolResult, | 
|  | 67 | +                ) | 
|  | 68 | +                pytest.fail("First request should have been cancelled") | 
|  | 69 | +            except McpError: | 
|  | 70 | +                pass  # Expected | 
|  | 71 | + | 
|  | 72 | +        # Start first request | 
|  | 73 | +        async with anyio.create_task_group() as tg: | 
|  | 74 | +            tg.start_soon(first_request) | 
|  | 75 | + | 
|  | 76 | +            # Wait for it to start | 
|  | 77 | +            await ev_first_call.wait() | 
|  | 78 | + | 
|  | 79 | +            # Cancel it | 
|  | 80 | +            assert first_request_id is not None | 
|  | 81 | +            await client.send_notification( | 
|  | 82 | +                ClientNotification( | 
|  | 83 | +                    CancelledNotification( | 
|  | 84 | +                        method="notifications/cancelled", | 
|  | 85 | +                        params=CancelledNotificationParams( | 
|  | 86 | +                            requestId=first_request_id, | 
|  | 87 | +                            reason="Testing server recovery", | 
|  | 88 | +                        ), | 
|  | 89 | +                    ) | 
|  | 90 | +                ) | 
|  | 91 | +            ) | 
|  | 92 | + | 
|  | 93 | +        # Second request (should work normally) | 
|  | 94 | +        result = await client.send_request( | 
|  | 95 | +            ClientRequest( | 
|  | 96 | +                CallToolRequest( | 
|  | 97 | +                    method="tools/call", | 
|  | 98 | +                    params=CallToolRequestParams(name="test_tool", arguments={}), | 
|  | 99 | +                ) | 
|  | 100 | +            ), | 
|  | 101 | +            CallToolResult, | 
|  | 102 | +        ) | 
|  | 103 | + | 
|  | 104 | +        # Verify second request completed successfully | 
|  | 105 | +        assert len(result.content) == 1 | 
|  | 106 | +        # Type narrowing for pyright | 
|  | 107 | +        content = result.content[0] | 
|  | 108 | +        assert content.type == "text" | 
|  | 109 | +        assert isinstance(content, types.TextContent) | 
|  | 110 | +        assert content.text == "Call number: 2" | 
|  | 111 | +        assert call_count == 2 | 
0 commit comments