Skip to content

Commit 204701c

Browse files
committed
Add prompts caching + working examples
1 parent eafd8df commit 204701c

File tree

4 files changed

+192
-15
lines changed

4 files changed

+192
-15
lines changed

examples/mcp/caching/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Caching Example
2+
3+
This example show how to integrate tools and prompts caching using a Streamable HTTP server in [server.py](server.py).
4+
5+
Run the example via:
6+
7+
```
8+
uv run python examples/mcp/caching/main.py
9+
```
10+
11+
## Details
12+
13+
The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `https://localhost:8000/mcp`.

examples/mcp/caching/main.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import asyncio
2+
import os
3+
import shutil
4+
import subprocess
5+
import time
6+
from typing import Any
7+
8+
from agents import gen_trace_id, trace
9+
from agents.mcp import MCPServerStreamableHttp
10+
11+
12+
async def run(mcp_server: MCPServerStreamableHttp):
13+
print(f"Cached tools before invoking tool_list")
14+
print(mcp_server._tools_list)
15+
await mcp_server.list_tools()
16+
print(f"Cached tools names after invoking list_tools")
17+
cached_tools_list = mcp_server._tools_list
18+
for tool in cached_tools_list:
19+
print(f"name: {tool.name}")
20+
21+
print(f"Cached prompts before invoking list_prompts")
22+
print(mcp_server._prompts_list)
23+
await mcp_server.list_prompts()
24+
print(f"\nCached prompts after invoking list_prompts")
25+
cached_prompts_list = mcp_server._prompts_list
26+
for prompt in cached_prompts_list.prompts:
27+
print(f"name: {prompt.name}")
28+
29+
async def main():
30+
async with MCPServerStreamableHttp(
31+
name="Streamable HTTP Python Server",
32+
cache_tools_list=True,
33+
cache_prompts_list=True,
34+
params={
35+
"url": "http://localhost:8000/mcp",
36+
},
37+
) as server:
38+
trace_id = gen_trace_id()
39+
with trace(workflow_name="Caching Example", trace_id=trace_id):
40+
print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n")
41+
await run(server)
42+
43+
44+
if __name__ == "__main__":
45+
# Let's make sure the user has uv installed
46+
if not shutil.which("uv"):
47+
raise RuntimeError(
48+
"uv is not installed. Please install it: https://docs.astral.sh/uv/getting-started/installation/"
49+
)
50+
51+
# We'll run the Streamable HTTP server in a subprocess. Usually this would be a remote server, but for this
52+
# demo, we'll run it locally at http://localhost:8000/mcp
53+
process: subprocess.Popen[Any] | None = None
54+
try:
55+
this_dir = os.path.dirname(os.path.abspath(__file__))
56+
server_file = os.path.join(this_dir, "server.py")
57+
58+
print("Starting Streamable HTTP server at http://localhost:8000/mcp ...")
59+
60+
# Run `uv run server.py` to start the Streamable HTTP server
61+
process = subprocess.Popen(["uv", "run", server_file])
62+
# Give it 3 seconds to start
63+
time.sleep(3)
64+
65+
print("Streamable HTTP server started. Running example...\n\n")
66+
except Exception as e:
67+
print(f"Error starting Streamable HTTP server: {e}")
68+
exit(1)
69+
70+
try:
71+
asyncio.run(main())
72+
finally:
73+
if process:
74+
process.terminate()

examples/mcp/caching/server.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import random
2+
3+
import requests
4+
from mcp.server.fastmcp import FastMCP
5+
6+
# Create server
7+
mcp = FastMCP("Echo Server")
8+
9+
10+
@mcp.tool()
11+
def add(a: int, b: int) -> int:
12+
"""Add two numbers"""
13+
print(f"[debug-server] add({a}, {b})")
14+
return a + b
15+
16+
17+
@mcp.tool()
18+
def get_secret_word() -> str:
19+
print("[debug-server] get_secret_word()")
20+
return random.choice(["apple", "banana", "cherry"])
21+
22+
23+
@mcp.tool()
24+
def get_current_weather(city: str) -> str:
25+
print(f"[debug-server] get_current_weather({city})")
26+
27+
endpoint = "https://wttr.in"
28+
response = requests.get(f"{endpoint}/{city}")
29+
return response.text
30+
31+
@mcp.prompt()
32+
def system_prompt() -> str:
33+
return "Use the tools to answer the questions."
34+
35+
36+
if __name__ == "__main__":
37+
mcp.run(transport="streamable-http")

src/agents/mcp/server.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
8484
def __init__(
8585
self,
8686
cache_tools_list: bool,
87+
cache_prompts_list: bool,
8788
client_session_timeout_seconds: float | None,
8889
tool_filter: ToolFilter = None,
8990
):
@@ -96,20 +97,30 @@ def __init__(
9697
server will not change its tools list, because it can drastically improve latency
9798
(by avoiding a round-trip to the server every time).
9899
100+
cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list will be
101+
cached and only fetched from the server once. If `False`, the prompts list will be
102+
fetched from the server on each call to `list_prompts()`. The cache can be invalidated
103+
by calling `invalidate_prompts_cache()`. You should set this to `True` if you know the
104+
server will not change its prompts list, because it can drastically improve latency
105+
(by avoiding a round-trip to the server every time).
106+
99107
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
100108
tool_filter: The tool filter to use for filtering tools.
101109
"""
102110
self.session: ClientSession | None = None
103111
self.exit_stack: AsyncExitStack = AsyncExitStack()
104112
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
105113
self.cache_tools_list = cache_tools_list
114+
self.cache_prompts_list = cache_prompts_list
106115
self.server_initialize_result: InitializeResult | None = None
107116

108117
self.client_session_timeout_seconds = client_session_timeout_seconds
109118

110-
# The cache is always dirty at startup, so that we fetch tools at least once
111-
self._cache_dirty = True
119+
# The cache is always dirty at startup, so that we fetch tools and prompts at least once
120+
self._cache_dirty_tools = True
112121
self._tools_list: list[MCPTool] | None = None
122+
self._cache_dirty_prompts = True
123+
self._prompts_list: ListPromptsResult | None = None
113124

114125
self.tool_filter = tool_filter
115126

@@ -213,7 +224,11 @@ async def __aexit__(self, exc_type, exc_value, traceback):
213224

214225
def invalidate_tools_cache(self):
215226
"""Invalidate the tools cache."""
216-
self._cache_dirty = True
227+
self._cache_dirty_tools = True
228+
229+
def invalidate_prompts_cache(self):
230+
"""Invalidate the prompts cache."""
231+
self._cache_dirty_prompts = True
217232

218233
async def connect(self):
219234
"""Connect to the server."""
@@ -251,11 +266,11 @@ async def list_tools(
251266
raise UserError("Server not initialized. Make sure you call `connect()` first.")
252267

253268
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
254-
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
269+
if self.cache_tools_list and not self._cache_dirty_tools and self._tools_list:
255270
tools = self._tools_list
256271
else:
257272
# Reset the cache dirty to False
258-
self._cache_dirty = False
273+
self._cache_dirty_tools = False
259274
# Fetch the tools from the server
260275
self._tools_list = (await self.session.list_tools()).tools
261276
tools = self._tools_list
@@ -282,7 +297,16 @@ async def list_prompts(
282297
if not self.session:
283298
raise UserError("Server not initialized. Make sure you call `connect()` first.")
284299

285-
return await self.session.list_prompts()
300+
if self.cache_prompts_list and not self._cache_dirty_prompts and self._prompts_list:
301+
prompts = self._prompts_list
302+
else:
303+
# Reset the cache dirty to False
304+
self._cache_dirty_prompts = False
305+
# Fetch the prompts from the server
306+
self._prompts_list = await self.session.list_prompts()
307+
prompts = self._tools_list
308+
309+
return prompts
286310

287311
async def get_prompt(
288312
self, name: str, arguments: dict[str, Any] | None = None
@@ -343,6 +367,7 @@ def __init__(
343367
self,
344368
params: MCPServerStdioParams,
345369
cache_tools_list: bool = False,
370+
cache_prompts_list: bool = False,
346371
name: str | None = None,
347372
client_session_timeout_seconds: float | None = 5,
348373
tool_filter: ToolFilter = None,
@@ -354,21 +379,31 @@ def __init__(
354379
start the server, the args to pass to the command, the environment variables to
355380
set for the server, the working directory to use when spawning the process, and
356381
the text encoding used when sending/receiving messages to the server.
382+
357383
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
358384
cached and only fetched from the server once. If `False`, the tools list will be
359385
fetched from the server on each call to `list_tools()`. The cache can be
360386
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
361387
if you know the server will not change its tools list, because it can drastically
362388
improve latency (by avoiding a round-trip to the server every time).
389+
390+
cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list will be
391+
cached and only fetched from the server once. If `False`, the prompts list will be
392+
fetched from the server on each call to `list_prompts()`. The cache can be invalidated
393+
by calling `invalidate_prompts_cache()`. You should set this to `True` if you know the
394+
server will not change its prompts list, because it can drastically improve latency
395+
(by avoiding a round-trip to the server every time).
396+
363397
name: A readable name for the server. If not provided, we'll create one from the
364398
command.
365399
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
366400
tool_filter: The tool filter to use for filtering tools.
367401
"""
368402
super().__init__(
369-
cache_tools_list,
370-
client_session_timeout_seconds,
371-
tool_filter,
403+
cache_tools_list=cache_tools_list,
404+
cache_prompts_list=cache_prompts_list,
405+
client_session_timeout_seconds=client_session_timeout_seconds,
406+
tool_filter=tool_filter,
372407
)
373408

374409
self.params = StdioServerParameters(
@@ -426,6 +461,7 @@ def __init__(
426461
self,
427462
params: MCPServerSseParams,
428463
cache_tools_list: bool = False,
464+
cache_prompts_list: bool = False,
429465
name: str | None = None,
430466
client_session_timeout_seconds: float | None = 5,
431467
tool_filter: ToolFilter = None,
@@ -444,16 +480,24 @@ def __init__(
444480
if you know the server will not change its tools list, because it can drastically
445481
improve latency (by avoiding a round-trip to the server every time).
446482
483+
cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list will be
484+
cached and only fetched from the server once. If `False`, the prompts list will be
485+
fetched from the server on each call to `list_prompts()`. The cache can be invalidated
486+
by calling `invalidate_prompts_cache()`. You should set this to `True` if you know the
487+
server will not change its prompts list, because it can drastically improve latency
488+
(by avoiding a round-trip to the server every time).
489+
447490
name: A readable name for the server. If not provided, we'll create one from the
448491
URL.
449492
450493
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
451494
tool_filter: The tool filter to use for filtering tools.
452495
"""
453496
super().__init__(
454-
cache_tools_list,
455-
client_session_timeout_seconds,
456-
tool_filter,
497+
cache_tools_list=cache_tools_list,
498+
cache_prompts_list=cache_prompts_list,
499+
client_session_timeout_seconds=client_session_timeout_seconds,
500+
tool_filter=tool_filter,
457501
)
458502

459503
self.params = params
@@ -511,6 +555,7 @@ def __init__(
511555
self,
512556
params: MCPServerStreamableHttpParams,
513557
cache_tools_list: bool = False,
558+
cache_prompts_list: bool = False,
514559
name: str | None = None,
515560
client_session_timeout_seconds: float | None = 5,
516561
tool_filter: ToolFilter = None,
@@ -530,16 +575,24 @@ def __init__(
530575
if you know the server will not change its tools list, because it can drastically
531576
improve latency (by avoiding a round-trip to the server every time).
532577
578+
cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list will be
579+
cached and only fetched from the server once. If `False`, the prompts list will be
580+
fetched from the server on each call to `list_prompts()`. The cache can be invalidated
581+
by calling `invalidate_prompts_cache()`. You should set this to `True` if you know the
582+
server will not change its prompts list, because it can drastically improve latency
583+
(by avoiding a round-trip to the server every time).
584+
533585
name: A readable name for the server. If not provided, we'll create one from the
534586
URL.
535587
536588
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
537589
tool_filter: The tool filter to use for filtering tools.
538590
"""
539591
super().__init__(
540-
cache_tools_list,
541-
client_session_timeout_seconds,
542-
tool_filter,
592+
cache_tools_list=cache_tools_list,
593+
cache_prompts_list=cache_prompts_list,
594+
client_session_timeout_seconds=client_session_timeout_seconds,
595+
tool_filter=tool_filter,
543596
)
544597

545598
self.params = params

0 commit comments

Comments
 (0)