1313from mcp .shared .context import RequestContext
1414from mcp .shared .message import SessionMessage
1515from mcp .shared .progress import progress
16- from mcp .shared .session import BaseSession , RequestResponder , SessionMessage
16+ from mcp .shared .session import BaseSession , RequestResponder
1717
1818
1919@pytest .mark .anyio
@@ -62,6 +62,7 @@ async def handle_progress(
6262 progress : float ,
6363 total : float | None ,
6464 message : str | None ,
65+ session : ServerSession | None ,
6566 ):
6667 server_progress_updates .append (
6768 {
@@ -228,6 +229,7 @@ async def handle_progress(
228229 progress : float ,
229230 total : float | None ,
230231 message : str | None ,
232+ session : ServerSession | None ,
231233 ):
232234 server_progress_updates .append (
233235 {"token" : progress_token , "progress" : progress , "total" : total , "message" : message }
@@ -332,9 +334,15 @@ async def test_initialized_notification():
332334
333335 server = Server ("test" )
334336 initialized_received = asyncio .Event ()
337+ received_session : ServerSession | None = None
335338
336339 @server .initialized_notification ()
337- async def handle_initialized (notification : types .InitializedNotification ):
340+ async def handle_initialized (
341+ notification : types .InitializedNotification ,
342+ session : ServerSession | None = None ,
343+ ):
344+ nonlocal received_session
345+ received_session = session
338346 initialized_received .set ()
339347
340348 async def run_server ():
@@ -364,6 +372,7 @@ async def message_handler(
364372 tg .cancel_scope .cancel ()
365373
366374 assert initialized_received .is_set ()
375+ assert isinstance (received_session , ServerSession )
367376
368377
369378@pytest .mark .anyio
@@ -374,105 +383,13 @@ async def test_roots_list_changed_notification():
374383
375384 server = Server ("test" )
376385 roots_list_changed_received = asyncio .Event ()
386+ received_session : ServerSession | None = None
377387
378388 @server .roots_list_changed_notification ()
379389 async def handle_roots_list_changed (
380390 notification : types .RootsListChangedNotification ,
391+ session : ServerSession | None = None ,
381392 ):
382- roots_list_changed_received .set ()
383-
384- async def run_server ():
385- await server .run (
386- client_to_server_receive ,
387- server_to_client_send ,
388- server .create_initialization_options (),
389- )
390-
391- async def message_handler (
392- message : (RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ),
393- ) -> None :
394- if isinstance (message , Exception ):
395- raise message
396-
397- async with (
398- ClientSession (
399- server_to_client_receive ,
400- client_to_server_send ,
401- message_handler = message_handler ,
402- ) as client_session ,
403- anyio .create_task_group () as tg ,
404- ):
405- tg .start_soon (run_server )
406- await client_session .initialize ()
407- await client_session .send_notification (
408- types .ClientNotification (
409- root = types .RootsListChangedNotification (method = "notifications/roots/list_changed" , params = None )
410- )
411- )
412- await roots_list_changed_received .wait ()
413- tg .cancel_scope .cancel ()
414-
415- assert roots_list_changed_received .is_set ()
416-
417-
418- @pytest .mark .anyio
419- async def test_initialized_notification_with_session ():
420- """Test that the server receives and handles InitializedNotification with a session."""
421- server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
422- client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
423-
424- server = Server ("test" )
425- initialized_received = asyncio .Event ()
426- received_session = None
427-
428- @server .initialized_notification ()
429- async def handle_initialized (notification : types .InitializedNotification , session : ServerSession ):
430- nonlocal received_session
431- received_session = session
432- initialized_received .set ()
433-
434- async def run_server ():
435- await server .run (
436- client_to_server_receive ,
437- server_to_client_send ,
438- server .create_initialization_options (),
439- )
440-
441- async def message_handler (
442- message : (RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ),
443- ) -> None :
444- if isinstance (message , Exception ):
445- raise message
446-
447- async with (
448- ClientSession (
449- server_to_client_receive ,
450- client_to_server_send ,
451- message_handler = message_handler ,
452- ) as client_session ,
453- anyio .create_task_group () as tg ,
454- ):
455- tg .start_soon (run_server )
456- await client_session .initialize ()
457- await initialized_received .wait ()
458- tg .cancel_scope .cancel ()
459-
460- assert initialized_received .is_set ()
461- assert isinstance (received_session , ServerSession )
462-
463-
464- @pytest .mark .anyio
465- async def test_roots_list_changed_notification_with_session ():
466- """Test that the server receives and handles RootsListChangedNotification with a session."""
467- server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
468- client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
469-
470- server = Server ("test" )
471- roots_list_changed_received = asyncio .Event ()
472- received_session = None
473-
474- @server .roots_list_changed_notification ()
475- async def handle_roots_list_changed (notification : types .RootsListChangedNotification , session : ServerSession ):
476393 nonlocal received_session
477394 received_session = session
478395 roots_list_changed_received .set ()
0 commit comments