44from mcp import ClientSession , Tool
55from mcp .types import ImageContent , TextContent
66from mcp .client .sse import sse_client
7+ from mcp .client .stdio import stdio_client , StdioServerParameters
78import jsonschema
89
910from .aswarm import Agent as SwarmAgent
@@ -23,7 +24,7 @@ def __init__(
2324 super ().__init__ (system = system , client = client )
2425
2526 self ._mcp_server_base_url : str = mcp_server_base_url
26- self ._mcp_sse_client : AsyncContextManager [tuple ] | None = None
27+ self ._mcp_client_transport : AsyncContextManager [tuple ] | None = None
2728 self ._mcp_client_session : ClientSession | None = None
2829
2930 self ._mcp_swarm_agent : SwarmAgent | None = None
@@ -34,7 +35,62 @@ def mcp_server_base_url(self) -> str:
3435 raise ValueError ("MCP server base URL is empty" )
3536 return self ._mcp_server_base_url
3637
37- def make_tool (self , t : Tool ) -> Callable :
38+ def _make_mcp_client_transport (self ) -> AsyncContextManager [tuple ]:
39+ if self .mcp_server_base_url .startswith (("http://" , "https://" )):
40+ url = urljoin (self .mcp_server_base_url , "sse" )
41+ return sse_client (url = url )
42+ else :
43+ # Mainly for testing purposes.
44+ command , arg = self .mcp_server_base_url .split (" " , 1 )
45+ params = StdioServerParameters (command = command , args = [arg ])
46+ return stdio_client (params )
47+
48+ async def started (self ) -> None :
49+ """
50+ Combining `started` and `stopped` to achieve the following behavior:
51+
52+ async with sse_client(url=url) as (read, write):
53+ async with ClientSession(read, write) as session:
54+ pass
55+ """
56+ self ._mcp_client_transport = self ._make_mcp_client_transport ()
57+ read , write = await self ._mcp_client_transport .__aenter__ ()
58+
59+ self ._mcp_client_session = ClientSession (read , write )
60+ await self ._mcp_client_session .__aenter__ ()
61+
62+ # Initialize the connection
63+ await self ._mcp_client_session .initialize ()
64+
65+ async def stopped (self ) -> None :
66+ await self ._mcp_client_session .__aexit__ (None , None , None )
67+ await self ._mcp_client_transport .__aexit__ (None , None , None )
68+
69+ async def _handle_data (self ) -> None :
70+ """Override the method to handle exceptions properly."""
71+ try :
72+ await super ()._handle_data ()
73+ finally :
74+ # Ensure the resources created in `started` are properly cleaned up.
75+ await self .stopped ()
76+
77+ async def get_swarm_agent (self ) -> SwarmAgent :
78+ if not self ._mcp_swarm_agent :
79+ tools = await self ._get_tools ()
80+ self ._mcp_swarm_agent = SwarmAgent (
81+ name = self .name ,
82+ model = self .client .model ,
83+ instructions = self .system ,
84+ functions = [wrap_error (t ) for t in tools ],
85+ )
86+ return self ._mcp_swarm_agent
87+
88+ async def _get_tools (self ) -> list [Callable ]:
89+ result = await self ._mcp_client_session .list_tools ()
90+ tools = [self ._make_tool (t ) for t in result .tools ]
91+ return tools
92+
93+ def _make_tool (self , t : Tool ) -> Callable :
3894 async def tool (** kwargs ) -> Any :
3995 # Validate the input against the schema
4096 jsonschema .validate (instance = kwargs , schema = t .inputSchema )
@@ -64,51 +120,5 @@ async def tool(**kwargs) -> Any:
64120 description = t .description ,
65121 parameters = t .inputSchema ,
66122 )
67- tool .__mcp_tool_args__ = t .inputSchema ["properties" ].keys ()
123+ tool .__mcp_tool_args__ = tuple ( t .inputSchema ["properties" ].keys () )
68124 return tool
69-
70- async def get_tools (self ) -> list [Callable ]:
71- result = await self ._mcp_client_session .list_tools ()
72- tools = [self .make_tool (t ) for t in result .tools ]
73- return tools
74-
75- async def get_swarm_agent (self ) -> SwarmAgent :
76- if not self ._mcp_swarm_agent :
77- tools = await self .get_tools ()
78- self ._mcp_swarm_agent = SwarmAgent (
79- name = self .name ,
80- model = self .client .model ,
81- instructions = self .system ,
82- functions = [wrap_error (t ) for t in tools ],
83- )
84- return self ._mcp_swarm_agent
85-
86- async def started (self ) -> None :
87- """
88- Combining `started` and `stopped` to achieve the following behavior:
89-
90- async with sse_client(url=url) as (read, write):
91- async with ClientSession(read, write) as session:
92- pass
93- """
94- url = urljoin (self .mcp_server_base_url , "sse" )
95- self ._mcp_sse_client = sse_client (url = url )
96- read , write = await self ._mcp_sse_client .__aenter__ ()
97-
98- self ._mcp_client_session = ClientSession (read , write )
99- await self ._mcp_client_session .__aenter__ ()
100-
101- # Initialize the connection
102- await self ._mcp_client_session .initialize ()
103-
104- async def stopped (self ) -> None :
105- await self ._mcp_client_session .__aexit__ (None , None , None )
106- await self ._mcp_sse_client .__aexit__ (None , None , None )
107-
108- async def _handle_data (self ) -> None :
109- """Override the method to handle exceptions properly."""
110- try :
111- await super ()._handle_data ()
112- finally :
113- # Ensure the resources created in `started` are properly cleaned up.
114- await self .stopped ()
0 commit comments