1111from dataclasses import field , replace
1212from datetime import timedelta
1313from pathlib import Path
14- from typing import Any
14+ from typing import Any , Literal
1515
1616import anyio
1717import httpx
1818import pydantic_core
1919from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
20+ from pydantic import BaseModel , HttpUrl
2021from typing_extensions import Self , assert_never , deprecated
2122
2223from pydantic_ai .tools import RunContext , ToolDefinition
@@ -838,21 +839,35 @@ def _transport_client(self):
838839"""
839840
840841
842+ class StdioOptions (BaseModel ):
843+ command : str
844+ args : Sequence [str ] = ()
845+ env : dict [str , Any ] | None = None
846+
847+
848+ class ServerOptions (BaseModel ):
849+ type : Literal ['sse' , 'http' , 'streamable-http' , 'streamable_http' ]
850+ url : HttpUrl
851+ headers : dict [str , Any ] | None = None
852+
853+
841854def load_mcp_servers (
842- config_path : str = '' ,
855+ config_path : str | None = None ,
843856 / ,
844857 * ,
845858 mcp_config : dict [str , Any ] = {},
846- server_options : dict [str , Any ] = {},
847- stdio_options : dict [str , Any ] = {},
859+ timeout : float = 5 ,
860+ read_timeout : float = 5 * 60 ,
861+ max_retries : int = 1 ,
848862) -> list [MCPServer ]:
849863 """Load MCP servers from configuration file.
850864
851865 Args:
852866 config_path (str): The path to the MCP configuration file.
853867 mcp_config (dict): A dictionary containing the MCP server configuration. If provided, `config_path` is ignored.
854- server_options (dict): Additional options to pass to the `MCPServerSSE` or `MCPServerStreamableHTTP`.
855- stdio_options (dict): Additional options to pass to the `MCPServerStdio`.
868+ timeout (float): The timeout in seconds to wait for the client to initialize.
869+ read_timeout (float): Maximum time in seconds to wait for new messages before timing out.
870+ max_retries (int): The maximum number of times to retry a tool call.
856871
857872 Returns:
858873 list[MCPServer]: A list of MCP servers.
@@ -866,22 +881,38 @@ def load_mcp_servers(
866881
867882 for name , server in config .get ('mcpServers' , {}).items ():
868883 if 'command' in server :
884+ options = StdioOptions (** server )
869885 mcp_server = MCPServerStdio (
870- command = server ['command' ], args = server .get ('args' , []), env = server .get ('env' ), id = name , ** stdio_options
886+ command = options .command ,
887+ args = options .args ,
888+ env = options .env ,
889+ id = name ,
890+ timeout = timeout ,
891+ read_timeout = read_timeout ,
892+ max_retries = max_retries ,
871893 )
872- elif 'url' in server :
873- if not server .get ('type' ):
874- raise ValueError (f'MCP server type is required for { name !r} ' )
875- elif server .get ('type' ) == 'sse' :
876- mcp_server = MCPServerSSE (url = server ['url' ], headers = server .get ('headers' ), id = name , ** server_options )
877- elif server .get ('type' ) == 'http' :
894+ else :
895+ options = ServerOptions (** server )
896+ if options .type == 'sse' :
897+ mcp_server = MCPServerSSE (
898+ url = options .url ,
899+ headers = options .headers ,
900+ id = name ,
901+ timeout = timeout ,
902+ read_timeout = read_timeout ,
903+ max_retries = max_retries ,
904+ )
905+ elif options .type in ('http' , 'streamable-http' , 'streamable_http' ):
878906 mcp_server = MCPServerStreamableHTTP (
879- url = server ['url' ], headers = server .get ('headers' ), id = name , ** server_options
907+ url = options .url ,
908+ headers = options .headers ,
909+ id = name ,
910+ timeout = timeout ,
911+ read_timeout = read_timeout ,
912+ max_retries = max_retries ,
880913 )
881914 else :
882915 raise ValueError (f'Invalid MCP server type for { name !r} ' )
883- else :
884- raise ValueError (f'Invalid MCP server configuration for { name !r} ' )
885916
886917 mcp_servers .append (mcp_server )
887918
0 commit comments