@@ -213,3 +213,129 @@ async def mock_client():
213
213
214
214
assert received_initialized
215
215
assert received_protocol_version == "2024-11-05"
216
+
217
+
218
+ @pytest .mark .anyio
219
+ async def test_ping_request_before_initialization ():
220
+ """Test that ping requests are allowed before initialization is complete."""
221
+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
222
+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
223
+
224
+ ping_response_received = False
225
+ ping_response_id = None
226
+
227
+ async def run_server ():
228
+ async with ServerSession (
229
+ client_to_server_receive ,
230
+ server_to_client_send ,
231
+ InitializationOptions (
232
+ server_name = "mcp" ,
233
+ server_version = "0.1.0" ,
234
+ capabilities = ServerCapabilities (),
235
+ ),
236
+ ) as server_session :
237
+ async for message in server_session .incoming_messages :
238
+ if isinstance (message , Exception ):
239
+ raise message
240
+
241
+ # We should receive a ping request before initialization
242
+ if isinstance (message , RequestResponder ) and isinstance (message .request .root , types .PingRequest ):
243
+ # Respond to the ping
244
+ with message :
245
+ await message .respond (types .ServerResult (types .EmptyResult ()))
246
+ return
247
+
248
+ async def mock_client ():
249
+ nonlocal ping_response_received , ping_response_id
250
+
251
+ # Send ping request before any initialization
252
+ await client_to_server_send .send (
253
+ SessionMessage (
254
+ types .JSONRPCMessage (
255
+ types .JSONRPCRequest (
256
+ jsonrpc = "2.0" ,
257
+ id = 42 ,
258
+ method = "ping" ,
259
+ )
260
+ )
261
+ )
262
+ )
263
+
264
+ # Wait for the ping response
265
+ ping_response_message = await server_to_client_receive .receive ()
266
+ assert isinstance (ping_response_message .message .root , types .JSONRPCResponse )
267
+
268
+ ping_response_received = True
269
+ ping_response_id = ping_response_message .message .root .id
270
+
271
+ async with (
272
+ client_to_server_send ,
273
+ client_to_server_receive ,
274
+ server_to_client_send ,
275
+ server_to_client_receive ,
276
+ anyio .create_task_group () as tg ,
277
+ ):
278
+ tg .start_soon (run_server )
279
+ tg .start_soon (mock_client )
280
+
281
+ assert ping_response_received
282
+ assert ping_response_id == 42
283
+
284
+
285
+ @pytest .mark .anyio
286
+ async def test_other_requests_blocked_before_initialization ():
287
+ """Test that non-ping requests are still blocked before initialization."""
288
+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
289
+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
290
+
291
+ error_response_received = False
292
+ error_code = None
293
+
294
+ async def run_server ():
295
+ async with ServerSession (
296
+ client_to_server_receive ,
297
+ server_to_client_send ,
298
+ InitializationOptions (
299
+ server_name = "mcp" ,
300
+ server_version = "0.1.0" ,
301
+ capabilities = ServerCapabilities (),
302
+ ),
303
+ ):
304
+ # Server should handle the request and send an error response
305
+ # No need to process incoming_messages since the error is handled automatically
306
+ await anyio .sleep (0.1 ) # Give time for the request to be processed
307
+
308
+ async def mock_client ():
309
+ nonlocal error_response_received , error_code
310
+
311
+ # Try to send a non-ping request before initialization
312
+ await client_to_server_send .send (
313
+ SessionMessage (
314
+ types .JSONRPCMessage (
315
+ types .JSONRPCRequest (
316
+ jsonrpc = "2.0" ,
317
+ id = 1 ,
318
+ method = "prompts/list" ,
319
+ )
320
+ )
321
+ )
322
+ )
323
+
324
+ # Wait for the error response
325
+ error_message = await server_to_client_receive .receive ()
326
+ if isinstance (error_message .message .root , types .JSONRPCError ):
327
+ error_response_received = True
328
+ error_code = error_message .message .root .error .code
329
+
330
+ async with (
331
+ client_to_server_send ,
332
+ client_to_server_receive ,
333
+ server_to_client_send ,
334
+ server_to_client_receive ,
335
+ anyio .create_task_group () as tg ,
336
+ ):
337
+ tg .start_soon (run_server )
338
+ tg .start_soon (mock_client )
339
+
340
+ assert error_response_received
341
+ assert error_code == types .INVALID_PARAMS
0 commit comments