-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Expand file tree
/
Copy pathmcp.py
More file actions
243 lines (197 loc) · 9.35 KB
/
mcp.py
File metadata and controls
243 lines (197 loc) · 9.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from contextlib import AsyncExitStack
from typing import Dict, List, Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult, TextContent
from app.logger import logger
from app.tool.base import BaseTool, ToolResult
from app.tool.tool_collection import ToolCollection
from mcp.client.streamable_http import streamablehttp_client
class MCPClientTool(BaseTool):
"""Represents a tool proxy that can be called on the MCP server from the client side."""
session: Optional[ClientSession] = None
server_id: str = "" # Add server identifier
original_name: str = ""
async def execute(self, **kwargs) -> ToolResult:
"""Execute the tool by making a remote call to the MCP server."""
if not self.session:
return ToolResult(error="Not connected to MCP server")
try:
logger.info(f"Executing tool: {self.original_name}")
result = await self.session.call_tool(self.original_name, kwargs)
content_str = ", ".join(
item.text for item in result.content if isinstance(item, TextContent)
)
return ToolResult(output=content_str or "No output returned.")
except Exception as e:
return ToolResult(error=f"Error executing tool: {str(e)}")
class MCPClients(ToolCollection):
"""
A collection of tools that connects to multiple MCP servers and manages available tools through the Model Context Protocol.
"""
sessions: Dict[str, ClientSession] = {}
exit_stacks: Dict[str, AsyncExitStack] = {}
description: str = "MCP client tools for server interaction"
def __init__(self):
super().__init__() # Initialize with empty tools list
self.name = "mcp" # Keep name for backward compatibility
async def connect_sse(self, server_url: str, server_id: str = "") -> None:
"""Connect to an MCP server using SSE transport."""
if not server_url:
raise ValueError("Server URL is required.")
server_id = server_id or server_url
# Always ensure clean disconnection before new connection
if server_id in self.sessions:
await self.disconnect(server_id)
exit_stack = AsyncExitStack()
self.exit_stacks[server_id] = exit_stack
streams_context = sse_client(url=server_url)
streams = await exit_stack.enter_async_context(streams_context)
session = await exit_stack.enter_async_context(ClientSession(*streams))
self.sessions[server_id] = session
await self._initialize_and_list_tools(server_id)
async def connect_stdio(
self, command: str, args: List[str], server_id: str = ""
) -> None:
"""Connect to an MCP server using stdio transport."""
if not command:
raise ValueError("Server command is required.")
server_id = server_id or command
# Always ensure clean disconnection before new connection
if server_id in self.sessions:
await self.disconnect(server_id)
exit_stack = AsyncExitStack()
self.exit_stacks[server_id] = exit_stack
server_params = StdioServerParameters(command=command, args=args)
stdio_transport = await exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = await exit_stack.enter_async_context(ClientSession(read, write))
self.sessions[server_id] = session
await self._initialize_and_list_tools(server_id)
async def connect_streamable_http(
self,
server_url: str,
server_id: str = "",
headers: Optional[Dict[str, str]] = None,
) -> None:
"""Connect to a streamableHttp type MCP server (based on HTTP streaming)"""
if not server_url:
raise ValueError("Server URL is required for streamableHttp connection")
server_id = server_id or server_url # Use URL as default server_id
# Disconnect existing connection to avoid conflicts
if server_id in self.sessions:
await self.disconnect(server_id)
exit_stack = AsyncExitStack()
self.exit_stacks[server_id] = exit_stack
try:
# Initialize HTTP streaming connection (with custom headers and timeout)
streams_context = streamablehttp_client(
url=server_url,
headers=headers or {}
)
# Enter context manager to establish streaming connection
streams = await exit_stack.enter_async_context(streams_context)
read_stream, write_stream, get_session_id = streams
if read_stream is None or write_stream is None:
raise ValueError("Invalid streams returned from streamablehttp_client")
session = await exit_stack.enter_async_context(
ClientSession(
read_stream,
write_stream,
)
)
self.sessions[server_id] = session
# Reuse existing logic to initialize tool list
await self._initialize_and_list_tools(server_id)
logger.info(f"Connected to streamableHttp server {server_id}")
except Exception as e:
# Clean up resources if connection fails
await exit_stack.aclose()
self.exit_stacks.pop(server_id, None)
raise RuntimeError(f"Failed to connect to streamableHttp server: {str(e)}")
async def _initialize_and_list_tools(self, server_id: str) -> None:
"""Initialize session and populate tool map."""
session = self.sessions.get(server_id)
if not session:
raise RuntimeError(f"Session not initialized for server {server_id}")
await session.initialize()
response = await session.list_tools()
# Create proper tool objects for each server tool
for tool in response.tools:
original_name = tool.name
tool_name = f"mcp_{server_id}_{original_name}"
tool_name = self._sanitize_tool_name(tool_name)
server_tool = MCPClientTool(
name=tool_name,
description=tool.description,
parameters=tool.inputSchema,
session=session,
server_id=server_id,
original_name=original_name,
)
self.tool_map[tool_name] = server_tool
# Update tools tuple
self.tools = tuple(self.tool_map.values())
logger.info(
f"Connected to server {server_id} with tools: {[tool.name for tool in response.tools]}"
)
def _sanitize_tool_name(self, name: str) -> str:
"""Sanitize tool name to match MCPClientTool requirements."""
import re
# Replace invalid characters with underscores
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
# Remove consecutive underscores
sanitized = re.sub(r"_+", "_", sanitized)
# Remove leading/trailing underscores
sanitized = sanitized.strip("_")
# Truncate to 64 characters if needed
if len(sanitized) > 64:
sanitized = sanitized[:64]
return sanitized
async def list_tools(self) -> ListToolsResult:
"""List all available tools."""
tools_result = ListToolsResult(tools=[])
for session in self.sessions.values():
response = await session.list_tools()
tools_result.tools += response.tools
return tools_result
async def disconnect(self, server_id: str = "") -> None:
"""Disconnect from a specific MCP server or all servers if no server_id provided."""
if server_id:
if server_id in self.sessions:
try:
exit_stack = self.exit_stacks.get(server_id)
# Close the exit stack which will handle session cleanup
if exit_stack:
try:
await exit_stack.aclose()
except RuntimeError as e:
if "cancel scope" in str(e).lower():
logger.warning(
f"Cancel scope error during disconnect from {server_id}, continuing with cleanup: {e}"
)
else:
raise
# Clean up references
self.sessions.pop(server_id, None)
self.exit_stacks.pop(server_id, None)
# Remove tools associated with this server
self.tool_map = {
k: v
for k, v in self.tool_map.items()
if v.server_id != server_id
}
self.tools = tuple(self.tool_map.values())
logger.info(f"Disconnected from MCP server {server_id}")
except Exception as e:
logger.error(f"Error disconnecting from server {server_id}: {e}")
else:
# Disconnect from all servers in a deterministic order
for sid in sorted(list(self.sessions.keys())):
await self.disconnect(sid)
self.tool_map = {}
self.tools = tuple()
logger.info("Disconnected from all MCP servers")