Skip to content

Commit 8397b17

Browse files
committed
Fixing unit tests
1 parent 74ae4f7 commit 8397b17

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

examples/mcp/caching/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import time
66
from typing import Any
77

8+
from mcp.types import ListPromptsResult
9+
810
from agents import gen_trace_id, trace
911
from agents.mcp import MCPServerStreamableHttp
1012

@@ -25,8 +27,11 @@ async def run(mcp_server: MCPServerStreamableHttp):
2527
print("Cached prompts after invoking list_prompts")
2628
await mcp_server.list_prompts()
2729
cached_prompts_list = mcp_server._prompts_list
28-
for prompt in cached_prompts_list.prompts:
29-
print(f"name: {prompt.name}")
30+
if isinstance(cached_prompts_list, ListPromptsResult):
31+
for prompt in cached_prompts_list.prompts:
32+
print(f"name: {prompt.name}")
33+
else:
34+
print("Failed to cache list_prompts")
3035

3136
async def main():
3237
async with MCPServerStreamableHttp(

tests/mcp/test_caching.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,32 +84,33 @@ async def test_server_caching_prompts_works(
8484
Prompt(name="prompt2"),
8585
]
8686

87-
mock_list_prompts.return_value = ListPromptsResult(prompts=prompts)
87+
list_prompts = ListPromptsResult(prompts=prompts)
88+
mock_list_prompts.return_value = list_prompts
8889

8990
async with server:
9091

9192
# Call list_prompts() multiple times
9293
result_prompts = await server.list_prompts()
93-
assert result_prompts == prompts
94+
assert result_prompts == list_prompts
9495

9596
assert mock_list_prompts.call_count == 1, "list_prompts() should have been called once"
9697

9798
# Call list_prompts() again, should return the cached value
9899
result_prompts = await server.list_prompts()
99-
assert result_prompts == prompts
100+
assert result_prompts == list_prompts
100101

101102
assert mock_list_prompts.call_count == 1, ("list_prompts() "
102103
"should not have been called again")
103104

104105
# Invalidate the cache and call list_prompts() again
105106
server.invalidate_prompts_cache()
106107
result_prompts = await server.list_prompts()
107-
assert result_prompts == prompts
108+
assert result_prompts == list_prompts
108109

109110
assert mock_list_prompts.call_count == 2, ("list_prompts() "
110111
"should be called again")
111112

112113
# Without invalidating the cache, calling list_prompts()
113114
# again should return the cached value
114115
result_prompts = await server.list_prompts()
115-
assert result_prompts == prompts
116+
assert result_prompts == list_prompts

0 commit comments

Comments
 (0)