Skip to content

Commit 009fa8b

Browse files
matthicksjclaude
andcommitted
fix: tool cache refresh with nested handler invocation
Fix MCP SDK issue #1298 where tool handlers fail to execute properly when streaming context is present in the parent process. The fix stores a direct reference to the list_tools function to avoid nested handler invocation which can disrupt async execution flow in streaming contexts. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 9a8592e commit 009fa8b

File tree

2 files changed

+223
-3
lines changed

2 files changed

+223
-3
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def __init__(
150150
}
151151
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
152152
self._tool_cache: dict[str, types.Tool] = {}
153+
# Store direct reference to list_tools function to avoid nested handler calls
154+
self._list_tools_func: Callable[[], Awaitable[list[types.Tool]]] | None = None
153155
logger.debug("Initializing server %r", name)
154156

155157
def create_initialization_options(
@@ -383,6 +385,11 @@ async def handler(req: types.UnsubscribeRequest):
383385
def list_tools(self):
384386
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
385387
logger.debug("Registering handler for ListToolsRequest")
388+
389+
# Store direct reference to the function for cache refresh.
390+
# This avoids nested handler invocation which can disrupt
391+
# async execution flow in streaming contexts.
392+
self._list_tools_func = func
386393

387394
async def handler(_: Any):
388395
tools = await func()
@@ -412,9 +419,15 @@ async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None
412419
Returns the Tool object if found, None otherwise.
413420
"""
414421
if tool_name not in self._tool_cache:
415-
if types.ListToolsRequest in self.request_handlers:
422+
# Use direct function reference to avoid nested handler invocation
423+
# which can disrupt async flow in streaming contexts
424+
if self._list_tools_func is not None:
416425
logger.debug("Tool cache miss for %s, refreshing cache", tool_name)
417-
await self.request_handlers[types.ListToolsRequest](None)
426+
tools = await self._list_tools_func()
427+
# Refresh the tool cache
428+
self._tool_cache.clear()
429+
for tool in tools:
430+
self._tool_cache[tool.name] = tool
418431

419432
tool = self._tool_cache.get(tool_name)
420433
if tool is None:
@@ -458,7 +471,6 @@ async def handler(req: types.CallToolRequest):
458471
except jsonschema.ValidationError as e:
459472
return self._make_error_result(f"Input validation error: {e.message}")
460473

461-
# tool call
462474
results = await func(tool_name, arguments)
463475

464476
# output normalization
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""Test for tool cache refresh bug with nested handler invocation (issue #1298).
2+
3+
This test verifies that cache refresh doesn't use nested handler invocation,
4+
which can disrupt async execution in streaming contexts.
5+
"""
6+
7+
from typing import Any
8+
9+
import anyio
10+
import pytest
11+
12+
from mcp.client.session import ClientSession
13+
from mcp.server.lowlevel import Server
14+
from mcp.types import ListToolsRequest, TextContent, Tool
15+
16+
17+
@pytest.mark.anyio
18+
async def test_no_nested_handler_invocation_on_cache_refresh():
19+
"""Verify that cache refresh doesn't use nested handler invocation.
20+
21+
Issue #1298: Tool handlers can fail when cache refresh triggers
22+
nested handler invocation via self.request_handlers[ListToolsRequest](None),
23+
which disrupts async execution flow in streaming contexts.
24+
25+
This test verifies the fix by detecting whether nested handler
26+
invocation occurs during cache refresh.
27+
"""
28+
server = Server("test-server")
29+
30+
# Track handler invocations
31+
handler_invocations = []
32+
33+
@server.list_tools()
34+
async def list_tools():
35+
# Normal tool listing
36+
await anyio.sleep(0.001)
37+
return [
38+
Tool(
39+
name="test_tool",
40+
description="Test tool",
41+
inputSchema={"type": "object", "properties": {}}
42+
)
43+
]
44+
45+
@server.call_tool()
46+
async def call_tool(name: str, arguments: dict[str, Any]):
47+
# Simple tool implementation
48+
return [TextContent(type="text", text="Tool result")]
49+
50+
# Intercept the ListToolsRequest handler to detect nested invocation
51+
original_handler = None
52+
53+
def setup_handler_interceptor():
54+
nonlocal original_handler
55+
original_handler = server.request_handlers.get(ListToolsRequest)
56+
57+
async def interceptor(req):
58+
# Track the invocation
59+
# req is None for nested invocations (the problematic pattern)
60+
# req is a proper request object for normal invocations
61+
if req is None:
62+
handler_invocations.append("nested")
63+
else:
64+
handler_invocations.append("normal")
65+
66+
# Call the original handler
67+
if original_handler:
68+
return await original_handler(req)
69+
return None
70+
71+
server.request_handlers[ListToolsRequest] = interceptor
72+
73+
# Set up the interceptor after decorators have run
74+
setup_handler_interceptor()
75+
76+
# Setup communication channels
77+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
78+
from mcp.shared.message import SessionMessage
79+
80+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
81+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
82+
83+
async def run_server():
84+
await server.run(
85+
client_to_server_receive,
86+
server_to_client_send,
87+
server.create_initialization_options()
88+
)
89+
90+
async with anyio.create_task_group() as tg:
91+
tg.start_soon(run_server)
92+
93+
async with ClientSession(server_to_client_receive, client_to_server_send) as session:
94+
await session.initialize()
95+
96+
# Clear the cache to force a refresh on next tool call
97+
server._tool_cache.clear()
98+
99+
# Make a tool call - this should trigger cache refresh
100+
result = await session.call_tool("test_tool", {})
101+
102+
# Verify the tool call succeeded
103+
assert result is not None
104+
assert not result.isError
105+
assert result.content[0].text == "Tool result"
106+
107+
# Check if nested handler invocation occurred
108+
has_nested_invocation = "nested" in handler_invocations
109+
110+
# The bug is present if nested handler invocation occurs
111+
assert not has_nested_invocation, (
112+
"Nested handler invocation detected during cache refresh. "
113+
"This pattern (calling request_handlers[ListToolsRequest](None)) "
114+
"can disrupt async execution in streaming contexts (issue #1298)."
115+
)
116+
117+
tg.cancel_scope.cancel()
118+
119+
120+
@pytest.mark.anyio
121+
async def test_concurrent_cache_refresh_safety():
122+
"""Verify that concurrent tool calls with cache refresh work correctly.
123+
124+
Multiple concurrent tool calls that all trigger cache refresh should
125+
not cause issues or result in nested handler invocations.
126+
"""
127+
server = Server("test-server")
128+
129+
# Track concurrent handler invocations
130+
nested_invocations = 0
131+
132+
@server.list_tools()
133+
async def list_tools():
134+
await anyio.sleep(0.01) # Simulate some async work
135+
return [
136+
Tool(
137+
name=f"tool_{i}",
138+
description=f"Tool {i}",
139+
inputSchema={"type": "object", "properties": {}}
140+
)
141+
for i in range(3)
142+
]
143+
144+
@server.call_tool()
145+
async def call_tool(name: str, arguments: dict[str, Any]):
146+
await anyio.sleep(0.001)
147+
return [TextContent(type="text", text=f"Result from {name}")]
148+
149+
# Intercept handler to detect nested invocations
150+
original_handler = server.request_handlers.get(ListToolsRequest)
151+
152+
async def interceptor(req):
153+
nonlocal nested_invocations
154+
if req is None:
155+
nested_invocations += 1
156+
if original_handler:
157+
return await original_handler(req)
158+
return None
159+
160+
if original_handler:
161+
server.request_handlers[ListToolsRequest] = interceptor
162+
163+
# Setup communication
164+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
165+
from mcp.shared.message import SessionMessage
166+
167+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
168+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
169+
170+
async def run_server():
171+
await server.run(
172+
client_to_server_receive,
173+
server_to_client_send,
174+
server.create_initialization_options()
175+
)
176+
177+
async with anyio.create_task_group() as tg:
178+
tg.start_soon(run_server)
179+
180+
async with ClientSession(server_to_client_receive, client_to_server_send) as session:
181+
await session.initialize()
182+
183+
# Clear cache to force refresh
184+
server._tool_cache.clear()
185+
186+
# Make concurrent tool calls
187+
import asyncio
188+
results = await asyncio.gather(
189+
session.call_tool("tool_0", {}),
190+
session.call_tool("tool_1", {}),
191+
session.call_tool("tool_2", {}),
192+
return_exceptions=True
193+
)
194+
195+
# Verify all calls succeeded
196+
for i, result in enumerate(results):
197+
assert not isinstance(result, Exception), f"Tool {i} failed: {result}"
198+
assert not result.isError
199+
assert f"tool_{i}" in result.content[0].text
200+
201+
# Verify no nested invocations occurred
202+
assert nested_invocations == 0, (
203+
f"Detected {nested_invocations} nested handler invocations "
204+
"during concurrent cache refresh. This indicates the bug from "
205+
"issue #1298 is present."
206+
)
207+
208+
tg.cancel_scope.cancel()

0 commit comments

Comments
 (0)