@@ -68,7 +68,8 @@ async def main():
6868import logging
6969import warnings
7070from collections .abc import Awaitable , Callable
71- from typing import Any , Sequence
71+ from contextlib import AbstractAsyncContextManager , asynccontextmanager
72+ from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
7273
7374from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7475from pydantic import AnyUrl
@@ -84,7 +85,10 @@ async def main():
8485
8586logger = logging .getLogger (__name__ )
8687
87- request_ctx : contextvars .ContextVar [RequestContext [ServerSession ]] = (
88+ LifespanResultT = TypeVar ("LifespanResultT" )
89+
90+ # This will be properly typed in each Server instance's context
91+ request_ctx : contextvars .ContextVar [RequestContext [ServerSession , Any ]] = (
8892 contextvars .ContextVar ("request_ctx" )
8993)
9094
@@ -101,13 +105,33 @@ def __init__(
101105 self .tools_changed = tools_changed
102106
103107
104- class Server :
108+ @asynccontextmanager
109+ async def lifespan (server : "Server" ) -> AsyncIterator [object ]:
110+ """Default lifespan context manager that does nothing.
111+
112+ Args:
113+ server: The server instance this lifespan is managing
114+
115+ Returns:
116+ An empty context object
117+ """
118+ yield {}
119+
120+
121+ class Server (Generic [LifespanResultT ]):
105122 def __init__ (
106- self , name : str , version : str | None = None , instructions : str | None = None
123+ self ,
124+ name : str ,
125+ version : str | None = None ,
126+ instructions : str | None = None ,
127+ lifespan : Callable [
128+ ["Server" ], AbstractAsyncContextManager [LifespanResultT ]
129+ ] = lifespan ,
107130 ):
108131 self .name = name
109132 self .version = version
110133 self .instructions = instructions
134+ self .lifespan = lifespan
111135 self .request_handlers : dict [
112136 type , Callable [..., Awaitable [types .ServerResult ]]
113137 ] = {
@@ -188,7 +212,7 @@ def get_capabilities(
188212 )
189213
190214 @property
191- def request_context (self ) -> RequestContext [ServerSession ]:
215+ def request_context (self ) -> RequestContext [ServerSession , LifespanResultT ]:
192216 """If called outside of a request context, this will raise a LookupError."""
193217 return request_ctx .get ()
194218
@@ -446,9 +470,14 @@ async def run(
446470 raise_exceptions : bool = False ,
447471 ):
448472 with warnings .catch_warnings (record = True ) as w :
449- async with ServerSession (
450- read_stream , write_stream , initialization_options
451- ) as session :
473+ from contextlib import AsyncExitStack
474+
475+ async with AsyncExitStack () as stack :
476+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477+ session = await stack .enter_async_context (
478+ ServerSession (read_stream , write_stream , initialization_options )
479+ )
480+
452481 async for message in session .incoming_messages :
453482 logger .debug (f"Received message: { message } " )
454483
@@ -460,21 +489,28 @@ async def run(
460489 ):
461490 with responder :
462491 await self ._handle_request (
463- message , req , session , raise_exceptions
492+ message ,
493+ req ,
494+ session ,
495+ lifespan_context ,
496+ raise_exceptions ,
464497 )
465498 case types .ClientNotification (root = notify ):
466499 await self ._handle_notification (notify )
467500
468501 for warning in w :
469502 logger .info (
470- f"Warning: { warning .category .__name__ } : { warning .message } "
503+ "Warning: %s: %s" ,
504+ warning .category .__name__ ,
505+ warning .message ,
471506 )
472507
473508 async def _handle_request (
474509 self ,
475510 message : RequestResponder ,
476511 req : Any ,
477512 session : ServerSession ,
513+ lifespan_context : LifespanResultT ,
478514 raise_exceptions : bool ,
479515 ):
480516 logger .info (f"Processing request of type { type (req ).__name__ } " )
@@ -491,6 +527,7 @@ async def _handle_request(
491527 message .request_id ,
492528 message .request_meta ,
493529 session ,
530+ lifespan_context ,
494531 )
495532 )
496533 response = await handler (req )
0 commit comments