Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
)
)
)
case types.PingRequest():
# Ping requests are allowed at any time
pass
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError("Received request before initialization was complete")
Expand Down
126 changes: 126 additions & 0 deletions tests/server/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,129 @@ async def mock_client():

assert received_initialized
assert received_protocol_version == "2024-11-05"


@pytest.mark.anyio
async def test_ping_request_before_initialization():
"""Test that ping requests are allowed before initialization is complete."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

ping_response_received = False
ping_response_id = None

async def run_server():
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="mcp",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
) as server_session:
async for message in server_session.incoming_messages:
if isinstance(message, Exception):
raise message

# We should receive a ping request before initialization
if isinstance(message, RequestResponder) and isinstance(message.request.root, types.PingRequest):
# Respond to the ping
with message:
await message.respond(types.ServerResult(types.EmptyResult()))
return

async def mock_client():
nonlocal ping_response_received, ping_response_id

# Send ping request before any initialization
await client_to_server_send.send(
SessionMessage(
types.JSONRPCMessage(
types.JSONRPCRequest(
jsonrpc="2.0",
id=42,
method="ping",
)
)
)
)

# Wait for the ping response
ping_response_message = await server_to_client_receive.receive()
assert isinstance(ping_response_message.message.root, types.JSONRPCResponse)

ping_response_received = True
ping_response_id = ping_response_message.message.root.id

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
tg.start_soon(mock_client)

assert ping_response_received
assert ping_response_id == 42


@pytest.mark.anyio
async def test_other_requests_blocked_before_initialization():
"""Test that non-ping requests are still blocked before initialization."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

error_response_received = False
error_code = None

async def run_server():
async with ServerSession(
client_to_server_receive,
server_to_client_send,
InitializationOptions(
server_name="mcp",
server_version="0.1.0",
capabilities=ServerCapabilities(),
),
):
# Server should handle the request and send an error response
# No need to process incoming_messages since the error is handled automatically
await anyio.sleep(0.1) # Give time for the request to be processed

async def mock_client():
nonlocal error_response_received, error_code

# Try to send a non-ping request before initialization
await client_to_server_send.send(
SessionMessage(
types.JSONRPCMessage(
types.JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="prompts/list",
)
)
)
)

# Wait for the error response
error_message = await server_to_client_receive.receive()
if isinstance(error_message.message.root, types.JSONRPCError):
error_response_received = True
error_code = error_message.message.root.error.code

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
tg.start_soon(mock_client)

assert error_response_received
assert error_code == types.INVALID_PARAMS
Loading