13
13
from mcp .shared .context import RequestContext
14
14
from mcp .shared .message import SessionMessage
15
15
from mcp .shared .progress import progress
16
- from mcp .shared .session import BaseSession , RequestResponder , SessionMessage
16
+ from mcp .shared .session import BaseSession , RequestResponder
17
17
18
18
19
19
@pytest .mark .anyio
@@ -62,6 +62,7 @@ async def handle_progress(
62
62
progress : float ,
63
63
total : float | None ,
64
64
message : str | None ,
65
+ session : ServerSession | None ,
65
66
):
66
67
server_progress_updates .append (
67
68
{
@@ -228,6 +229,7 @@ async def handle_progress(
228
229
progress : float ,
229
230
total : float | None ,
230
231
message : str | None ,
232
+ session : ServerSession | None ,
231
233
):
232
234
server_progress_updates .append (
233
235
{"token" : progress_token , "progress" : progress , "total" : total , "message" : message }
@@ -332,9 +334,15 @@ async def test_initialized_notification():
332
334
333
335
server = Server ("test" )
334
336
initialized_received = asyncio .Event ()
337
+ received_session : ServerSession | None = None
335
338
336
339
@server .initialized_notification ()
337
- async def handle_initialized (notification : types .InitializedNotification ):
340
+ async def handle_initialized (
341
+ notification : types .InitializedNotification ,
342
+ session : ServerSession | None = None ,
343
+ ):
344
+ nonlocal received_session
345
+ received_session = session
338
346
initialized_received .set ()
339
347
340
348
async def run_server ():
@@ -364,6 +372,7 @@ async def message_handler(
364
372
tg .cancel_scope .cancel ()
365
373
366
374
assert initialized_received .is_set ()
375
+ assert isinstance (received_session , ServerSession )
367
376
368
377
369
378
@pytest .mark .anyio
@@ -374,105 +383,13 @@ async def test_roots_list_changed_notification():
374
383
375
384
server = Server ("test" )
376
385
roots_list_changed_received = asyncio .Event ()
386
+ received_session : ServerSession | None = None
377
387
378
388
@server .roots_list_changed_notification ()
379
389
async def handle_roots_list_changed (
380
390
notification : types .RootsListChangedNotification ,
391
+ session : ServerSession | None = None ,
381
392
):
382
- roots_list_changed_received .set ()
383
-
384
- async def run_server ():
385
- await server .run (
386
- client_to_server_receive ,
387
- server_to_client_send ,
388
- server .create_initialization_options (),
389
- )
390
-
391
- async def message_handler (
392
- message : (RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ),
393
- ) -> None :
394
- if isinstance (message , Exception ):
395
- raise message
396
-
397
- async with (
398
- ClientSession (
399
- server_to_client_receive ,
400
- client_to_server_send ,
401
- message_handler = message_handler ,
402
- ) as client_session ,
403
- anyio .create_task_group () as tg ,
404
- ):
405
- tg .start_soon (run_server )
406
- await client_session .initialize ()
407
- await client_session .send_notification (
408
- types .ClientNotification (
409
- root = types .RootsListChangedNotification (method = "notifications/roots/list_changed" , params = None )
410
- )
411
- )
412
- await roots_list_changed_received .wait ()
413
- tg .cancel_scope .cancel ()
414
-
415
- assert roots_list_changed_received .is_set ()
416
-
417
-
418
- @pytest .mark .anyio
419
- async def test_initialized_notification_with_session ():
420
- """Test that the server receives and handles InitializedNotification with a session."""
421
- server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
422
- client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
423
-
424
- server = Server ("test" )
425
- initialized_received = asyncio .Event ()
426
- received_session = None
427
-
428
- @server .initialized_notification ()
429
- async def handle_initialized (notification : types .InitializedNotification , session : ServerSession ):
430
- nonlocal received_session
431
- received_session = session
432
- initialized_received .set ()
433
-
434
- async def run_server ():
435
- await server .run (
436
- client_to_server_receive ,
437
- server_to_client_send ,
438
- server .create_initialization_options (),
439
- )
440
-
441
- async def message_handler (
442
- message : (RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ),
443
- ) -> None :
444
- if isinstance (message , Exception ):
445
- raise message
446
-
447
- async with (
448
- ClientSession (
449
- server_to_client_receive ,
450
- client_to_server_send ,
451
- message_handler = message_handler ,
452
- ) as client_session ,
453
- anyio .create_task_group () as tg ,
454
- ):
455
- tg .start_soon (run_server )
456
- await client_session .initialize ()
457
- await initialized_received .wait ()
458
- tg .cancel_scope .cancel ()
459
-
460
- assert initialized_received .is_set ()
461
- assert isinstance (received_session , ServerSession )
462
-
463
-
464
- @pytest .mark .anyio
465
- async def test_roots_list_changed_notification_with_session ():
466
- """Test that the server receives and handles RootsListChangedNotification with a session."""
467
- server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
468
- client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
469
-
470
- server = Server ("test" )
471
- roots_list_changed_received = asyncio .Event ()
472
- received_session = None
473
-
474
- @server .roots_list_changed_notification ()
475
- async def handle_roots_list_changed (notification : types .RootsListChangedNotification , session : ServerSession ):
476
393
nonlocal received_session
477
394
received_session = session
478
395
roots_list_changed_received .set ()
0 commit comments