@@ -68,7 +68,8 @@ async def main():
6868import logging
6969import warnings
7070from collections .abc import Awaitable , Callable
71- from typing import Any , Sequence
71+ from contextlib import AbstractAsyncContextManager , asynccontextmanager
72+ from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
7273
7374from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7475from pydantic import AnyUrl
@@ -101,13 +102,36 @@ def __init__(
101102 self .tools_changed = tools_changed
102103
103104
104- class Server :
105+ LifespanResultT = TypeVar ("LifespanResultT" )
106+
107+
108+ @asynccontextmanager
109+ async def lifespan (server : "Server" ) -> AsyncIterator [object ]:
110+ """Default lifespan context manager that does nothing.
111+
112+ Args:
113+ server: The server instance this lifespan is managing
114+
115+ Returns:
116+ An empty context object
117+ """
118+ yield {}
119+
120+
121+ class Server (Generic [LifespanResultT ]):
105122 def __init__ (
106- self , name : str , version : str | None = None , instructions : str | None = None
123+ self ,
124+ name : str ,
125+ version : str | None = None ,
126+ instructions : str | None = None ,
127+ lifespan : Callable [
128+ ["Server" ], AbstractAsyncContextManager [LifespanResultT ]
129+ ] = lifespan ,
107130 ):
108131 self .name = name
109132 self .version = version
110133 self .instructions = instructions
134+ self .lifespan = lifespan
111135 self .request_handlers : dict [
112136 type , Callable [..., Awaitable [types .ServerResult ]]
113137 ] = {
@@ -446,35 +470,43 @@ async def run(
446470 raise_exceptions : bool = False ,
447471 ):
448472 with warnings .catch_warnings (record = True ) as w :
449- async with ServerSession (
450- read_stream , write_stream , initialization_options
451- ) as session :
452- async for message in session .incoming_messages :
453- logger .debug (f"Received message: { message } " )
454-
455- match message :
456- case (
457- RequestResponder (
458- request = types .ClientRequest (root = req )
459- ) as responder
460- ):
461- with responder :
462- await self ._handle_request (
463- message , req , session , raise_exceptions
464- )
465- case types .ClientNotification (root = notify ):
466- await self ._handle_notification (notify )
467-
468- for warning in w :
469- logger .info (
470- f"Warning: { warning .category .__name__ } : { warning .message } "
471- )
473+ async with self .lifespan (self ) as lifespan_context :
474+ async with ServerSession (
475+ read_stream , write_stream , initialization_options
476+ ) as session :
477+ async for message in session .incoming_messages :
478+ logger .debug (f"Received message: { message } " )
479+
480+ match message :
481+ case (
482+ RequestResponder (
483+ request = types .ClientRequest (root = req )
484+ ) as responder
485+ ):
486+ with responder :
487+ await self ._handle_request (
488+ message ,
489+ req ,
490+ session ,
491+ lifespan_context ,
492+ raise_exceptions ,
493+ )
494+ case types .ClientNotification (root = notify ):
495+ await self ._handle_notification (notify )
496+
497+ for warning in w :
498+ logger .info (
499+ "Warning: %s: %s" ,
500+ warning .category .__name__ ,
501+ warning .message ,
502+ )
472503
473504 async def _handle_request (
474505 self ,
475506 message : RequestResponder ,
476507 req : Any ,
477508 session : ServerSession ,
509+ lifespan_context : object ,
478510 raise_exceptions : bool ,
479511 ):
480512 logger .info (f"Processing request of type { type (req ).__name__ } " )
@@ -491,6 +523,7 @@ async def _handle_request(
491523 message .request_id ,
492524 message .request_meta ,
493525 session ,
526+ lifespan_context ,
494527 )
495528 )
496529 response = await handler (req )
0 commit comments