3
3
import abc
4
4
import asyncio
5
5
import inspect
6
+ from collections .abc import Awaitable
6
7
from contextlib import AbstractAsyncContextManager , AsyncExitStack
7
8
from datetime import timedelta
8
9
from pathlib import Path
9
- from typing import TYPE_CHECKING , Any , Literal
10
+ from typing import TYPE_CHECKING , Any , Callable , Literal , TypeVar
10
11
11
12
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
12
13
from mcp import ClientSession , StdioServerParameters , Tool as MCPTool , stdio_client
21
22
from ..run_context import RunContextWrapper
22
23
from .util import ToolFilter , ToolFilterContext , ToolFilterStatic
23
24
25
+ T = TypeVar ("T" )
26
+
24
27
if TYPE_CHECKING :
25
28
from ..agent import AgentBase
26
29
@@ -98,6 +101,8 @@ def __init__(
98
101
client_session_timeout_seconds : float | None ,
99
102
tool_filter : ToolFilter = None ,
100
103
use_structured_content : bool = False ,
104
+ max_retry_attempts : int = 0 ,
105
+ retry_backoff_seconds_base : float = 1.0 ,
101
106
):
102
107
"""
103
108
Args:
@@ -115,6 +120,10 @@ def __init__(
115
120
include the structured content in the `tool_result.content`, and using it by
116
121
default will cause duplicate content. You can set this to True if you know the
117
122
server will not duplicate the structured content in the `tool_result.content`.
123
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
124
+ Defaults to no retries.
125
+ retry_backoff_seconds_base: The base delay, in seconds, used for exponential
126
+ backoff between retries.
118
127
"""
119
128
super ().__init__ (use_structured_content = use_structured_content )
120
129
self .session : ClientSession | None = None
@@ -124,6 +133,8 @@ def __init__(
124
133
self .server_initialize_result : InitializeResult | None = None
125
134
126
135
self .client_session_timeout_seconds = client_session_timeout_seconds
136
+ self .max_retry_attempts = max_retry_attempts
137
+ self .retry_backoff_seconds_base = retry_backoff_seconds_base
127
138
128
139
# The cache is always dirty at startup, so that we fetch tools at least once
129
140
self ._cache_dirty = True
@@ -233,6 +244,18 @@ def invalidate_tools_cache(self):
233
244
"""Invalidate the tools cache."""
234
245
self ._cache_dirty = True
235
246
247
+ async def _run_with_retries (self , func : Callable [[], Awaitable [T ]]) -> T :
248
+ attempts = 0
249
+ while True :
250
+ try :
251
+ return await func ()
252
+ except Exception :
253
+ attempts += 1
254
+ if self .max_retry_attempts != - 1 and attempts > self .max_retry_attempts :
255
+ raise
256
+ backoff = self .retry_backoff_seconds_base * (2 ** (attempts - 1 ))
257
+ await asyncio .sleep (backoff )
258
+
236
259
async def connect (self ):
237
260
"""Connect to the server."""
238
261
try :
@@ -267,15 +290,17 @@ async def list_tools(
267
290
"""List the tools available on the server."""
268
291
if not self .session :
269
292
raise UserError ("Server not initialized. Make sure you call `connect()` first." )
293
+ session = self .session
294
+ assert session is not None
270
295
271
296
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
272
297
if self .cache_tools_list and not self ._cache_dirty and self ._tools_list :
273
298
tools = self ._tools_list
274
299
else :
275
- # Reset the cache dirty to False
276
- self ._cache_dirty = False
277
300
# Fetch the tools from the server
278
- self ._tools_list = (await self .session .list_tools ()).tools
301
+ result = await self ._run_with_retries (lambda : session .list_tools ())
302
+ self ._tools_list = result .tools
303
+ self ._cache_dirty = False
279
304
tools = self ._tools_list
280
305
281
306
# Filter tools based on tool_filter
@@ -290,8 +315,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
290
315
"""Invoke a tool on the server."""
291
316
if not self .session :
292
317
raise UserError ("Server not initialized. Make sure you call `connect()` first." )
318
+ session = self .session
319
+ assert session is not None
293
320
294
- return await self .session .call_tool (tool_name , arguments )
321
+ return await self ._run_with_retries ( lambda : session .call_tool (tool_name , arguments ) )
295
322
296
323
async def list_prompts (
297
324
self ,
@@ -365,6 +392,8 @@ def __init__(
365
392
client_session_timeout_seconds : float | None = 5 ,
366
393
tool_filter : ToolFilter = None ,
367
394
use_structured_content : bool = False ,
395
+ max_retry_attempts : int = 0 ,
396
+ retry_backoff_seconds_base : float = 1.0 ,
368
397
):
369
398
"""Create a new MCP server based on the stdio transport.
370
399
@@ -388,12 +417,18 @@ def __init__(
388
417
include the structured content in the `tool_result.content`, and using it by
389
418
default will cause duplicate content. You can set this to True if you know the
390
419
server will not duplicate the structured content in the `tool_result.content`.
420
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
421
+ Defaults to no retries.
422
+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
423
+ backoff between retries.
391
424
"""
392
425
super ().__init__ (
393
426
cache_tools_list ,
394
427
client_session_timeout_seconds ,
395
428
tool_filter ,
396
429
use_structured_content ,
430
+ max_retry_attempts ,
431
+ retry_backoff_seconds_base ,
397
432
)
398
433
399
434
self .params = StdioServerParameters (
@@ -455,6 +490,8 @@ def __init__(
455
490
client_session_timeout_seconds : float | None = 5 ,
456
491
tool_filter : ToolFilter = None ,
457
492
use_structured_content : bool = False ,
493
+ max_retry_attempts : int = 0 ,
494
+ retry_backoff_seconds_base : float = 1.0 ,
458
495
):
459
496
"""Create a new MCP server based on the HTTP with SSE transport.
460
497
@@ -480,12 +517,18 @@ def __init__(
480
517
include the structured content in the `tool_result.content`, and using it by
481
518
default will cause duplicate content. You can set this to True if you know the
482
519
server will not duplicate the structured content in the `tool_result.content`.
520
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
521
+ Defaults to no retries.
522
+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
523
+ backoff between retries.
483
524
"""
484
525
super ().__init__ (
485
526
cache_tools_list ,
486
527
client_session_timeout_seconds ,
487
528
tool_filter ,
488
529
use_structured_content ,
530
+ max_retry_attempts ,
531
+ retry_backoff_seconds_base ,
489
532
)
490
533
491
534
self .params = params
@@ -547,6 +590,8 @@ def __init__(
547
590
client_session_timeout_seconds : float | None = 5 ,
548
591
tool_filter : ToolFilter = None ,
549
592
use_structured_content : bool = False ,
593
+ max_retry_attempts : int = 0 ,
594
+ retry_backoff_seconds_base : float = 1.0 ,
550
595
):
551
596
"""Create a new MCP server based on the Streamable HTTP transport.
552
597
@@ -573,12 +618,18 @@ def __init__(
573
618
include the structured content in the `tool_result.content`, and using it by
574
619
default will cause duplicate content. You can set this to True if you know the
575
620
server will not duplicate the structured content in the `tool_result.content`.
621
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
622
+ Defaults to no retries.
623
+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
624
+ backoff between retries.
576
625
"""
577
626
super ().__init__ (
578
627
cache_tools_list ,
579
628
client_session_timeout_seconds ,
580
629
tool_filter ,
581
630
use_structured_content ,
631
+ max_retry_attempts ,
632
+ retry_backoff_seconds_base ,
582
633
)
583
634
584
635
self .params = params
0 commit comments