1+ """Test that cancelled requests don't cause double responses."""
2+
3+ import asyncio
4+ from unittest .mock import MagicMock
5+
6+ import pytest
7+
8+ import mcp .types as types
9+ from mcp .server .lowlevel .server import Server
10+ from mcp .types import PingRequest , ServerResult
11+
12+
13+ # Shared mock class
14+ class MockRequestResponder :
15+ def __init__ (self ):
16+ self .request_id = "test-123"
17+ self ._responded = False
18+ self .request_meta = {}
19+ self .message_metadata = None
20+
21+ async def send (self , response ):
22+ if self ._responded :
23+ raise AssertionError (f"Request { self .request_id } already responded to" )
24+ self ._responded = True
25+
26+ async def respond (self , response ):
27+ await self .send (response )
28+
29+ def cancel (self ):
30+ """Simulate the cancel() method sending an error response."""
31+ asyncio .create_task (self .send (ServerResult (
32+ error = types .ErrorData (
33+ code = - 32800 ,
34+ message = "Request cancelled"
35+ )
36+ )))
37+
38+
39+ @pytest .mark .asyncio
40+ async def test_cancelled_request_no_double_response ():
41+ """Verify server handles cancelled requests without double response."""
42+
43+ # Create a server instance
44+ server = Server ("test-server" )
45+
46+ # Track if multiple responses are attempted
47+ response_count = 0
48+
49+ # Override the send method to track calls
50+ mock_message = MockRequestResponder ()
51+ original_send = mock_message .send
52+
53+ async def tracked_send (response ):
54+ nonlocal response_count
55+ response_count += 1
56+ await original_send (response )
57+
58+ mock_message .send = tracked_send
59+
60+ # Create a slow handler that will be cancelled
61+ async def slow_handler (req ):
62+ await asyncio .sleep (10 )
63+ return types .ServerResult (types .EmptyResult ())
64+
65+ # Use PingRequest as it's a valid request type
66+ server .request_handlers [types .PingRequest ] = slow_handler
67+
68+ # Create mock message and session
69+ mock_req = PingRequest (method = "ping" , params = {})
70+ mock_session = MagicMock ()
71+ mock_context = None
72+
73+ # Start the request
74+ handle_task = asyncio .create_task (
75+ server ._handle_request (
76+ mock_message ,
77+ mock_req ,
78+ mock_session ,
79+ mock_context ,
80+ raise_exceptions = False
81+ )
82+ )
83+
84+ # Give it time to start
85+ await asyncio .sleep (0.1 )
86+
87+ # Simulate cancellation
88+ mock_message .cancel ()
89+ handle_task .cancel ()
90+
91+ # Wait for cancellation to propagate
92+ try :
93+ await handle_task
94+ except asyncio .CancelledError :
95+ pass
96+
97+ # Give time for any duplicate response attempts
98+ await asyncio .sleep (0.1 )
99+
100+ # Should only have one response (from cancel())
101+ assert response_count == 1 , f"Expected 1 response, got { response_count } "
102+
103+
104+ @pytest .mark .asyncio
105+ async def test_server_remains_functional_after_cancel ():
106+ """Verify server can handle new requests after a cancellation."""
107+
108+ server = Server ("test-server" )
109+
110+ # Add handlers
111+ async def slow_handler (req ):
112+ await asyncio .sleep (5 )
113+ return types .ServerResult (types .EmptyResult ())
114+
115+ async def fast_handler (req ):
116+ return types .ServerResult (types .EmptyResult ())
117+
118+ # Override ping handler for our test
119+ server .request_handlers [types .PingRequest ] = slow_handler
120+
121+ # First request (will be cancelled)
122+ mock_message1 = MockRequestResponder ()
123+ mock_req1 = PingRequest (method = "ping" , params = {})
124+
125+ handle_task = asyncio .create_task (
126+ server ._handle_request (
127+ mock_message1 ,
128+ mock_req1 ,
129+ MagicMock (),
130+ None ,
131+ raise_exceptions = False
132+ )
133+ )
134+
135+ await asyncio .sleep (0.1 )
136+ mock_message1 .cancel ()
137+ handle_task .cancel ()
138+
139+ try :
140+ await handle_task
141+ except asyncio .CancelledError :
142+ pass
143+
144+ # Change handler to fast one
145+ server .request_handlers [types .PingRequest ] = fast_handler
146+
147+ # Second request (should work normally)
148+ mock_message2 = MockRequestResponder ()
149+ mock_req2 = PingRequest (method = "ping" , params = {})
150+
151+ # This should complete successfully
152+ await server ._handle_request (
153+ mock_message2 ,
154+ mock_req2 ,
155+ MagicMock (),
156+ None ,
157+ raise_exceptions = False
158+ )
159+
160+ # Server handled the second request successfully
161+ assert mock_message2 ._responded
0 commit comments