Skip to content

Commit 2c5041d

Browse files
authored
feat: add retry logic to MCP server operations (#1554)
## Summary - allow configuring retries with exponential backoff when listing tools or calling a tool on an MCP server via `max_retry_attempts` (supporting `-1` for unlimited retries) and `retry_backoff_seconds_base` - propagate the retry parameters through the stdio, SSE, and streamable HTTP server implementations so callers can tune retries when constructing these servers - test that `call_tool` and `list_tools` retry appropriately ## Testing - `make lint` - `make mypy` - `make test` *(fails: No rule to make target 'test')* - `make tests` - `make old_version_tests` *(fails: Request failed after 3 retries (tunnel error))* ------ https://chatgpt.com/codex/tasks/task_i_68a73cab8b9c8321876f6f3dd1dfcd20
1 parent dcf9cf7 commit 2c5041d

File tree

2 files changed

+120
-5
lines changed

2 files changed

+120
-5
lines changed

src/agents/mcp/server.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import abc
44
import asyncio
55
import inspect
6+
from collections.abc import Awaitable
67
from contextlib import AbstractAsyncContextManager, AsyncExitStack
78
from datetime import timedelta
89
from pathlib import Path
9-
from typing import TYPE_CHECKING, Any, Literal
10+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
1011

1112
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1213
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
@@ -21,6 +22,8 @@
2122
from ..run_context import RunContextWrapper
2223
from .util import ToolFilter, ToolFilterContext, ToolFilterStatic
2324

25+
T = TypeVar("T")
26+
2427
if TYPE_CHECKING:
2528
from ..agent import AgentBase
2629

@@ -98,6 +101,8 @@ def __init__(
98101
client_session_timeout_seconds: float | None,
99102
tool_filter: ToolFilter = None,
100103
use_structured_content: bool = False,
104+
max_retry_attempts: int = 0,
105+
retry_backoff_seconds_base: float = 1.0,
101106
):
102107
"""
103108
Args:
@@ -115,6 +120,10 @@ def __init__(
115120
include the structured content in the `tool_result.content`, and using it by
116121
default will cause duplicate content. You can set this to True if you know the
117122
server will not duplicate the structured content in the `tool_result.content`.
123+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
124+
Defaults to no retries.
125+
retry_backoff_seconds_base: The base delay, in seconds, used for exponential
126+
backoff between retries.
118127
"""
119128
super().__init__(use_structured_content=use_structured_content)
120129
self.session: ClientSession | None = None
@@ -124,6 +133,8 @@ def __init__(
124133
self.server_initialize_result: InitializeResult | None = None
125134

126135
self.client_session_timeout_seconds = client_session_timeout_seconds
136+
self.max_retry_attempts = max_retry_attempts
137+
self.retry_backoff_seconds_base = retry_backoff_seconds_base
127138

128139
# The cache is always dirty at startup, so that we fetch tools at least once
129140
self._cache_dirty = True
@@ -233,6 +244,18 @@ def invalidate_tools_cache(self):
233244
"""Invalidate the tools cache."""
234245
self._cache_dirty = True
235246

247+
async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T:
248+
attempts = 0
249+
while True:
250+
try:
251+
return await func()
252+
except Exception:
253+
attempts += 1
254+
if self.max_retry_attempts != -1 and attempts > self.max_retry_attempts:
255+
raise
256+
backoff = self.retry_backoff_seconds_base * (2 ** (attempts - 1))
257+
await asyncio.sleep(backoff)
258+
236259
async def connect(self):
237260
"""Connect to the server."""
238261
try:
@@ -267,15 +290,17 @@ async def list_tools(
267290
"""List the tools available on the server."""
268291
if not self.session:
269292
raise UserError("Server not initialized. Make sure you call `connect()` first.")
293+
session = self.session
294+
assert session is not None
270295

271296
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
272297
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
273298
tools = self._tools_list
274299
else:
275-
# Reset the cache dirty to False
276-
self._cache_dirty = False
277300
# Fetch the tools from the server
278-
self._tools_list = (await self.session.list_tools()).tools
301+
result = await self._run_with_retries(lambda: session.list_tools())
302+
self._tools_list = result.tools
303+
self._cache_dirty = False
279304
tools = self._tools_list
280305

281306
# Filter tools based on tool_filter
@@ -290,8 +315,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
290315
"""Invoke a tool on the server."""
291316
if not self.session:
292317
raise UserError("Server not initialized. Make sure you call `connect()` first.")
318+
session = self.session
319+
assert session is not None
293320

294-
return await self.session.call_tool(tool_name, arguments)
321+
return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
295322

296323
async def list_prompts(
297324
self,
@@ -365,6 +392,8 @@ def __init__(
365392
client_session_timeout_seconds: float | None = 5,
366393
tool_filter: ToolFilter = None,
367394
use_structured_content: bool = False,
395+
max_retry_attempts: int = 0,
396+
retry_backoff_seconds_base: float = 1.0,
368397
):
369398
"""Create a new MCP server based on the stdio transport.
370399
@@ -388,12 +417,18 @@ def __init__(
388417
include the structured content in the `tool_result.content`, and using it by
389418
default will cause duplicate content. You can set this to True if you know the
390419
server will not duplicate the structured content in the `tool_result.content`.
420+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
421+
Defaults to no retries.
422+
retry_backoff_seconds_base: The base delay, in seconds, for exponential
423+
backoff between retries.
391424
"""
392425
super().__init__(
393426
cache_tools_list,
394427
client_session_timeout_seconds,
395428
tool_filter,
396429
use_structured_content,
430+
max_retry_attempts,
431+
retry_backoff_seconds_base,
397432
)
398433

399434
self.params = StdioServerParameters(
@@ -455,6 +490,8 @@ def __init__(
455490
client_session_timeout_seconds: float | None = 5,
456491
tool_filter: ToolFilter = None,
457492
use_structured_content: bool = False,
493+
max_retry_attempts: int = 0,
494+
retry_backoff_seconds_base: float = 1.0,
458495
):
459496
"""Create a new MCP server based on the HTTP with SSE transport.
460497
@@ -480,12 +517,18 @@ def __init__(
480517
include the structured content in the `tool_result.content`, and using it by
481518
default will cause duplicate content. You can set this to True if you know the
482519
server will not duplicate the structured content in the `tool_result.content`.
520+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
521+
Defaults to no retries.
522+
retry_backoff_seconds_base: The base delay, in seconds, for exponential
523+
backoff between retries.
483524
"""
484525
super().__init__(
485526
cache_tools_list,
486527
client_session_timeout_seconds,
487528
tool_filter,
488529
use_structured_content,
530+
max_retry_attempts,
531+
retry_backoff_seconds_base,
489532
)
490533

491534
self.params = params
@@ -547,6 +590,8 @@ def __init__(
547590
client_session_timeout_seconds: float | None = 5,
548591
tool_filter: ToolFilter = None,
549592
use_structured_content: bool = False,
593+
max_retry_attempts: int = 0,
594+
retry_backoff_seconds_base: float = 1.0,
550595
):
551596
"""Create a new MCP server based on the Streamable HTTP transport.
552597
@@ -573,12 +618,18 @@ def __init__(
573618
include the structured content in the `tool_result.content`, and using it by
574619
default will cause duplicate content. You can set this to True if you know the
575620
server will not duplicate the structured content in the `tool_result.content`.
621+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
622+
Defaults to no retries.
623+
retry_backoff_seconds_base: The base delay, in seconds, for exponential
624+
backoff between retries.
576625
"""
577626
super().__init__(
578627
cache_tools_list,
579628
client_session_timeout_seconds,
580629
tool_filter,
581630
use_structured_content,
631+
max_retry_attempts,
632+
retry_backoff_seconds_base,
582633
)
583634

584635
self.params = params
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import cast
2+
3+
import pytest
4+
from mcp import ClientSession, Tool as MCPTool
5+
from mcp.types import CallToolResult, ListToolsResult
6+
7+
from agents.mcp.server import _MCPServerWithClientSession
8+
9+
10+
class DummySession:
11+
def __init__(self, fail_call_tool: int = 0, fail_list_tools: int = 0):
12+
self.fail_call_tool = fail_call_tool
13+
self.fail_list_tools = fail_list_tools
14+
self.call_tool_attempts = 0
15+
self.list_tools_attempts = 0
16+
17+
async def call_tool(self, tool_name, arguments):
18+
self.call_tool_attempts += 1
19+
if self.call_tool_attempts <= self.fail_call_tool:
20+
raise RuntimeError("call_tool failure")
21+
return CallToolResult(content=[])
22+
23+
async def list_tools(self):
24+
self.list_tools_attempts += 1
25+
if self.list_tools_attempts <= self.fail_list_tools:
26+
raise RuntimeError("list_tools failure")
27+
return ListToolsResult(tools=[MCPTool(name="tool", inputSchema={})])
28+
29+
30+
class DummyServer(_MCPServerWithClientSession):
31+
def __init__(self, session: DummySession, retries: int):
32+
super().__init__(
33+
cache_tools_list=False,
34+
client_session_timeout_seconds=None,
35+
max_retry_attempts=retries,
36+
retry_backoff_seconds_base=0,
37+
)
38+
self.session = cast(ClientSession, session)
39+
40+
def create_streams(self):
41+
raise NotImplementedError
42+
43+
@property
44+
def name(self) -> str:
45+
return "dummy"
46+
47+
48+
@pytest.mark.asyncio
49+
async def test_call_tool_retries_until_success():
50+
session = DummySession(fail_call_tool=2)
51+
server = DummyServer(session=session, retries=2)
52+
result = await server.call_tool("tool", None)
53+
assert isinstance(result, CallToolResult)
54+
assert session.call_tool_attempts == 3
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_list_tools_unlimited_retries():
59+
session = DummySession(fail_list_tools=3)
60+
server = DummyServer(session=session, retries=-1)
61+
tools = await server.list_tools()
62+
assert len(tools) == 1
63+
assert tools[0].name == "tool"
64+
assert session.list_tools_attempts == 4

0 commit comments

Comments
 (0)