Skip to content

Commit a32d83a

Browse files
committed
Add tests for ChatAgent
1 parent 9bed96e commit a32d83a

File tree

2 files changed

+103
-2
lines changed

2 files changed

+103
-2
lines changed

coagent/core/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ async def run(
293293
msg: RawMessage,
294294
stream: bool = False,
295295
session_id: str = "",
296+
request: bool = True,
296297
timeout: float = 0.5,
297298
) -> AsyncIterator[RawMessage] | RawMessage | None:
298299
"""Create an agent and run it with the given message."""
@@ -303,7 +304,7 @@ async def run(
303304
addr = Address(name=self.name, id=session_id)
304305

305306
return await self.__runtime.channel.publish(
306-
addr, msg, stream=stream, request=True, timeout=timeout
307+
addr, msg, stream=stream, request=request, timeout=timeout
307308
)
308309

309310

tests/agents/test_chat_agent.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
import sys
12
from typing import AsyncIterator
23

34
from pydantic import Field
45
import pytest
56

6-
from coagent.agents.chat_agent import wrap_error
7+
from coagent.agents import ChatAgent
8+
from coagent.agents.chat_agent import (
9+
CallTool,
10+
CallToolResult,
11+
ListToolsResult,
12+
MCPTextContent,
13+
MCPTool,
14+
wrap_error,
15+
)
16+
from coagent.core import Address, RawMessage
17+
from coagent.core.runtime import NopChannel
18+
import jsonschema
719

820

921
@pytest.mark.asyncio
@@ -37,3 +49,91 @@ async def func(
3749

3850
result = await func(a=1, b=0)
3951
assert await anext(result) == "Error: division by zero"
52+
53+
54+
class MCPServerTestChannel(NopChannel):
55+
async def publish(
56+
self,
57+
addr: Address,
58+
msg: RawMessage,
59+
stream: bool = False,
60+
request: bool = False,
61+
reply: str = "",
62+
timeout: float = 0.5,
63+
probe: bool = True,
64+
) -> AsyncIterator[RawMessage] | RawMessage | None:
65+
match msg.header.type:
66+
case "ListTools":
67+
return ListToolsResult(
68+
tools=[
69+
MCPTool(
70+
name="query_weather",
71+
description="Query the weather in the given city.",
72+
inputSchema={
73+
"title": "query_weatherArguments",
74+
"type": "object",
75+
"properties": {
76+
"city": {
77+
"title": "City",
78+
"type": "string",
79+
},
80+
},
81+
"required": ["city"],
82+
},
83+
)
84+
]
85+
).encode()
86+
87+
case "CallTool":
88+
call_tool = CallTool.decode(msg)
89+
city = call_tool.arguments["city"]
90+
return CallToolResult(
91+
content=[
92+
MCPTextContent(
93+
type="text", text=f"The weather in {city} is sunny."
94+
)
95+
],
96+
).encode()
97+
98+
99+
class TestChatAgent:
100+
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
101+
@pytest.mark.asyncio
102+
async def test_get_mcp_tools(self):
103+
agent = ChatAgent()
104+
addr = Address(name="test", id="0")
105+
agent.init(MCPServerTestChannel(), addr)
106+
107+
tools = await agent._get_mcp_tools(["server1"])
108+
assert len(tools) == 1
109+
# Tool query_weather
110+
tool = tools[0]
111+
112+
# Validate the tool
113+
assert tool.__name__ == "query_weather"
114+
assert tool.__doc__ == "Query the weather in the given city."
115+
assert tool.__mcp_tool_schema__ == {
116+
"description": "Query the weather in the given city.",
117+
"name": "query_weather",
118+
"parameters": {
119+
"properties": {
120+
"city": {
121+
"title": "City",
122+
"type": "string",
123+
}
124+
},
125+
"required": ["city"],
126+
"title": "query_weatherArguments",
127+
"type": "object",
128+
},
129+
}
130+
assert tool.__mcp_tool_args__ == ("city",)
131+
132+
# Call the tool with no arguments
133+
with pytest.raises(jsonschema.exceptions.ValidationError) as exc:
134+
await tool()
135+
assert str(exc.value).startswith("'city' is a required property")
136+
137+
# Call the tool with required arguments
138+
result = await tool(city="Beijing")
139+
assert result == "The weather in Beijing is sunny."

0 commit comments

Comments
 (0)