Skip to content

Commit 8c90789

Browse files
committed
add tool and resource caching for mcp servers that support change notifications
1 parent f6807af commit 8c90789

File tree

2 files changed

+229
-12
lines changed

2 files changed

+229
-12
lines changed

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ class ServerCapabilities:
232232
tools: bool = False
233233
"""Whether the server offers any tools to call."""
234234

235+
tools_list_changed: bool = False
236+
"""Whether the server will emit notifications when the list of tools changes."""
237+
238+
resources_list_changed: bool = False
239+
"""Whether the server will emit notifications when the list of resources changes."""
240+
235241
completions: bool = False
236242
"""Whether the server offers autocompletion suggestions for prompts and resources."""
237243

@@ -244,12 +250,16 @@ def from_mcp_sdk(cls, mcp_capabilities: mcp_types.ServerCapabilities) -> ServerC
244250
Args:
245251
mcp_capabilities: The MCP SDK ServerCapabilities object.
246252
"""
253+
tools_cap = mcp_capabilities.tools
254+
resources_cap = mcp_capabilities.resources
247255
return cls(
248256
experimental=list(mcp_capabilities.experimental.keys()) if mcp_capabilities.experimental else None,
249257
logging=mcp_capabilities.logging is not None,
250258
prompts=mcp_capabilities.prompts is not None,
251-
resources=mcp_capabilities.resources is not None,
252-
tools=mcp_capabilities.tools is not None,
259+
resources=resources_cap is not None,
260+
tools=tools_cap is not None,
261+
tools_list_changed=bool(tools_cap.listChanged) if tools_cap else False,
262+
resources_list_changed=bool(resources_cap.listChanged) if resources_cap else False,
253263
completions=mcp_capabilities.completions is not None,
254264
)
255265

@@ -332,6 +342,11 @@ class MCPServer(AbstractToolset[Any], ABC):
332342
_server_capabilities: ServerCapabilities
333343
_instructions: str | None
334344

345+
_cached_tools: list[mcp_types.Tool] | None
346+
_tools_cache_valid: bool
347+
_cached_resources: list[Resource] | None
348+
_resources_cache_valid: bool
349+
335350
def __init__(
336351
self,
337352
tool_prefix: str | None = None,
@@ -366,6 +381,10 @@ def __post_init__(self):
366381
self._enter_lock = Lock()
367382
self._running_count = 0
368383
self._exit_stack = None
384+
self._cached_tools = None
385+
self._tools_cache_valid = False
386+
self._cached_resources = None
387+
self._resources_cache_valid = False
369388

370389
@abstractmethod
371390
@asynccontextmanager
@@ -430,13 +449,23 @@ def instructions(self) -> str | None:
430449
async def list_tools(self) -> list[mcp_types.Tool]:
431450
"""Retrieve tools that are currently active on the server.
432451
433-
Note:
434-
- We don't cache tools as they might change.
435-
- We also don't subscribe to the server to avoid complexity.
452+
Tools are cached when the server advertises `tools.listChanged` capability,
453+
with cache invalidation on tool change notifications and reconnection.
436454
"""
437455
async with self: # Ensure server is running
438-
result = await self._client.list_tools()
439-
return result.tools
456+
# Only cache if server supports listChanged notifications
457+
if self._server_capabilities.tools_list_changed:
458+
if self._cached_tools is not None and self._tools_cache_valid:
459+
return self._cached_tools
460+
461+
result = await self._client.list_tools()
462+
self._cached_tools = result.tools
463+
self._tools_cache_valid = True
464+
return result.tools
465+
else:
466+
# Server doesn't support notifications, always fetch fresh
467+
result = await self._client.list_tools()
468+
return result.tools
440469

441470
async def direct_call_tool(
442471
self,
@@ -542,9 +571,8 @@ def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]:
542571
async def list_resources(self) -> list[Resource]:
543572
"""Retrieve resources that are currently present on the server.
544573
545-
Note:
546-
- We don't cache resources as they might change.
547-
- We also don't subscribe to resource changes to avoid complexity.
574+
Resources are cached when the server advertises `resources.listChanged` capability,
575+
with cache invalidation on resource change notifications and reconnection.
548576
549577
Raises:
550578
MCPError: If the server returns an error.
@@ -553,10 +581,21 @@ async def list_resources(self) -> list[Resource]:
553581
if not self.capabilities.resources:
554582
return []
555583
try:
556-
result = await self._client.list_resources()
584+
# caching logic same as list_tools
585+
if self._server_capabilities.resources_list_changed:
586+
if self._cached_resources is not None and self._resources_cache_valid:
587+
return self._cached_resources
588+
589+
result = await self._client.list_resources()
590+
resources = [Resource.from_mcp_sdk(r) for r in result.resources]
591+
self._cached_resources = resources
592+
self._resources_cache_valid = True
593+
return resources
594+
else:
595+
result = await self._client.list_resources()
596+
return [Resource.from_mcp_sdk(r) for r in result.resources]
557597
except mcp_exceptions.McpError as e:
558598
raise MCPError.from_mcp_sdk(e) from e
559-
return [Resource.from_mcp_sdk(r) for r in result.resources]
560599

561600
async def list_resource_templates(self) -> list[ResourceTemplate]:
562601
"""Retrieve resource templates that are currently present on the server.
@@ -619,6 +658,12 @@ async def __aenter__(self) -> Self:
619658
"""
620659
async with self._enter_lock:
621660
if self._running_count == 0:
661+
# Invalidate caches on fresh connection
662+
self._cached_tools = None
663+
self._tools_cache_valid = False
664+
self._cached_resources = None
665+
self._resources_cache_valid = False
666+
622667
async with AsyncExitStack() as exit_stack:
623668
self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams())
624669
client = ClientSession(
@@ -628,6 +673,7 @@ async def __aenter__(self) -> Self:
628673
elicitation_callback=self.elicitation_callback,
629674
logging_callback=self.log_handler,
630675
read_timeout_seconds=timedelta(seconds=self.read_timeout),
676+
message_handler=self._handle_notification,
631677
)
632678
self._client = await exit_stack.enter_async_context(client)
633679

@@ -680,6 +726,13 @@ async def _sampling_callback(
680726
model=self.sampling_model.model_name,
681727
)
682728

729+
async def _handle_notification(self, message: Any) -> None:
730+
"""Handle notifications from the MCP server, invalidating caches as needed."""
731+
if isinstance(message, mcp_types.ToolListChangedNotification):
732+
self._tools_cache_valid = False
733+
elif isinstance(message, mcp_types.ResourceListChangedNotification):
734+
self._resources_cache_valid = False
735+
683736
async def _map_tool_result_part(
684737
self, part: mcp_types.ContentBlock
685738
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:

tests/test_mcp.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
ElicitRequestParams,
5555
ElicitResult,
5656
ImageContent,
57+
ResourceListChangedNotification,
5758
TextContent,
59+
ToolListChangedNotification,
5860
)
5961

6062
from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response
@@ -1987,3 +1989,165 @@ async def test_custom_http_client_not_closed():
19871989
assert len(tools) > 0
19881990

19891991
assert not custom_http_client.is_closed
1992+
1993+
1994+
# ============================================================================
1995+
# Tool and Resource Caching Tests
1996+
# ============================================================================
1997+
1998+
1999+
async def test_tools_caching_with_list_changed_capability(mcp_server: MCPServerStdio) -> None:
2000+
"""Test that list_tools() caches results when server supports listChanged notifications."""
2001+
async with mcp_server:
2002+
# Mock the server capabilities to indicate listChanged is supported
2003+
mcp_server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage]
2004+
2005+
# First call - should fetch from server and cache
2006+
tools1 = await mcp_server.list_tools()
2007+
assert len(tools1) > 0
2008+
assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage]
2009+
assert mcp_server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage]
2010+
2011+
# Mock _client.list_tools to track if it's called again
2012+
original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage]
2013+
call_count = 0
2014+
2015+
async def mock_list_tools(): # pragma: no cover
2016+
nonlocal call_count
2017+
call_count += 1
2018+
return await original_list_tools()
2019+
2020+
mcp_server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue]
2021+
2022+
# Second call - should return cached value without calling server
2023+
tools2 = await mcp_server.list_tools()
2024+
assert tools2 == tools1
2025+
assert call_count == 0 # list_tools should not have been called
2026+
2027+
2028+
async def test_tools_no_caching_without_list_changed_capability(mcp_server: MCPServerStdio) -> None:
2029+
"""Test that list_tools() always fetches fresh when server doesn't support listChanged."""
2030+
async with mcp_server:
2031+
# Verify the server doesn't advertise listChanged by default
2032+
# (this depends on the test MCP server implementation)
2033+
mcp_server._server_capabilities.tools_list_changed = False # pyright: ignore[reportPrivateUsage]
2034+
2035+
# First call
2036+
tools1 = await mcp_server.list_tools()
2037+
assert len(tools1) > 0
2038+
2039+
# Mock _client.list_tools to track calls
2040+
original_list_tools = mcp_server._client.list_tools # pyright: ignore[reportPrivateUsage]
2041+
call_count = 0
2042+
2043+
async def mock_list_tools():
2044+
nonlocal call_count
2045+
call_count += 1
2046+
return await original_list_tools()
2047+
2048+
mcp_server._client.list_tools = mock_list_tools # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue]
2049+
2050+
# Second call - should fetch fresh since no listChanged capability
2051+
tools2 = await mcp_server.list_tools()
2052+
assert tools2 == tools1
2053+
assert call_count == 1 # list_tools should have been called
2054+
2055+
2056+
async def test_tools_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None:
2057+
"""Test that tools cache is invalidated when ToolListChangedNotification is received."""
2058+
async with mcp_server:
2059+
# Enable caching
2060+
mcp_server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage]
2061+
2062+
# Populate cache
2063+
await mcp_server.list_tools()
2064+
assert mcp_server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage]
2065+
2066+
# Simulate receiving a tool list changed notification
2067+
notification = ToolListChangedNotification()
2068+
await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage]
2069+
2070+
# Cache should be invalidated
2071+
assert mcp_server._tools_cache_valid is False # pyright: ignore[reportPrivateUsage]
2072+
2073+
# Cached tools are still present but marked invalid
2074+
assert mcp_server._cached_tools is not None # pyright: ignore[reportPrivateUsage]
2075+
2076+
2077+
async def test_resources_caching_with_list_changed_capability(mcp_server: MCPServerStdio) -> None:
2078+
"""Test that list_resources() caches results when server supports listChanged notifications."""
2079+
async with mcp_server:
2080+
# Mock the server capabilities to indicate listChanged is supported
2081+
mcp_server._server_capabilities.resources_list_changed = True # pyright: ignore[reportPrivateUsage]
2082+
2083+
# First call - should fetch from server and cache
2084+
if mcp_server.capabilities.resources: # pragma: no branch
2085+
resources1 = await mcp_server.list_resources()
2086+
assert mcp_server._cached_resources is not None # pyright: ignore[reportPrivateUsage]
2087+
assert mcp_server._resources_cache_valid is True # pyright: ignore[reportPrivateUsage]
2088+
2089+
# Mock _client.list_resources to track if it's called again
2090+
original_list_resources = mcp_server._client.list_resources # pyright: ignore[reportPrivateUsage]
2091+
call_count = 0
2092+
2093+
async def mock_list_resources(): # pragma: no cover
2094+
nonlocal call_count
2095+
call_count += 1
2096+
return await original_list_resources()
2097+
2098+
mcp_server._client.list_resources = mock_list_resources # pyright: ignore[reportPrivateUsage,reportAttributeAccessIssue]
2099+
2100+
# Second call - should return cached value without calling server
2101+
resources2 = await mcp_server.list_resources()
2102+
assert resources2 == resources1
2103+
assert call_count == 0 # list_resources should not have been called
2104+
2105+
2106+
async def test_resources_cache_invalidation_on_notification(mcp_server: MCPServerStdio) -> None:
2107+
"""Test that resources cache is invalidated when ResourceListChangedNotification is received."""
2108+
async with mcp_server:
2109+
# Enable caching
2110+
mcp_server._server_capabilities.resources_list_changed = True # pyright: ignore[reportPrivateUsage]
2111+
2112+
# Populate cache (if server supports resources)
2113+
if mcp_server.capabilities.resources: # pragma: no branch
2114+
await mcp_server.list_resources()
2115+
assert mcp_server._resources_cache_valid is True # pyright: ignore[reportPrivateUsage]
2116+
2117+
# Simulate receiving a resource list changed notification
2118+
notification = ResourceListChangedNotification()
2119+
await mcp_server._handle_notification(notification) # pyright: ignore[reportPrivateUsage]
2120+
2121+
# Cache should be invalidated
2122+
assert mcp_server._resources_cache_valid is False # pyright: ignore[reportPrivateUsage]
2123+
2124+
2125+
async def test_cache_invalidation_on_reconnection() -> None:
2126+
"""Test that caches are cleared when reconnecting to the server."""
2127+
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
2128+
2129+
# First connection
2130+
async with server:
2131+
server._server_capabilities.tools_list_changed = True # pyright: ignore[reportPrivateUsage]
2132+
await server.list_tools()
2133+
assert server._cached_tools is not None # pyright: ignore[reportPrivateUsage]
2134+
assert server._tools_cache_valid is True # pyright: ignore[reportPrivateUsage]
2135+
2136+
# After exiting, the server is no longer running
2137+
# but cache state persists until next connection
2138+
2139+
# Reconnect
2140+
async with server:
2141+
# Cache should be cleared on fresh connection
2142+
assert server._cached_tools is None # pyright: ignore[reportPrivateUsage]
2143+
assert server._tools_cache_valid is False # pyright: ignore[reportPrivateUsage]
2144+
2145+
2146+
async def test_server_capabilities_list_changed_fields() -> None:
2147+
"""Test that ServerCapabilities correctly parses listChanged fields."""
2148+
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
2149+
async with server:
2150+
# Test that capabilities are accessible
2151+
caps = server.capabilities
2152+
assert isinstance(caps.tools_list_changed, bool)
2153+
assert isinstance(caps.resources_list_changed, bool)

0 commit comments

Comments
 (0)