@@ -470,36 +470,40 @@ async def run(
470470 raise_exceptions : bool = False ,
471471 ):
472472 with warnings .catch_warnings (record = True ) as w :
473- async with self .lifespan (self ) as lifespan_context :
474- async with ServerSession (
475- read_stream , write_stream , initialization_options
476- ) as session :
477- async for message in session .incoming_messages :
478- logger .debug (f"Received message: { message } " )
479-
480- match message :
481- case (
482- RequestResponder (
483- request = types .ClientRequest (root = req )
484- ) as responder
485- ):
486- with responder :
487- await self ._handle_request (
488- message ,
489- req ,
490- session ,
491- lifespan_context ,
492- raise_exceptions ,
493- )
494- case types .ClientNotification (root = notify ):
495- await self ._handle_notification (notify )
496-
497- for warning in w :
498- logger .info (
499- "Warning: %s: %s" ,
500- warning .category .__name__ ,
501- warning .message ,
502- )
473+ from contextlib import AsyncExitStack
474+
475+ async with AsyncExitStack () as stack :
476+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477+ session = await stack .enter_async_context (
478+ ServerSession (read_stream , write_stream , initialization_options )
479+ )
480+
481+ async for message in session .incoming_messages :
482+ logger .debug (f"Received message: { message } " )
483+
484+ match message :
485+ case (
486+ RequestResponder (
487+ request = types .ClientRequest (root = req )
488+ ) as responder
489+ ):
490+ with responder :
491+ await self ._handle_request (
492+ message ,
493+ req ,
494+ session ,
495+ lifespan_context ,
496+ raise_exceptions ,
497+ )
498+ case types .ClientNotification (root = notify ):
499+ await self ._handle_notification (notify )
500+
501+ for warning in w :
502+ logger .info (
503+ "Warning: %s: %s" ,
504+ warning .category .__name__ ,
505+ warning .message ,
506+ )
503507
504508 async def _handle_request (
505509 self ,
0 commit comments