Skip to content

Commit 7216983

Browse files
Merge pull request #14720 from uc4w6c/feat/remove-servername-prefix-mcp_tools
feat: remove server_name prefix from list_tools
2 parents e377e30 + e912b89 commit 7216983

File tree

5 files changed

+271
-29
lines changed

5 files changed

+271
-29
lines changed

litellm/proxy/_experimental/mcp_server/mcp_server_manager.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ async def _get_tools_from_server(
421421
self,
422422
server: MCPServer,
423423
mcp_auth_header: Optional[str] = None,
424+
add_prefix: bool = True,
424425
) -> List[MCPTool]:
425426
"""
426427
Helper method to get tools from a single MCP server with prefixed names.
@@ -445,9 +446,11 @@ async def _get_tools_from_server(
445446

446447
tools = await self._fetch_tools_with_timeout(client, server.name)
447448

448-
prefixed_tools = self._create_prefixed_tools(tools, server)
449+
prefixed_or_original_tools = self._create_prefixed_tools(
450+
tools, server, add_prefix=add_prefix
451+
)
449452

450-
return prefixed_tools
453+
return prefixed_or_original_tools
451454

452455
except Exception as e:
453456
verbose_logger.warning(
@@ -516,7 +519,7 @@ async def _list_tools_task():
516519
return []
517520

518521
def _create_prefixed_tools(
519-
self, tools: List[MCPTool], server: MCPServer
522+
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
520523
) -> List[MCPTool]:
521524
"""
522525
Create prefixed tools and update tool mapping.
@@ -534,14 +537,16 @@ def _create_prefixed_tools(
534537
for tool in tools:
535538
prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix)
536539

537-
prefixed_tool = MCPTool(
538-
name=prefixed_name,
540+
name_to_use = prefixed_name if add_prefix else tool.name
541+
542+
tool_obj = MCPTool(
543+
name=name_to_use,
539544
description=tool.description,
540545
inputSchema=tool.inputSchema,
541546
)
542-
prefixed_tools.append(prefixed_tool)
547+
prefixed_tools.append(tool_obj)
543548

544-
# Update tool to server mapping with both original and prefixed names
549+
# Update tool to server mapping for resolution (support both forms)
545550
self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix
546551
self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix
547552

litellm/proxy/_experimental/mcp_server/rest_endpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ async def _get_tools_for_single_server(server, server_auth_header):
7373
tools = await global_mcp_server_manager._get_tools_from_server(
7474
server=server,
7575
mcp_auth_header=server_auth_header,
76+
add_prefix=False,
7677
)
7778
return _create_tool_response_objects(tools, server.mcp_info)
7879

litellm/proxy/_experimental/mcp_server/server.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ async def _get_tools_from_mcp_servers(
384384
allowed_mcp_servers=allowed_mcp_servers,
385385
)
386386

387+
# Decide whether to add prefix based on number of allowed servers
388+
add_prefix = not (len(allowed_mcp_servers) == 1)
389+
387390
# Get tools from each allowed server
388391
all_tools = []
389392
for server_id in allowed_mcp_servers:
@@ -406,6 +409,7 @@ async def _get_tools_from_mcp_servers(
406409
tools = await global_mcp_server_manager._get_tools_from_server(
407410
server=server,
408411
mcp_auth_header=server_auth_header,
412+
add_prefix=add_prefix,
409413
)
410414
all_tools.extend(tools)
411415
verbose_logger.debug(
@@ -637,27 +641,35 @@ def _get_mcp_servers_in_path(path: str) -> Optional[List[str]]:
637641
# Server names can contain slashes (e.g., "custom_solutions/user_123")
638642
mcp_path_match = re.match(r"^/mcp/([^?#]+)(?:\?.*)?(?:#.*)?$", path)
639643
if mcp_path_match:
640-
servers_and_path = mcp_path_match.group(1)
641-
642-
if servers_and_path:
643-
# Check if it contains commas (comma-separated servers)
644-
if ',' in servers_and_path:
645-
# For comma-separated, look for a path at the end
646-
# Common patterns: /tools, /chat/completions, etc.
647-
path_match = re.search(r'/([^/,]+(?:/[^/,]+)*)$', servers_and_path)
648-
if path_match:
649-
# Path found at the end, remove it from servers
650-
path_part = '/' + path_match.group(1)
651-
servers_part = servers_and_path[:-len(path_part)]
652-
mcp_servers_from_path = [s.strip() for s in servers_part.split(',') if s.strip()]
653-
else:
654-
# No path, just comma-separated servers
655-
mcp_servers_from_path = [s.strip() for s in servers_and_path.split(',') if s.strip()]
644+
mcp_servers_str = mcp_path_match.group(1)
645+
optional_path = mcp_path_match.group(2)
646+
647+
if mcp_servers_str:
648+
# First, try to split by comma for comma-separated lists
649+
if "," in mcp_servers_str:
650+
# For comma-separated lists, we need to handle the case where the last item
651+
# might include the path (e.g., "zapier,group1/tools" -> ["zapier", "group1/tools"])
652+
parts = [s.strip() for s in mcp_servers_str.split(",") if s.strip()]
653+
654+
# If there's an optional path AND the last part contains a slash that matches the optional path,
655+
# remove the path portion from the last server name
656+
if optional_path and len(parts) > 0 and "/" in parts[-1]:
657+
last_part = parts[-1]
658+
# Check if the last part ends with the optional path
659+
if optional_path and last_part.endswith(
660+
optional_path.lstrip("/")
661+
):
662+
# Remove the path portion from the last server name
663+
parts[-1] = last_part[: -len(optional_path.lstrip("/"))]
664+
665+
mcp_servers_from_path = parts
656666
else:
657-
# Single server case - use regex approach for server/path separation
658-
# This handles cases like "custom_solutions/user_123/chat/completions"
659-
# where we want to extract "custom_solutions/user_123" as the server name
660-
single_server_match = re.match(r"^([^/]+(?:/[^/]+)?)(?:/.*)?$", servers_and_path)
667+
# For single server, it might be just a name or contain slashes
668+
# We need to determine where the server name ends and the path begins
669+
# This is tricky - let's use the original logic but handle comma cases differently
670+
single_server_match = re.match(
671+
r"^([^/]+(?:/[^/]+)?)(?:/.*)?$", mcp_servers_str
672+
)
661673
if single_server_match:
662674
server_name = single_server_match.group(1)
663675
mcp_servers_from_path = [server_name]

tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails():
102102
working_server if server_id == "working_server" else failing_server
103103
)
104104

105-
async def mock_get_tools_from_server(server, mcp_auth_header=None):
105+
async def mock_get_tools_from_server(server, mcp_auth_header=None, add_prefix=True):
106106
if server.name == "working_server":
107107
# Working server returns tools
108108
tool1 = MagicMock()
@@ -184,7 +184,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing():
184184
failing_server1 if server_id == "failing_server1" else failing_server2
185185
)
186186

187-
async def mock_get_tools_from_server(server, mcp_auth_header=None):
187+
async def mock_get_tools_from_server(server, mcp_auth_header=None, add_prefix=True):
188188
# All servers fail
189189
raise Exception(f"Server {server.name} connection failed")
190190

@@ -448,3 +448,121 @@ async def test_mcp_routing_with_conflicting_alias_and_group_name():
448448
assert (
449449
called_servers[0].server_id == specific_server.server_id
450450
), "Should have contacted the specific server alias, not the group."
451+
452+
453+
@pytest.mark.asyncio
454+
async def test_list_tools_single_server_unprefixed_names():
455+
"""When only one MCP server is allowed, list tools should return unprefixed names."""
456+
try:
457+
from litellm.proxy._experimental.mcp_server.server import (
458+
_get_tools_from_mcp_servers,
459+
set_auth_context,
460+
)
461+
except ImportError:
462+
pytest.skip("MCP server not available")
463+
464+
# Mock user auth
465+
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
466+
set_auth_context(user_api_key_auth)
467+
468+
# One allowed server
469+
server = MagicMock()
470+
server.server_id = "server1"
471+
server.name = "Zapier MCP"
472+
server.alias = "zapier"
473+
474+
# Mock manager: allow just one server and return a tool based on add_prefix flag
475+
mock_manager = MagicMock()
476+
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
477+
mock_manager.get_mcp_server_by_id = (
478+
lambda server_id: server if server_id == "server1" else None
479+
)
480+
481+
async def mock_get_tools_from_server(
482+
server, mcp_auth_header=None, add_prefix=False
483+
):
484+
tool = MagicMock()
485+
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
486+
tool.description = "desc"
487+
tool.inputSchema = {}
488+
return [tool]
489+
490+
mock_manager._get_tools_from_server = mock_get_tools_from_server
491+
492+
with patch(
493+
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
494+
mock_manager,
495+
):
496+
tools = await _get_tools_from_mcp_servers(
497+
user_api_key_auth=user_api_key_auth,
498+
mcp_auth_header=None,
499+
mcp_servers=None,
500+
mcp_server_auth_headers=None,
501+
)
502+
503+
# Should be unprefixed since only one server is allowed
504+
assert len(tools) == 1
505+
assert tools[0].name == "toolA"
506+
507+
508+
@pytest.mark.asyncio
509+
async def test_list_tools_multiple_servers_prefixed_names():
510+
"""When multiple MCP servers are allowed, list tools should return prefixed names."""
511+
try:
512+
from litellm.proxy._experimental.mcp_server.server import (
513+
_get_tools_from_mcp_servers,
514+
set_auth_context,
515+
)
516+
except ImportError:
517+
pytest.skip("MCP server not available")
518+
519+
# Mock user auth
520+
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
521+
set_auth_context(user_api_key_auth)
522+
523+
# Two allowed servers
524+
server1 = MagicMock()
525+
server1.server_id = "server1"
526+
server1.name = "Zapier MCP"
527+
server1.alias = "zapier"
528+
529+
server2 = MagicMock()
530+
server2.server_id = "server2"
531+
server2.name = "Jira MCP"
532+
server2.alias = "jira"
533+
534+
# Mock manager
535+
mock_manager = MagicMock()
536+
mock_manager.get_allowed_mcp_servers = AsyncMock(
537+
return_value=["server1", "server2"]
538+
)
539+
mock_manager.get_mcp_server_by_id = (
540+
lambda server_id: server1 if server_id == "server1" else server2
541+
)
542+
543+
async def mock_get_tools_from_server(
544+
server, mcp_auth_header=None, add_prefix=True
545+
):
546+
tool = MagicMock()
547+
# When multiple servers, add_prefix should be True -> prefixed names
548+
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
549+
tool.description = "desc"
550+
tool.inputSchema = {}
551+
return [tool]
552+
553+
mock_manager._get_tools_from_server = mock_get_tools_from_server
554+
555+
with patch(
556+
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
557+
mock_manager,
558+
):
559+
tools = await _get_tools_from_mcp_servers(
560+
user_api_key_auth=user_api_key_auth,
561+
mcp_auth_header=None,
562+
mcp_servers=None,
563+
mcp_server_auth_headers=None,
564+
)
565+
566+
# Should be prefixed since multiple servers are allowed
567+
names = sorted([t.name for t in tools])
568+
assert names == ["jira-toolA", "zapier-toolA"]

tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,112 @@ async def mock_get_tools_from_server(server, mcp_auth_header=None):
420420
assert result["status"] == "healthy"
421421
assert result["tools_count"] == 1
422422

423+
@pytest.mark.asyncio
424+
async def test_get_tools_from_server_add_prefix(self):
425+
"""Verify _get_tools_from_server respects add_prefix True/False."""
426+
manager = MCPServerManager()
427+
428+
# Create a minimal server with alias used as prefix
429+
server = MCPServer(
430+
server_id="zapier",
431+
name="zapier",
432+
transport=MCPTransport.http,
433+
)
434+
435+
# Mock client creation and fetching tools
436+
manager._create_mcp_client = MagicMock(return_value=object())
437+
438+
# Tools returned upstream (unprefixed from provider)
439+
upstream_tool = MagicMock()
440+
upstream_tool.name = "send_email"
441+
upstream_tool.description = "Send an email"
442+
upstream_tool.inputSchema = {}
443+
444+
manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool])
445+
446+
# Case 1: add_prefix=True (default for multi-server) -> expect prefixed
447+
tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True)
448+
assert len(tools_prefixed) == 1
449+
assert tools_prefixed[0].name == "zapier-send_email"
450+
451+
# Case 2: add_prefix=False (single-server) -> expect unprefixed
452+
tools_unprefixed = await manager._get_tools_from_server(
453+
server, add_prefix=False
454+
)
455+
assert len(tools_unprefixed) == 1
456+
assert tools_unprefixed[0].name == "send_email"
457+
458+
def test_create_prefixed_tools_updates_mapping_for_both_forms(self):
459+
"""_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output."""
460+
manager = MCPServerManager()
461+
462+
server = MCPServer(
463+
server_id="jira",
464+
name="jira",
465+
transport=MCPTransport.http,
466+
)
467+
468+
# Input tools as would come from upstream
469+
t1 = MagicMock()
470+
t1.name = "create_issue"
471+
t1.description = ""
472+
t1.inputSchema = {}
473+
t2 = MagicMock()
474+
t2.name = "close_issue"
475+
t2.description = ""
476+
t2.inputSchema = {}
477+
478+
# Do not add prefix in returned objects
479+
out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False)
480+
481+
# Returned names should be unprefixed
482+
names = sorted([t.name for t in out_tools])
483+
assert names == ["close_issue", "create_issue"]
484+
485+
# Mapping should include both original and prefixed names -> resolves calls either way
486+
assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira"
487+
assert (
488+
manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira"
489+
)
490+
assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira"
491+
assert (
492+
manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira"
493+
)
494+
495+
def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self):
496+
"""After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server."""
497+
manager = MCPServerManager()
498+
499+
server = MCPServer(
500+
server_id="zapier",
501+
name="zapier",
502+
server_name="zapier",
503+
transport=MCPTransport.http,
504+
)
505+
506+
# Register server so resolution can find it
507+
manager.registry = {server.server_id: server}
508+
509+
# Populate mapping (add_prefix value doesn't matter for mapping population)
510+
base_tool = MagicMock()
511+
base_tool.name = "create_zap"
512+
base_tool.description = ""
513+
base_tool.inputSchema = {}
514+
_ = manager._create_prefixed_tools([base_tool], server, add_prefix=False)
515+
516+
# Unprefixed resolution
517+
resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap")
518+
print(resolved_server_unpref)
519+
assert resolved_server_unpref is not None
520+
assert resolved_server_unpref.server_id == server.server_id
521+
522+
# Prefixed resolution
523+
resolved_server_pref = manager._get_mcp_server_from_tool_name(
524+
"zapier-create_zap"
525+
)
526+
assert resolved_server_pref is not None
527+
assert resolved_server_pref.server_id == server.server_id
528+
423529

424530
if __name__ == "__main__":
425531
pytest.main([__file__])

0 commit comments

Comments
 (0)