|
| 1 | +import sys |
1 | 2 | from typing import AsyncIterator |
2 | 3 |
|
3 | 4 | from pydantic import Field |
4 | 5 | import pytest |
5 | 6 |
|
6 | | -from coagent.agents.chat_agent import wrap_error |
| 7 | +from coagent.agents import ChatAgent |
| 8 | +from coagent.agents.chat_agent import ( |
| 9 | + CallTool, |
| 10 | + CallToolResult, |
| 11 | + ListToolsResult, |
| 12 | + MCPTextContent, |
| 13 | + MCPTool, |
| 14 | + wrap_error, |
| 15 | +) |
| 16 | +from coagent.core import Address, RawMessage |
| 17 | +from coagent.core.runtime import NopChannel |
| 18 | +import jsonschema |
7 | 19 |
|
8 | 20 |
|
9 | 21 | @pytest.mark.asyncio |
@@ -37,3 +49,91 @@ async def func( |
37 | 49 |
|
38 | 50 | result = await func(a=1, b=0) |
39 | 51 | assert await anext(result) == "Error: division by zero" |
| 52 | + |
| 53 | + |
| 54 | +class MCPServerTestChannel(NopChannel): |
| 55 | + async def publish( |
| 56 | + self, |
| 57 | + addr: Address, |
| 58 | + msg: RawMessage, |
| 59 | + stream: bool = False, |
| 60 | + request: bool = False, |
| 61 | + reply: str = "", |
| 62 | + timeout: float = 0.5, |
| 63 | + probe: bool = True, |
| 64 | + ) -> AsyncIterator[RawMessage] | RawMessage | None: |
| 65 | + match msg.header.type: |
| 66 | + case "ListTools": |
| 67 | + return ListToolsResult( |
| 68 | + tools=[ |
| 69 | + MCPTool( |
| 70 | + name="query_weather", |
| 71 | + description="Query the weather in the given city.", |
| 72 | + inputSchema={ |
| 73 | + "title": "query_weatherArguments", |
| 74 | + "type": "object", |
| 75 | + "properties": { |
| 76 | + "city": { |
| 77 | + "title": "City", |
| 78 | + "type": "string", |
| 79 | + }, |
| 80 | + }, |
| 81 | + "required": ["city"], |
| 82 | + }, |
| 83 | + ) |
| 84 | + ] |
| 85 | + ).encode() |
| 86 | + |
| 87 | + case "CallTool": |
| 88 | + call_tool = CallTool.decode(msg) |
| 89 | + city = call_tool.arguments["city"] |
| 90 | + return CallToolResult( |
| 91 | + content=[ |
| 92 | + MCPTextContent( |
| 93 | + type="text", text=f"The weather in {city} is sunny." |
| 94 | + ) |
| 95 | + ], |
| 96 | + ).encode() |
| 97 | + |
| 98 | + |
| 99 | +class TestChatAgent: |
| 100 | + @pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.") |
| 101 | + @pytest.mark.asyncio |
| 102 | + async def test_get_mcp_tools(self): |
| 103 | + agent = ChatAgent() |
| 104 | + addr = Address(name="test", id="0") |
| 105 | + agent.init(MCPServerTestChannel(), addr) |
| 106 | + |
| 107 | + tools = await agent._get_mcp_tools(["server1"]) |
| 108 | + assert len(tools) == 1 |
| 109 | + # Tool query_weather |
| 110 | + tool = tools[0] |
| 111 | + |
| 112 | + # Validate the tool |
| 113 | + assert tool.__name__ == "query_weather" |
| 114 | + assert tool.__doc__ == "Query the weather in the given city." |
| 115 | + assert tool.__mcp_tool_schema__ == { |
| 116 | + "description": "Query the weather in the given city.", |
| 117 | + "name": "query_weather", |
| 118 | + "parameters": { |
| 119 | + "properties": { |
| 120 | + "city": { |
| 121 | + "title": "City", |
| 122 | + "type": "string", |
| 123 | + } |
| 124 | + }, |
| 125 | + "required": ["city"], |
| 126 | + "title": "query_weatherArguments", |
| 127 | + "type": "object", |
| 128 | + }, |
| 129 | + } |
| 130 | + assert tool.__mcp_tool_args__ == ("city",) |
| 131 | + |
| 132 | + # Call the tool with no arguments |
| 133 | + with pytest.raises(jsonschema.exceptions.ValidationError) as exc: |
| 134 | + await tool() |
| 135 | + assert str(exc.value).startswith("'city' is a required property") |
| 136 | + |
| 137 | + # Call the tool with required arguments |
| 138 | + result = await tool(city="Beijing") |
| 139 | + assert result == "The weather in Beijing is sunny." |
0 commit comments