diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 8c459383c..58d47afbf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -150,6 +150,8 @@ def __init__( } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self._tool_cache: dict[str, types.Tool] = {} + # Store direct reference to list_tools function to avoid nested handler calls + self._list_tools_func: Callable[[], Awaitable[list[types.Tool]]] | None = None logger.debug("Initializing server %r", name) def create_initialization_options( @@ -384,6 +386,11 @@ def list_tools(self): def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): logger.debug("Registering handler for ListToolsRequest") + # Store direct reference to the function for cache refresh. + # This avoids nested handler invocation which can disrupt + # async execution flow in streaming contexts. + self._list_tools_func = func + async def handler(_: Any): tools = await func() # Refresh the tool cache @@ -412,9 +419,15 @@ async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None Returns the Tool object if found, None otherwise. """ if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: + # Use direct function reference to avoid nested handler invocation + # which can disrupt async flow in streaming contexts + if self._list_tools_func is not None: logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) + tools = await self._list_tools_func() + # Refresh the tool cache + self._tool_cache.clear() + for tool in tools: + self._tool_cache[tool.name] = tool tool = self._tool_cache.get(tool_name) if tool is None: @@ -458,7 +471,6 @@ async def handler(req: types.CallToolRequest): except jsonschema.ValidationError as e: return self._make_error_result(f"Input validation error: {e.message}") - # tool call results = await func(tool_name, arguments) # output normalization diff --git a/tests/server/test_tool_cache_refresh_bug.py b/tests/server/test_tool_cache_refresh_bug.py new file mode 100644 index 000000000..11d296707 --- /dev/null +++ b/tests/server/test_tool_cache_refresh_bug.py @@ -0,0 +1,195 @@ +"""Test for tool cache refresh bug with nested handler invocation (issue #1298). + +This test verifies that cache refresh doesn't use nested handler invocation, +which can disrupt async execution in streaming contexts. +""" + +from typing import Any, cast + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.types import CallToolResult, ListToolsRequest, TextContent, Tool + + +@pytest.mark.anyio +async def test_no_nested_handler_invocation_on_cache_refresh(): + """Verify that cache refresh doesn't use nested handler invocation. + + Issue #1298: Tool handlers can fail when cache refresh triggers + nested handler invocation via self.request_handlers[ListToolsRequest](None), + which disrupts async execution flow in streaming contexts. + + This test verifies the fix by detecting whether nested handler + invocation occurs during cache refresh. + """ + server = Server("test-server") + + # Track handler invocations + handler_invocations: list[str] = [] + + @server.list_tools() + async def list_tools(): + # Normal tool listing + await anyio.sleep(0.001) + return [Tool(name="test_tool", description="Test tool", inputSchema={"type": "object", "properties": {}})] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]): + # Simple tool implementation + return [TextContent(type="text", text="Tool result")] + + # Intercept the ListToolsRequest handler to detect nested invocation + original_handler = None + + def setup_handler_interceptor(): + nonlocal original_handler + original_handler = server.request_handlers.get(ListToolsRequest) + + async def interceptor(req: Any) -> Any: + # Track the invocation + # req is None for nested invocations (the problematic pattern) + # req is a proper request object for normal invocations + if req is None: + handler_invocations.append("nested") + else: + handler_invocations.append("normal") + + # Call the original handler + if original_handler: + return await original_handler(req) + return None + + server.request_handlers[ListToolsRequest] = cast(Any, interceptor) + + # Set up the interceptor after decorators have run + setup_handler_interceptor() + + # Setup communication channels + from mcp.shared.message import SessionMessage + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(client_to_server_receive, server_to_client_send, server.create_initialization_options()) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession(server_to_client_receive, client_to_server_send) as session: + await session.initialize() + + # Clear the cache to force a refresh on next tool call + server._tool_cache.clear() + + # Make a tool call - this should trigger cache refresh + result = await session.call_tool("test_tool", {}) + + # Verify the tool call succeeded + assert result is not None + assert not result.isError + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Tool result" + + # Check if nested handler invocation occurred + has_nested_invocation = "nested" in handler_invocations + + # The bug is present if nested handler invocation occurs + assert not has_nested_invocation, ( + "Nested handler invocation detected during cache refresh. " + "This pattern (calling request_handlers[ListToolsRequest](None)) " + "can disrupt async execution in streaming contexts (issue #1298)." + ) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_concurrent_cache_refresh_safety(): + """Verify that concurrent tool calls with cache refresh work correctly. + + Multiple concurrent tool calls that all trigger cache refresh should + not cause issues or result in nested handler invocations. + """ + server = Server("test-server") + + # Track concurrent handler invocations + nested_invocations = 0 + + @server.list_tools() + async def list_tools(): + await anyio.sleep(0.01) # Simulate some async work + return [ + Tool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object", "properties": {}}) + for i in range(3) + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]): + await anyio.sleep(0.001) + return [TextContent(type="text", text=f"Result from {name}")] + + # Intercept handler to detect nested invocations + original_handler = server.request_handlers.get(ListToolsRequest) + + async def interceptor(req: Any) -> Any: + nonlocal nested_invocations + if req is None: + nested_invocations += 1 + if original_handler: + return await original_handler(req) + return None + + if original_handler: + server.request_handlers[ListToolsRequest] = cast(Any, interceptor) + + # Setup communication + from mcp.shared.message import SessionMessage + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(client_to_server_receive, server_to_client_send, server.create_initialization_options()) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession(server_to_client_receive, client_to_server_send) as session: + await session.initialize() + + # Clear cache to force refresh + server._tool_cache.clear() + + # Make concurrent tool calls + import asyncio + + results = await asyncio.gather( + session.call_tool("tool_0", {}), + session.call_tool("tool_1", {}), + session.call_tool("tool_2", {}), + return_exceptions=True, + ) + + # Verify all calls succeeded + for i, result in enumerate(results): + assert not isinstance(result, Exception), f"Tool {i} failed: {result}" + # Type narrowing: result is CallToolResult at this point, not Exception + assert isinstance(result, CallToolResult) + assert not result.isError + content = result.content[0] + assert isinstance(content, TextContent) + assert f"tool_{i}" in content.text + + # Verify no nested invocations occurred + assert nested_invocations == 0, ( + f"Detected {nested_invocations} nested handler invocations " + "during concurrent cache refresh. This indicates the bug from " + "issue #1298 is present." + ) + + tg.cancel_scope.cancel()