diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 48df1171d..7b3680f7c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -161,6 +161,9 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques ) ) ) + case types.PingRequest(): + # Ping requests are allowed at any time + pass case _: if self._initialization_state != InitializationState.Initialized: raise RuntimeError("Received request before initialization was complete") diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 89e807b29..664867511 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -213,3 +213,129 @@ async def mock_client(): assert received_initialized assert received_protocol_version == "2024-11-05" + + +@pytest.mark.anyio +async def test_ping_request_before_initialization(): + """Test that ping requests are allowed before initialization is complete.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + ping_response_received = False + ping_response_id = None + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="mcp", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + async for message in server_session.incoming_messages: + if isinstance(message, Exception): + raise message + + # We should receive a ping request before initialization + if isinstance(message, RequestResponder) and isinstance(message.request.root, types.PingRequest): + # Respond to the ping + with message: + await message.respond(types.ServerResult(types.EmptyResult())) + return + + async def mock_client(): + nonlocal ping_response_received, ping_response_id + + # Send ping request before any initialization + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCRequest( + jsonrpc="2.0", + id=42, + method="ping", + ) + ) + ) + ) + + # Wait for the ping response + ping_response_message = await server_to_client_receive.receive() + assert isinstance(ping_response_message.message.root, types.JSONRPCResponse) + + ping_response_received = True + ping_response_id = ping_response_message.message.root.id + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + tg.start_soon(mock_client) + + assert ping_response_received + assert ping_response_id == 42 + + +@pytest.mark.anyio +async def test_other_requests_blocked_before_initialization(): + """Test that non-ping requests are still blocked before initialization.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + error_response_received = False + error_code = None + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="mcp", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ), + ): + # Server should handle the request and send an error response + # No need to process incoming_messages since the error is handled automatically + await anyio.sleep(0.1) # Give time for the request to be processed + + async def mock_client(): + nonlocal error_response_received, error_code + + # Try to send a non-ping request before initialization + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="prompts/list", + ) + ) + ) + ) + + # Wait for the error response + error_message = await server_to_client_receive.receive() + if isinstance(error_message.message.root, types.JSONRPCError): + error_response_received = True + error_code = error_message.message.root.error.code + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + tg.start_soon(mock_client) + + assert error_response_received + assert error_code == types.INVALID_PARAMS