Skip to content

Commit 0d9f0b3

Browse files
committed
Add MCPAgent
1 parent 9c11a73 commit 0d9f0b3

File tree

8 files changed

+307
-2
lines changed

8 files changed

+307
-2
lines changed

coagent/agents/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa: F401
22
from .chat_agent import ChatAgent, confirm, submit, RunContext, StreamChatAgent, tool
33
from .dynamic_triage import DynamicTriage
4+
from .mcp_agent import MCPAgent
45
from .messages import ChatHistory, ChatMessage
56
from .model_client import ModelClient
67
from .parallel import Aggregator, AggregationResult, Parallel

coagent/agents/aswarm/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ def greet(
131131
Then you will get a JSON schema with per-parameter descriptions.
132132
"""
133133

134+
if hasattr(func, "__mcp_tool_schema__"):
135+
# If the function already has a schema, return it.
136+
# This is the case for tools used in MCPAgent.
137+
return dict(
138+
type="function",
139+
function=func.__mcp_tool_schema__
140+
)
141+
134142
# Construct the pydantic mdoel for the _under_fn's function signature parameters.
135143
# 1. Get the function signature.
136144

coagent/agents/chat_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ def system(self) -> str:
234234
def client(self) -> ModelClient:
235235
return self._client
236236

237+
async def get_swarm_agent(self) -> SwarmAgent:
238+
return self._swarm_agent
239+
237240
async def agent(self, agent_type: str) -> AsyncIterator[ChatMessage]:
238241
"""The candidate agent to delegate the conversation to."""
239242
async for chunk in StreamDelegate(self, agent_type).handle(self._history):
@@ -265,8 +268,10 @@ async def _handle_history(
265268
await self.update_user_confirmed(msg)
266269
await self.update_user_submitted(msg)
267270

271+
swarm_agent = await self.get_swarm_agent()
272+
268273
response = self._swarm_client.run_and_stream(
269-
agent=self._swarm_agent,
274+
agent=swarm_agent,
270275
messages=[m.model_dump() for m in msg.messages],
271276
context_variables=msg.extensions,
272277
)

coagent/agents/mcp_agent.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import Any, AsyncContextManager, Callable
2+
from urllib.parse import urljoin
3+
4+
from mcp import ClientSession, Tool
5+
from mcp.types import ImageContent, TextContent
6+
from mcp.client.sse import sse_client
7+
import jsonschema
8+
9+
from .aswarm import Agent as SwarmAgent
10+
from .chat_agent import ChatAgent, wrap_error
11+
from .model_client import default_model_client, ModelClient
12+
13+
14+
class MCPAgent(ChatAgent):
15+
"""An agent that can use tools provided by MCP (Model Context Protocol) servers."""
16+
17+
def __init__(
18+
self,
19+
system: str = "",
20+
mcp_server_base_url: str = "",
21+
client: ModelClient = default_model_client,
22+
) -> None:
23+
super().__init__(system=system, client=client)
24+
25+
self._mcp_server_base_url: str = mcp_server_base_url
26+
self._mcp_sse_client: AsyncContextManager[tuple] | None = None
27+
self._mcp_client_session: ClientSession | None = None
28+
29+
self._mcp_swarm_agent: SwarmAgent | None = None
30+
31+
@property
32+
def mcp_server_base_url(self) -> str:
33+
if not self._mcp_server_base_url:
34+
raise ValueError("MCP server base URL is empty")
35+
return self._mcp_server_base_url
36+
37+
def make_tool(self, t: Tool) -> Callable:
38+
async def tool(**kwargs) -> Any:
39+
# Validate the input against the schema
40+
jsonschema.validate(instance=kwargs, schema=t.inputSchema)
41+
# Actually call the tool.
42+
result = await self._mcp_client_session.call_tool(t.name, arguments=kwargs)
43+
if not result.content:
44+
return ""
45+
content = result.content[0]
46+
47+
if result.isError:
48+
raise ValueError(content.text)
49+
50+
match content:
51+
case TextContent():
52+
return content.text
53+
case ImageContent():
54+
return content.data
55+
case _: # EmbeddedResource() or other types
56+
return ""
57+
58+
tool.__name__ = t.name
59+
tool.__doc__ = t.description
60+
61+
# Attach the schema and arguments to the tool.
62+
tool.__mcp_tool_schema__ = dict(
63+
name=t.name,
64+
description=t.description,
65+
parameters=t.inputSchema,
66+
)
67+
tool.__mcp_tool_args__ = t.inputSchema["properties"].keys()
68+
return tool
69+
70+
async def get_tools(self) -> list[Callable]:
71+
result = await self._mcp_client_session.list_tools()
72+
tools = [self.make_tool(t) for t in result.tools]
73+
return tools
74+
75+
async def get_swarm_agent(self) -> SwarmAgent:
76+
if not self._mcp_swarm_agent:
77+
tools = await self.get_tools()
78+
self._mcp_swarm_agent = SwarmAgent(
79+
name=self.name,
80+
model=self.client.model,
81+
instructions=self.system,
82+
functions=[wrap_error(t) for t in tools],
83+
)
84+
return self._mcp_swarm_agent
85+
86+
async def started(self) -> None:
87+
"""
88+
Combining `started` and `stopped` to achieve the following behavior:
89+
90+
async with sse_client(url=url) as (read, write):
91+
async with ClientSession(read, write) as session:
92+
pass
93+
"""
94+
url = urljoin(self.mcp_server_base_url, "sse")
95+
self._mcp_sse_client = sse_client(url=url)
96+
read, write = await self._mcp_sse_client.__aenter__()
97+
98+
self._mcp_client_session = ClientSession(read, write)
99+
await self._mcp_client_session.__aenter__()
100+
101+
# Initialize the connection
102+
await self._mcp_client_session.initialize()
103+
104+
async def stopped(self) -> None:
105+
await self._mcp_client_session.__aexit__(None, None, None)
106+
await self._mcp_sse_client.__aexit__(None, None, None)
107+
108+
async def _handle_data(self) -> None:
109+
"""Override the method to handle exceptions properly."""
110+
try:
111+
await super()._handle_data()
112+
finally:
113+
# Ensure the resources created in `started` are properly cleaned up.
114+
await self.stopped()

coagent/core/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def direct_values(self, prefix: str) -> list[Any]:
4343

4444

4545
def get_func_args(func) -> set[str]:
46+
if hasattr(func, "__mcp_tool_args__"):
47+
return set(func.__mcp_tool_args__)
48+
4649
hints = get_type_hints(func)
4750
hints.pop("return", None) # Ignore the return type.
4851
return set(hints.keys())

poetry.lock

Lines changed: 101 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ blinker = "1.9.0"
3030
loguru = "0.7.3"
3131
jq = "1.8.0"
3232
litellm = "1.55.12"
33+
mcp = "1.2.0"
3334

3435
[tool.pyright]
3536
# https://github.com/microsoft/pyright/blob/main/docs/configuration.md

0 commit comments

Comments
 (0)