Skip to content

Commit e5b1a91

Browse files
committed
Improve MCPAgent to get system prompt from the MCP server
1 parent 92ce6fa commit e5b1a91

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

coagent/agents/mcp_agent.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import dataclasses
12
from typing import Any, AsyncContextManager, Callable
23
from urllib.parse import urljoin
34

4-
from mcp import ClientSession, Tool
5+
from coagent.core.exceptions import InternalError
6+
from mcp import ClientSession, Tool, McpError
57
from mcp.types import ImageContent, TextContent
68
from mcp.client.sse import sse_client
79
from mcp.client.stdio import stdio_client, StdioServerParameters
@@ -12,22 +14,29 @@
1214
from .model_client import default_model_client, ModelClient
1315

1416

17+
@dataclasses.dataclass
18+
class Prompt:
19+
name: str
20+
arguments: dict[str, str] | None = None
21+
22+
1523
class MCPAgent(ChatAgent):
1624
"""An agent that can use tools provided by MCP (Model Context Protocol) servers."""
1725

1826
def __init__(
1927
self,
20-
system: str = "",
28+
system: Prompt | None = None,
2129
mcp_server_base_url: str = "",
2230
client: ModelClient = default_model_client,
2331
) -> None:
24-
super().__init__(system=system, client=client)
32+
super().__init__(system="", client=client)
2533

2634
self._mcp_server_base_url: str = mcp_server_base_url
2735
self._mcp_client_transport: AsyncContextManager[tuple] | None = None
2836
self._mcp_client_session: ClientSession | None = None
2937

3038
self._mcp_swarm_agent: SwarmAgent | None = None
39+
self._mcp_system_prompt: Prompt | None = system
3140

3241
@property
3342
def mcp_server_base_url(self) -> str:
@@ -76,15 +85,34 @@ async def _handle_data(self) -> None:
7685

7786
async def get_swarm_agent(self) -> SwarmAgent:
7887
if not self._mcp_swarm_agent:
88+
system = await self._get_system_prompt()
7989
tools = await self._get_tools()
8090
self._mcp_swarm_agent = SwarmAgent(
8191
name=self.name,
8292
model=self.client.model,
83-
instructions=self.system,
93+
instructions=system,
8494
functions=[wrap_error(t) for t in tools],
8595
)
8696
return self._mcp_swarm_agent
8797

98+
async def _get_system_prompt(self) -> str:
99+
if not self._mcp_system_prompt:
100+
return ""
101+
102+
try:
103+
prompt = await self._mcp_client_session.get_prompt(
104+
**dataclasses.asdict(self._mcp_system_prompt),
105+
)
106+
except McpError as exc:
107+
raise InternalError(str(exc))
108+
109+
content = prompt.messages[0].content
110+
match content:
111+
case TextContent():
112+
return content.text
113+
case _: # ImageContent() or EmbeddedResource() or other types
114+
return ""
115+
88116
async def _get_tools(self) -> list[Callable]:
89117
result = await self._mcp_client_session.list_tools()
90118
tools = [self._make_tool(t) for t in result.tools]

0 commit comments

Comments
 (0)