@@ -72,11 +72,12 @@ async def main():
72
72
import warnings
73
73
from collections .abc import AsyncIterator , Awaitable , Callable , Iterable
74
74
from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
75
- from typing import Any , Generic , TypeVar
75
+ from typing import Any , Generic
76
76
77
77
import anyio
78
78
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
79
79
from pydantic import AnyUrl
80
+ from typing_extensions import TypeVar
80
81
81
82
import mcp .types as types
82
83
from mcp .server .lowlevel .helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ async def main():
85
86
from mcp .server .stdio import stdio_server as stdio_server
86
87
from mcp .shared .context import RequestContext
87
88
from mcp .shared .exceptions import McpError
88
- from mcp .shared .message import SessionMessage
89
+ from mcp .shared .message import ServerMessageMetadata , SessionMessage
89
90
from mcp .shared .session import RequestResponder
90
91
91
92
logger = logging .getLogger (__name__ )
92
93
93
94
LifespanResultT = TypeVar ("LifespanResultT" )
95
+ RequestT = TypeVar ("RequestT" , default = Any )
94
96
95
97
# This will be properly typed in each Server instance's context
96
- request_ctx : contextvars .ContextVar [RequestContext [ServerSession , Any ]] = (
98
+ request_ctx : contextvars .ContextVar [RequestContext [ServerSession , Any , Any ]] = (
97
99
contextvars .ContextVar ("request_ctx" )
98
100
)
99
101
@@ -111,7 +113,7 @@ def __init__(
111
113
112
114
113
115
@asynccontextmanager
114
- async def lifespan (server : Server [LifespanResultT ]) -> AsyncIterator [object ]:
116
+ async def lifespan (server : Server [LifespanResultT , RequestT ]) -> AsyncIterator [object ]:
115
117
"""Default lifespan context manager that does nothing.
116
118
117
119
Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
123
125
yield {}
124
126
125
127
126
- class Server (Generic [LifespanResultT ]):
128
+ class Server (Generic [LifespanResultT , RequestT ]):
127
129
def __init__ (
128
130
self ,
129
131
name : str ,
130
132
version : str | None = None ,
131
133
instructions : str | None = None ,
132
134
lifespan : Callable [
133
- [Server [LifespanResultT ]], AbstractAsyncContextManager [LifespanResultT ]
135
+ [Server [LifespanResultT , RequestT ]],
136
+ AbstractAsyncContextManager [LifespanResultT ],
134
137
] = lifespan ,
135
138
):
136
139
self .name = name
@@ -215,7 +218,9 @@ def get_capabilities(
215
218
)
216
219
217
220
@property
218
- def request_context (self ) -> RequestContext [ServerSession , LifespanResultT ]:
221
+ def request_context (
222
+ self ,
223
+ ) -> RequestContext [ServerSession , LifespanResultT , RequestT ]:
219
224
"""If called outside of a request context, this will raise a LookupError."""
220
225
return request_ctx .get ()
221
226
@@ -555,6 +560,13 @@ async def _handle_request(
555
560
556
561
token = None
557
562
try :
563
+ # Extract request context from message metadata
564
+ request_data = None
565
+ if message .message_metadata is not None and isinstance (
566
+ message .message_metadata , ServerMessageMetadata
567
+ ):
568
+ request_data = message .message_metadata .request_context
569
+
558
570
# Set our global state that can be retrieved via
559
571
# app.get_request_context()
560
572
token = request_ctx .set (
@@ -563,6 +575,7 @@ async def _handle_request(
563
575
message .request_meta ,
564
576
session ,
565
577
lifespan_context ,
578
+ request = request_data ,
566
579
)
567
580
)
568
581
response = await handler (req )
0 commit comments