|
1 | 1 | from unittest.mock import AsyncMock, patch |
2 | 2 |
|
3 | 3 | import pytest |
4 | | -from mcp.types import ListToolsResult, Tool as MCPTool |
| 4 | +from mcp.types import ListPromptsResult, ListToolsResult, Prompt, Tool as MCPTool |
5 | 5 |
|
6 | 6 | from agents import Agent |
7 | 7 | from agents.mcp import MCPServerStdio |
|
14 | 14 | @patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) |
15 | 15 | @patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) |
16 | 16 | @patch("mcp.client.session.ClientSession.list_tools") |
17 | | -async def test_server_caching_works( |
| 17 | +async def test_server_caching_tools_works( |
18 | 18 | mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client |
19 | 19 | ): |
20 | 20 | """Test that if we turn caching on, the list of tools is cached and not fetched from the server |
@@ -61,3 +61,52 @@ async def test_server_caching_works( |
61 | 61 | # Without invalidating the cache, calling list_tools() again should return the cached value |
62 | 62 | result_tools = await server.list_tools(run_context, agent) |
63 | 63 | assert result_tools == tools |
| 64 | + |
| 65 | +@pytest.mark.asyncio |
| 66 | +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) |
| 67 | +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) |
| 68 | +@patch("mcp.client.session.ClientSession.list_tools") |
| 69 | +async def test_server_caching_prompts_works( |
| 70 | + mock_list_prompts: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client |
| 71 | +): |
| 72 | + """Test that if we turn caching on, the list of prompts is cached and not fetched from the server |
| 73 | + on each call to `list_prompts()`. |
| 74 | + """ |
| 75 | + server = MCPServerStdio( |
| 76 | + params={ |
| 77 | + "command": tee, |
| 78 | + }, |
| 79 | + cache_prompts_list=True, |
| 80 | + ) |
| 81 | + |
| 82 | + prompts = [ |
| 83 | + Prompt(name="prompt1"), |
| 84 | + Prompt(name="prompt2"), |
| 85 | + ] |
| 86 | + |
| 87 | + mock_list_prompts.return_value = ListPromptsResult(prompts=prompts) |
| 88 | + |
| 89 | + async with server: |
| 90 | + |
| 91 | + # Call list_prompts() multiple times |
| 92 | + result_prompts = await server.list_prompts() |
| 93 | + assert result_prompts == prompts |
| 94 | + |
| 95 | + assert mock_list_prompts.call_count == 1, "list_prompts() should have been called once" |
| 96 | + |
| 97 | + # Call list_prompts() again, should return the cached value |
| 98 | + result_prompts = await server.list_prompts() |
| 99 | + assert result_prompts == prompts |
| 100 | + |
| 101 | + assert mock_list_prompts.call_count == 1, "list_prompts() should not have been called again" |
| 102 | + |
| 103 | + # Invalidate the cache and call list_prompts() again |
| 104 | + server.invalidate_prompts_cache() |
| 105 | + result_prompts = await server.list_prompts() |
| 106 | + assert result_prompts == prompts |
| 107 | + |
| 108 | + assert mock_list_prompts.call_count == 2, "list_prompts() should be called again" |
| 109 | + |
| 110 | + # Without invalidating the cache, calling list_prompts() again should return the cached value |
| 111 | + result_prompts = await server.list_prompts() |
| 112 | + assert result_prompts == prompts |
0 commit comments