diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index aa240da7a..8b918b09d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -491,6 +491,7 @@ async def handle_sse(request: Request) -> None: streams[0], streams[1], self._mcp_server.create_initialization_options(), + extra_metadata={"http_request": request}, ) return Starlette( @@ -501,6 +502,15 @@ async def handle_sse(request: Request) -> None: ], ) + def get_http_request(self) -> Request | None: + ctx = self.get_context() + if (ctx.request_context and + ctx.request_context.meta and + hasattr(ctx.request_context.meta, "extra_metadata")): + req: Request = ctx.request_context.meta.extra_metadata.get("http_request") # type: ignore + return req + return None + async def list_prompts(self) -> list[MCPPrompt]: """List all available prompts.""" prompts = self._prompt_manager.list_prompts() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff3051..868aeb06e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -479,6 +479,7 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + extra_metadata: dict[str, Any] | None = None, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -489,7 +490,10 @@ async def run( async with anyio.create_task_group() as tg: async for message in session.incoming_messages: logger.debug(f"Received message: {message}") - + if (hasattr(message, "request_meta") and + getattr(message, "request_meta")): + message.request_meta = message.request_meta or types.RequestParams.Meta() # type: ignore + message.request_meta.extra_metadata = extra_metadata # type: ignore tg.start_soon( self._handle_message, message, diff --git a/src/mcp/types.py b/src/mcp/types.py index bd71d51f0..67f8f74be 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -41,6 +41,7 @@ class RequestParams(BaseModel): class Meta(BaseModel): progressToken: ProgressToken | None = None + extra_metadata: dict[str, Any] | None = None """ If specified, the caller requests out-of-band progress notifications for this request (as represented by notifications/progress). The value of this