Skip to content

Commit 7f68fab

Browse files
committed
Add unit test test_server_caching_prompts_works
1 parent 90a349f commit 7f68fab

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

tests/mcp/test_caching.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import AsyncMock, patch
22

33
import pytest
4-
from mcp.types import ListToolsResult, Tool as MCPTool
4+
from mcp.types import ListPromptsResult, ListToolsResult, Prompt, Tool as MCPTool
55

66
from agents import Agent
77
from agents.mcp import MCPServerStdio
@@ -14,7 +14,7 @@
1414
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
1515
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
1616
@patch("mcp.client.session.ClientSession.list_tools")
17-
async def test_server_caching_works(
17+
async def test_server_caching_tools_works(
1818
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
1919
):
2020
"""Test that if we turn caching on, the list of tools is cached and not fetched from the server
@@ -61,3 +61,52 @@ async def test_server_caching_works(
6161
# Without invalidating the cache, calling list_tools() again should return the cached value
6262
result_tools = await server.list_tools(run_context, agent)
6363
assert result_tools == tools
64+
65+
@pytest.mark.asyncio
66+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
67+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
68+
@patch("mcp.client.session.ClientSession.list_tools")
69+
async def test_server_caching_prompts_works(
70+
mock_list_prompts: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
71+
):
72+
"""Test that if we turn caching on, the list of prompts is cached and not fetched from the server
73+
on each call to `list_prompts()`.
74+
"""
75+
server = MCPServerStdio(
76+
params={
77+
"command": tee,
78+
},
79+
cache_prompts_list=True,
80+
)
81+
82+
prompts = [
83+
Prompt(name="prompt1"),
84+
Prompt(name="prompt2"),
85+
]
86+
87+
mock_list_prompts.return_value = ListPromptsResult(prompts=prompts)
88+
89+
async with server:
90+
91+
# Call list_prompts() multiple times
92+
result_prompts = await server.list_prompts()
93+
assert result_prompts == prompts
94+
95+
assert mock_list_prompts.call_count == 1, "list_prompts() should have been called once"
96+
97+
# Call list_prompts() again, should return the cached value
98+
result_prompts = await server.list_prompts()
99+
assert result_prompts == prompts
100+
101+
assert mock_list_prompts.call_count == 1, "list_prompts() should not have been called again"
102+
103+
# Invalidate the cache and call list_prompts() again
104+
server.invalidate_prompts_cache()
105+
result_prompts = await server.list_prompts()
106+
assert result_prompts == prompts
107+
108+
assert mock_list_prompts.call_count == 2, "list_prompts() should be called again"
109+
110+
# Without invalidating the cache, calling list_prompts() again should return the cached value
111+
result_prompts = await server.list_prompts()
112+
assert result_prompts == prompts

0 commit comments

Comments
 (0)