@@ -68,9 +68,10 @@ async def main():
6868import logging
6969import warnings
7070from collections .abc import Awaitable , Callable
71- from contextlib import AbstractAsyncContextManager , asynccontextmanager
71+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
7272from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
7373
74+ import anyio
7475from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7576from pydantic import AnyUrl
7677
@@ -458,6 +459,30 @@ async def handler(req: types.CompleteRequest):
458459
459460 return decorator
460461
462+ async def _handle_message (
463+ self ,
464+ message : RequestResponder [types .ClientRequest , types .ServerResult ]
465+ | types .ClientNotification
466+ | Exception ,
467+ session : ServerSession ,
468+ lifespan_context : LifespanResultT ,
469+ raise_exceptions : bool = False ,
470+ ):
471+ with warnings .catch_warnings (record = True ) as w :
472+ match message :
473+ case (
474+ RequestResponder (request = types .ClientRequest (root = req )) as responder
475+ ):
476+ with responder :
477+ await self ._handle_request (
478+ message , req , session , lifespan_context , raise_exceptions
479+ )
480+ case types .ClientNotification (root = notify ):
481+ await self ._handle_notification (notify )
482+
483+ for warning in w :
484+ logger .info (f"Warning: { warning .category .__name__ } : { warning .message } " )
485+
461486 async def run (
462487 self ,
463488 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
@@ -469,41 +494,23 @@ async def run(
469494 # in-process servers.
470495 raise_exceptions : bool = False ,
471496 ):
472- with warnings .catch_warnings (record = True ) as w :
473- from contextlib import AsyncExitStack
474-
475- async with AsyncExitStack () as stack :
476- lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477- session = await stack .enter_async_context (
478- ServerSession (read_stream , write_stream , initialization_options )
479- )
497+ async with AsyncExitStack () as stack :
498+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
499+ session = await stack .enter_async_context (
500+ ServerSession (read_stream , write_stream , initialization_options )
501+ )
480502
503+ async with anyio .create_task_group () as tg :
481504 async for message in session .incoming_messages :
482505 logger .debug (f"Received message: { message } " )
483506
484- match message :
485- case (
486- RequestResponder (
487- request = types .ClientRequest (root = req )
488- ) as responder
489- ):
490- with responder :
491- await self ._handle_request (
492- message ,
493- req ,
494- session ,
495- lifespan_context ,
496- raise_exceptions ,
497- )
498- case types .ClientNotification (root = notify ):
499- await self ._handle_notification (notify )
500-
501- for warning in w :
502- logger .info (
503- "Warning: %s: %s" ,
504- warning .category .__name__ ,
505- warning .message ,
506- )
507+ tg .start_soon (
508+ self ._handle_message ,
509+ message ,
510+ session ,
511+ lifespan_context ,
512+ raise_exceptions ,
513+ )
507514
508515 async def _handle_request (
509516 self ,
0 commit comments