Skip to content

Commit 146f2dd

Browse files
committed
fix router
1 parent 2b3e24d commit 146f2dd

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/mcpm/router/router.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = DEFAU
4242
"""Initialize the router."""
4343
self.server_sessions: t.Dict[str, ServerConnection] = {}
4444
self.capabilities_mapping: t.Dict[str, t.Dict[str, t.Any]] = defaultdict(dict)
45-
self.tool_names: t.Set[str] = set()
45+
self.tool_name_to_server_id: t.Dict[str, str] = {}
4646
self.tools_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
4747
self.prompts_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
4848
self.resources_mapping: t.Dict[str, t.Dict[str, t.Any]] = {}
@@ -123,9 +123,9 @@ async def add_server(self, server_id: str, server_config: ServerConfig) -> None:
123123
tools = await client.session.list_tools() # type: ignore
124124
for tool in tools.tools:
125125
# To make sure tool name is unique across all servers
126-
if tool.name in self.tool_names:
126+
if tool.name in self.tool_name_to_server_id:
127127
raise ValueError(f"Tool {tool.name} already exists. Please use unique tool names across all servers.")
128-
self.tool_names.add(tool.name)
128+
self.tool_name_to_server_id[tool.name] = server_id
129129
self.tools_mapping[f"{server_id}{TOOL_SPLITOR}{tool.name}"] = tool.model_dump()
130130

131131
if response.capabilities.prompts:
@@ -274,17 +274,21 @@ async def list_tools(req: types.ListToolsRequest) -> types.ServerResult:
274274
async def call_tool(req: types.CallToolRequest) -> types.ServerResult:
275275
active_servers = get_active_servers(req.params.meta.profile) # type: ignore
276276
logger.info(f"call_tool: {req} with active servers: {active_servers}")
277-
278-
server_id, tool_name = parse_namespaced_id(req.params.name, TOOL_SPLITOR)
279-
if server_id is None or tool_name is None:
277+
278+
tool_name = req.params.name
279+
server_id = self.tool_name_to_server_id.get(tool_name)
280+
if server_id is None:
281+
logger.debug(f"call_tool: {req} with tool_name: {tool_name}. Server ID {server_id} is not found")
280282
return empty_result()
281283
if server_id not in active_servers:
284+
logger.debug(f"call_tool: {req} with tool_name: {tool_name}. Server ID {server_id} is not in active servers")
282285
return empty_result()
283286

284287
try:
285288
result = await self.server_sessions[server_id].session.call_tool(tool_name, req.params.arguments or {})
286289
return types.ServerResult(result)
287290
except Exception as e:
291+
logger.error(f"Error calling tool {tool_name} on server {server_id}: {e}")
288292
return types.ServerResult(
289293
types.CallToolResult(
290294
content=[types.TextContent(type="text", text=str(e))],

0 commit comments

Comments
 (0)