Skip to content

Commit 47d35f0

Browse files
authored
Allow ping requests before initialization (#1312)
1 parent 346e794 commit 47d35f0

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

src/mcp/server/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
161161
)
162162
)
163163
)
164+
case types.PingRequest():
165+
# Ping requests are allowed at any time
166+
pass
164167
case _:
165168
if self._initialization_state != InitializationState.Initialized:
166169
raise RuntimeError("Received request before initialization was complete")

tests/server/test_session.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,129 @@ async def mock_client():
213213

214214
assert received_initialized
215215
assert received_protocol_version == "2024-11-05"
216+
217+
218+
@pytest.mark.anyio
219+
async def test_ping_request_before_initialization():
220+
"""Test that ping requests are allowed before initialization is complete."""
221+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
222+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
223+
224+
ping_response_received = False
225+
ping_response_id = None
226+
227+
async def run_server():
228+
async with ServerSession(
229+
client_to_server_receive,
230+
server_to_client_send,
231+
InitializationOptions(
232+
server_name="mcp",
233+
server_version="0.1.0",
234+
capabilities=ServerCapabilities(),
235+
),
236+
) as server_session:
237+
async for message in server_session.incoming_messages:
238+
if isinstance(message, Exception):
239+
raise message
240+
241+
# We should receive a ping request before initialization
242+
if isinstance(message, RequestResponder) and isinstance(message.request.root, types.PingRequest):
243+
# Respond to the ping
244+
with message:
245+
await message.respond(types.ServerResult(types.EmptyResult()))
246+
return
247+
248+
async def mock_client():
249+
nonlocal ping_response_received, ping_response_id
250+
251+
# Send ping request before any initialization
252+
await client_to_server_send.send(
253+
SessionMessage(
254+
types.JSONRPCMessage(
255+
types.JSONRPCRequest(
256+
jsonrpc="2.0",
257+
id=42,
258+
method="ping",
259+
)
260+
)
261+
)
262+
)
263+
264+
# Wait for the ping response
265+
ping_response_message = await server_to_client_receive.receive()
266+
assert isinstance(ping_response_message.message.root, types.JSONRPCResponse)
267+
268+
ping_response_received = True
269+
ping_response_id = ping_response_message.message.root.id
270+
271+
async with (
272+
client_to_server_send,
273+
client_to_server_receive,
274+
server_to_client_send,
275+
server_to_client_receive,
276+
anyio.create_task_group() as tg,
277+
):
278+
tg.start_soon(run_server)
279+
tg.start_soon(mock_client)
280+
281+
assert ping_response_received
282+
assert ping_response_id == 42
283+
284+
285+
@pytest.mark.anyio
286+
async def test_other_requests_blocked_before_initialization():
287+
"""Test that non-ping requests are still blocked before initialization."""
288+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
289+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
290+
291+
error_response_received = False
292+
error_code = None
293+
294+
async def run_server():
295+
async with ServerSession(
296+
client_to_server_receive,
297+
server_to_client_send,
298+
InitializationOptions(
299+
server_name="mcp",
300+
server_version="0.1.0",
301+
capabilities=ServerCapabilities(),
302+
),
303+
):
304+
# Server should handle the request and send an error response
305+
# No need to process incoming_messages since the error is handled automatically
306+
await anyio.sleep(0.1) # Give time for the request to be processed
307+
308+
async def mock_client():
309+
nonlocal error_response_received, error_code
310+
311+
# Try to send a non-ping request before initialization
312+
await client_to_server_send.send(
313+
SessionMessage(
314+
types.JSONRPCMessage(
315+
types.JSONRPCRequest(
316+
jsonrpc="2.0",
317+
id=1,
318+
method="prompts/list",
319+
)
320+
)
321+
)
322+
)
323+
324+
# Wait for the error response
325+
error_message = await server_to_client_receive.receive()
326+
if isinstance(error_message.message.root, types.JSONRPCError):
327+
error_response_received = True
328+
error_code = error_message.message.root.error.code
329+
330+
async with (
331+
client_to_server_send,
332+
client_to_server_receive,
333+
server_to_client_send,
334+
server_to_client_receive,
335+
anyio.create_task_group() as tg,
336+
):
337+
tg.start_soon(run_server)
338+
tg.start_soon(mock_client)
339+
340+
assert error_response_received
341+
assert error_code == types.INVALID_PARAMS

0 commit comments

Comments
 (0)