Skip to content

Commit 54692c3

Browse files
authored
Expose InitializeResult to middleware (#2516)
Wrap responder.respond() to capture the InitializeResult before it's sent to the write stream, then return it through the middleware chain. This allows middleware (e.g., logging) to access the server's initialize response, not just the client's request.
1 parent d5ef413 commit 54692c3

File tree

3 files changed

+67
-6
lines changed

3 files changed

+67
-6
lines changed

src/fastmcp/server/low_level.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,37 @@ async def _received_request(
5959
from fastmcp.server.middleware.middleware import MiddlewareContext
6060

6161
if isinstance(responder.request.root, mcp.types.InitializeRequest):
62+
# The MCP SDK's ServerSession._received_request() handles the
63+
# initialize request internally by calling responder.respond()
64+
# to send the InitializeResult directly to the write stream, then
65+
# returning None. This bypasses the middleware return path entirely,
66+
# so middleware would only see the request, never the response.
67+
#
68+
# To expose the response to middleware (e.g., for logging server
69+
# capabilities), we wrap responder.respond() to capture the
70+
# InitializeResult before it's sent, then return it from
71+
# call_original_handler so it flows back through the middleware chain.
72+
captured_response: mcp.types.ServerResult | None = None
73+
original_respond = responder.respond
74+
75+
async def capturing_respond(
76+
response: mcp.types.ServerResult,
77+
) -> None:
78+
nonlocal captured_response
79+
captured_response = response
80+
return await original_respond(response)
81+
82+
responder.respond = capturing_respond # type: ignore[method-assign]
6283

6384
async def call_original_handler(
6485
ctx: MiddlewareContext,
65-
) -> None:
66-
return await super(MiddlewareServerSession, self)._received_request(
67-
responder
68-
)
86+
) -> mcp.types.InitializeResult | None:
87+
await super(MiddlewareServerSession, self)._received_request(responder)
88+
if captured_response is not None and isinstance(
89+
captured_response.root, mcp.types.InitializeResult
90+
):
91+
return captured_response.root
92+
return None
6993

7094
async with fastmcp.server.context.Context(
7195
fastmcp=self.fastmcp

src/fastmcp/server/middleware/middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ async def on_notification(
150150
async def on_initialize(
151151
self,
152152
context: MiddlewareContext[mt.InitializeRequest],
153-
call_next: CallNext[mt.InitializeRequest, None],
154-
) -> None:
153+
call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
154+
) -> mt.InitializeResult | None:
155155
return await call_next(context)
156156

157157
async def on_call_tool(

tests/server/middleware/test_initialization_middleware.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,40 @@ def test_tool() -> str:
249249
# This test shows the pattern, but actual cross-request state would need
250250
# external storage (Redis, DB, etc.)
251251
# The middleware.tool_state might be None if state doesn't persist
252+
253+
254+
async def test_middleware_can_access_initialize_result():
255+
"""Test that middleware can access the InitializeResult from call_next().
256+
257+
This verifies that the initialize response is returned through the middleware
258+
chain, not just sent directly via the responder (fixes #2504).
259+
"""
260+
server = FastMCP("TestServer")
261+
262+
class ResponseCapturingMiddleware(Middleware):
263+
def __init__(self):
264+
super().__init__()
265+
self.initialize_result: mt.InitializeResult | None = None
266+
267+
async def on_initialize(
268+
self,
269+
context: MiddlewareContext[mt.InitializeRequest],
270+
call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
271+
) -> mt.InitializeResult | None:
272+
# Call next and capture the result
273+
result = await call_next(context)
274+
self.initialize_result = result
275+
return result
276+
277+
middleware = ResponseCapturingMiddleware()
278+
server.add_middleware(middleware)
279+
280+
async with Client(server):
281+
# Middleware should have captured the InitializeResult
282+
assert middleware.initialize_result is not None
283+
assert isinstance(middleware.initialize_result, mt.InitializeResult)
284+
285+
# Verify the result contains expected server info
286+
assert middleware.initialize_result.serverInfo.name == "TestServer"
287+
assert middleware.initialize_result.protocolVersion is not None
288+
assert middleware.initialize_result.capabilities is not None

0 commit comments

Comments
 (0)