|
| 1 | +import dataclasses |
1 | 2 | from typing import Any, AsyncContextManager, Callable |
2 | 3 | from urllib.parse import urljoin |
3 | 4 |
|
4 | | -from mcp import ClientSession, Tool |
| 5 | +from coagent.core.exceptions import InternalError |
| 6 | +from mcp import ClientSession, Tool, McpError |
5 | 7 | from mcp.types import ImageContent, TextContent |
6 | 8 | from mcp.client.sse import sse_client |
7 | 9 | from mcp.client.stdio import stdio_client, StdioServerParameters |
|
12 | 14 | from .model_client import default_model_client, ModelClient |
13 | 15 |
|
14 | 16 |
|
| 17 | +@dataclasses.dataclass |
| 18 | +class Prompt: |
| 19 | + name: str |
| 20 | + arguments: dict[str, str] | None = None |
| 21 | + |
| 22 | + |
15 | 23 | class MCPAgent(ChatAgent): |
16 | 24 | """An agent that can use tools provided by MCP (Model Context Protocol) servers.""" |
17 | 25 |
|
18 | 26 | def __init__( |
19 | 27 | self, |
20 | | - system: str = "", |
| 28 | + system: Prompt | None = None, |
21 | 29 | mcp_server_base_url: str = "", |
22 | 30 | client: ModelClient = default_model_client, |
23 | 31 | ) -> None: |
24 | | - super().__init__(system=system, client=client) |
| 32 | + super().__init__(system="", client=client) |
25 | 33 |
|
26 | 34 | self._mcp_server_base_url: str = mcp_server_base_url |
27 | 35 | self._mcp_client_transport: AsyncContextManager[tuple] | None = None |
28 | 36 | self._mcp_client_session: ClientSession | None = None |
29 | 37 |
|
30 | 38 | self._mcp_swarm_agent: SwarmAgent | None = None |
| 39 | + self._mcp_system_prompt: Prompt | None = system |
31 | 40 |
|
32 | 41 | @property |
33 | 42 | def mcp_server_base_url(self) -> str: |
@@ -76,15 +85,34 @@ async def _handle_data(self) -> None: |
76 | 85 |
|
77 | 86 | async def get_swarm_agent(self) -> SwarmAgent: |
78 | 87 | if not self._mcp_swarm_agent: |
| 88 | + system = await self._get_system_prompt() |
79 | 89 | tools = await self._get_tools() |
80 | 90 | self._mcp_swarm_agent = SwarmAgent( |
81 | 91 | name=self.name, |
82 | 92 | model=self.client.model, |
83 | | - instructions=self.system, |
| 93 | + instructions=system, |
84 | 94 | functions=[wrap_error(t) for t in tools], |
85 | 95 | ) |
86 | 96 | return self._mcp_swarm_agent |
87 | 97 |
|
| 98 | + async def _get_system_prompt(self) -> str: |
| 99 | + if not self._mcp_system_prompt: |
| 100 | + return "" |
| 101 | + |
| 102 | + try: |
| 103 | + prompt = await self._mcp_client_session.get_prompt( |
| 104 | + **dataclasses.asdict(self._mcp_system_prompt), |
| 105 | + ) |
| 106 | + except McpError as exc: |
| 107 | + raise InternalError(str(exc)) |
| 108 | + |
| 109 | + content = prompt.messages[0].content |
| 110 | + match content: |
| 111 | + case TextContent(): |
| 112 | + return content.text |
| 113 | + case _: # ImageContent() or EmbeddedResource() or other types |
| 114 | + return "" |
| 115 | + |
88 | 116 | async def _get_tools(self) -> list[Callable]: |
89 | 117 | result = await self._mcp_client_session.list_tools() |
90 | 118 | tools = [self._make_tool(t) for t in result.tools] |
|
0 commit comments