Skip to content

Commit 2b2fd29

Browse files
committed
Add tests for MCPAgent
1 parent 3a6c345 commit 2b2fd29

File tree

3 files changed

+109
-49
lines changed

3 files changed

+109
-49
lines changed

coagent/agents/mcp_agent.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from mcp import ClientSession, Tool
55
from mcp.types import ImageContent, TextContent
66
from mcp.client.sse import sse_client
7+
from mcp.client.stdio import stdio_client, StdioServerParameters
78
import jsonschema
89

910
from .aswarm import Agent as SwarmAgent
@@ -23,7 +24,7 @@ def __init__(
2324
super().__init__(system=system, client=client)
2425

2526
self._mcp_server_base_url: str = mcp_server_base_url
26-
self._mcp_sse_client: AsyncContextManager[tuple] | None = None
27+
self._mcp_client_transport: AsyncContextManager[tuple] | None = None
2728
self._mcp_client_session: ClientSession | None = None
2829

2930
self._mcp_swarm_agent: SwarmAgent | None = None
@@ -34,7 +35,62 @@ def mcp_server_base_url(self) -> str:
3435
raise ValueError("MCP server base URL is empty")
3536
return self._mcp_server_base_url
3637

37-
def make_tool(self, t: Tool) -> Callable:
38+
def _make_mcp_client_transport(self) -> AsyncContextManager[tuple]:
39+
if self.mcp_server_base_url.startswith(("http://", "https://")):
40+
url = urljoin(self.mcp_server_base_url, "sse")
41+
return sse_client(url=url)
42+
else:
43+
# Mainly for testing purposes.
44+
command, arg = self.mcp_server_base_url.split(" ", 1)
45+
params = StdioServerParameters(command=command, args=[arg])
46+
return stdio_client(params)
47+
48+
async def started(self) -> None:
49+
"""
50+
Combining `started` and `stopped` to achieve the following behavior:
51+
52+
async with sse_client(url=url) as (read, write):
53+
async with ClientSession(read, write) as session:
54+
pass
55+
"""
56+
self._mcp_client_transport = self._make_mcp_client_transport()
57+
read, write = await self._mcp_client_transport.__aenter__()
58+
59+
self._mcp_client_session = ClientSession(read, write)
60+
await self._mcp_client_session.__aenter__()
61+
62+
# Initialize the connection
63+
await self._mcp_client_session.initialize()
64+
65+
async def stopped(self) -> None:
66+
await self._mcp_client_session.__aexit__(None, None, None)
67+
await self._mcp_client_transport.__aexit__(None, None, None)
68+
69+
async def _handle_data(self) -> None:
70+
"""Override the method to handle exceptions properly."""
71+
try:
72+
await super()._handle_data()
73+
finally:
74+
# Ensure the resources created in `started` are properly cleaned up.
75+
await self.stopped()
76+
77+
async def get_swarm_agent(self) -> SwarmAgent:
78+
if not self._mcp_swarm_agent:
79+
tools = await self._get_tools()
80+
self._mcp_swarm_agent = SwarmAgent(
81+
name=self.name,
82+
model=self.client.model,
83+
instructions=self.system,
84+
functions=[wrap_error(t) for t in tools],
85+
)
86+
return self._mcp_swarm_agent
87+
88+
async def _get_tools(self) -> list[Callable]:
89+
result = await self._mcp_client_session.list_tools()
90+
tools = [self._make_tool(t) for t in result.tools]
91+
return tools
92+
93+
def _make_tool(self, t: Tool) -> Callable:
3894
async def tool(**kwargs) -> Any:
3995
# Validate the input against the schema
4096
jsonschema.validate(instance=kwargs, schema=t.inputSchema)
@@ -64,51 +120,5 @@ async def tool(**kwargs) -> Any:
64120
description=t.description,
65121
parameters=t.inputSchema,
66122
)
67-
tool.__mcp_tool_args__ = t.inputSchema["properties"].keys()
123+
tool.__mcp_tool_args__ = tuple(t.inputSchema["properties"].keys())
68124
return tool
69-
70-
async def get_tools(self) -> list[Callable]:
71-
result = await self._mcp_client_session.list_tools()
72-
tools = [self.make_tool(t) for t in result.tools]
73-
return tools
74-
75-
async def get_swarm_agent(self) -> SwarmAgent:
76-
if not self._mcp_swarm_agent:
77-
tools = await self.get_tools()
78-
self._mcp_swarm_agent = SwarmAgent(
79-
name=self.name,
80-
model=self.client.model,
81-
instructions=self.system,
82-
functions=[wrap_error(t) for t in tools],
83-
)
84-
return self._mcp_swarm_agent
85-
86-
async def started(self) -> None:
87-
"""
88-
Combining `started` and `stopped` to achieve the following behavior:
89-
90-
async with sse_client(url=url) as (read, write):
91-
async with ClientSession(read, write) as session:
92-
pass
93-
"""
94-
url = urljoin(self.mcp_server_base_url, "sse")
95-
self._mcp_sse_client = sse_client(url=url)
96-
read, write = await self._mcp_sse_client.__aenter__()
97-
98-
self._mcp_client_session = ClientSession(read, write)
99-
await self._mcp_client_session.__aenter__()
100-
101-
# Initialize the connection
102-
await self._mcp_client_session.initialize()
103-
104-
async def stopped(self) -> None:
105-
await self._mcp_client_session.__aexit__(None, None, None)
106-
await self._mcp_sse_client.__aexit__(None, None, None)
107-
108-
async def _handle_data(self) -> None:
109-
"""Override the method to handle exceptions properly."""
110-
try:
111-
await super()._handle_data()
112-
finally:
113-
# Ensure the resources created in `started` are properly cleaned up.
114-
await self.stopped()

tests/agents/mcp_server.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from mcp.server.fastmcp import FastMCP
2+
3+
mcp = FastMCP("Weather")
4+
5+
6+
@mcp.tool()
7+
def query_weather(city: str) -> str:
8+
"""Query the weather in the given city."""
9+
return f"The weather in {city} is sunny."
10+
11+
12+
if __name__ == "__main__":
13+
mcp.run()

tests/agents/test_mcp_agent.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
from coagent.agents.mcp_agent import MCPAgent
4+
5+
6+
class TestMCPAgent:
7+
@pytest.mark.asyncio
8+
async def test_get_tools(self):
9+
agent = MCPAgent(mcp_server_base_url="python tests/agents/mcp_server.py")
10+
await agent.started()
11+
12+
tools = await agent._get_tools()
13+
tool = tools[0]
14+
15+
assert tool.__name__ == "query_weather"
16+
assert tool.__doc__ == "Query the weather in the given city."
17+
assert tool.__mcp_tool_schema__ == {
18+
"description": "Query the weather in the given city.",
19+
"name": "query_weather",
20+
"parameters": {
21+
"properties": {
22+
"city": {
23+
"title": "City",
24+
"type": "string",
25+
}
26+
},
27+
"required": ["city"],
28+
"title": "query_weatherArguments",
29+
"type": "object",
30+
},
31+
}
32+
assert tool.__mcp_tool_args__ == ("city",)
33+
34+
result = await tool(city="Beijing")
35+
assert result == "The weather in Beijing is sunny."
36+
37+
await agent.stopped()

0 commit comments

Comments
 (0)