Skip to content

Commit dae7d08

Browse files
committed
Revert "Revert "Merge pull request #14720 from uc4w6c/feat/remove-servername-prefix-mcp_tools""
This reverts commit a88d774.
1 parent 208cd5f commit dae7d08

File tree

5 files changed

+240
-9
lines changed

5 files changed

+240
-9
lines changed

litellm/proxy/_experimental/mcp_server/mcp_server_manager.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ async def _get_tools_from_server(
443443
server: MCPServer,
444444
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
445445
extra_headers: Optional[Dict[str, str]] = None,
446+
add_prefix: bool = True,
446447
) -> List[MCPTool]:
447448
"""
448449
Helper method to get tools from a single MCP server with prefixed names.
@@ -468,9 +469,11 @@ async def _get_tools_from_server(
468469

469470
tools = await self._fetch_tools_with_timeout(client, server.name)
470471

471-
prefixed_tools = self._create_prefixed_tools(tools, server)
472+
prefixed_or_original_tools = self._create_prefixed_tools(
473+
tools, server, add_prefix=add_prefix
474+
)
472475

473-
return prefixed_tools
476+
return prefixed_or_original_tools
474477

475478
except Exception as e:
476479
verbose_logger.warning(
@@ -539,7 +542,7 @@ async def _list_tools_task():
539542
return []
540543

541544
def _create_prefixed_tools(
542-
self, tools: List[MCPTool], server: MCPServer
545+
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
543546
) -> List[MCPTool]:
544547
"""
545548
Create prefixed tools and update tool mapping.
@@ -557,14 +560,16 @@ def _create_prefixed_tools(
557560
for tool in tools:
558561
prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix)
559562

560-
prefixed_tool = MCPTool(
561-
name=prefixed_name,
563+
name_to_use = prefixed_name if add_prefix else tool.name
564+
565+
tool_obj = MCPTool(
566+
name=name_to_use,
562567
description=tool.description,
563568
inputSchema=tool.inputSchema,
564569
)
565-
prefixed_tools.append(prefixed_tool)
570+
prefixed_tools.append(tool_obj)
566571

567-
# Update tool to server mapping with both original and prefixed names
572+
# Update tool to server mapping for resolution (support both forms)
568573
self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix
569574
self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix
570575

litellm/proxy/_experimental/mcp_server/rest_endpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ async def _get_tools_for_single_server(server, server_auth_header):
7373
tools = await global_mcp_server_manager._get_tools_from_server(
7474
server=server,
7575
mcp_auth_header=server_auth_header,
76+
add_prefix=False,
7677
)
7778
return _create_tool_response_objects(tools, server.mcp_info)
7879

litellm/proxy/_experimental/mcp_server/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ async def _get_tools_from_mcp_servers(
414414
allowed_mcp_servers=allowed_mcp_servers,
415415
)
416416

417+
# Decide whether to add prefix based on number of allowed servers
418+
add_prefix = not (len(allowed_mcp_servers) == 1)
419+
417420
# Get tools from each allowed server
418421
all_tools = []
419422
for server_id in allowed_mcp_servers:
@@ -448,6 +451,7 @@ async def _get_tools_from_mcp_servers(
448451
server=server,
449452
mcp_auth_header=server_auth_header,
450453
extra_headers=extra_headers,
454+
add_prefix=add_prefix,
451455
)
452456
all_tools.extend(filter_tools_by_allowed_tools(tools, server))
453457
verbose_logger.debug(

tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails():
103103
)
104104

105105
async def mock_get_tools_from_server(
106-
server, mcp_auth_header=None, extra_headers=None
106+
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
107107
):
108108
if server.name == "working_server":
109109
# Working server returns tools
@@ -187,7 +187,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing():
187187
)
188188

189189
async def mock_get_tools_from_server(
190-
server, mcp_auth_header=None, extra_headers=None
190+
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
191191
):
192192
# All servers fail
193193
raise Exception(f"Server {server.name} connection failed")
@@ -564,3 +564,120 @@ async def mock_fetch_tools_with_timeout(client, server_name):
564564
captured_client_args["extra_headers"]["Authorization"]
565565
== "Bearer github_oauth_token_12345"
566566
)
567+
568+
@pytest.mark.asyncio
569+
async def test_list_tools_single_server_unprefixed_names():
570+
"""When only one MCP server is allowed, list tools should return unprefixed names."""
571+
try:
572+
from litellm.proxy._experimental.mcp_server.server import (
573+
_get_tools_from_mcp_servers,
574+
set_auth_context,
575+
)
576+
except ImportError:
577+
pytest.skip("MCP server not available")
578+
579+
# Mock user auth
580+
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
581+
set_auth_context(user_api_key_auth)
582+
583+
# One allowed server
584+
server = MagicMock()
585+
server.server_id = "server1"
586+
server.name = "Zapier MCP"
587+
server.alias = "zapier"
588+
589+
# Mock manager: allow just one server and return a tool based on add_prefix flag
590+
mock_manager = MagicMock()
591+
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
592+
mock_manager.get_mcp_server_by_id = (
593+
lambda server_id: server if server_id == "server1" else None
594+
)
595+
596+
async def mock_get_tools_from_server(
597+
server, mcp_auth_header=None, extra_headers=None, add_prefix=False
598+
):
599+
tool = MagicMock()
600+
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
601+
tool.description = "desc"
602+
tool.inputSchema = {}
603+
return [tool]
604+
605+
mock_manager._get_tools_from_server = mock_get_tools_from_server
606+
607+
with patch(
608+
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
609+
mock_manager,
610+
):
611+
tools = await _get_tools_from_mcp_servers(
612+
user_api_key_auth=user_api_key_auth,
613+
mcp_auth_header=None,
614+
mcp_servers=None,
615+
mcp_server_auth_headers=None,
616+
)
617+
618+
# Should be unprefixed since only one server is allowed
619+
assert len(tools) == 1
620+
assert tools[0].name == "toolA"
621+
622+
623+
@pytest.mark.asyncio
624+
async def test_list_tools_multiple_servers_prefixed_names():
625+
"""When multiple MCP servers are allowed, list tools should return prefixed names."""
626+
try:
627+
from litellm.proxy._experimental.mcp_server.server import (
628+
_get_tools_from_mcp_servers,
629+
set_auth_context,
630+
)
631+
except ImportError:
632+
pytest.skip("MCP server not available")
633+
634+
# Mock user auth
635+
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
636+
set_auth_context(user_api_key_auth)
637+
638+
# Two allowed servers
639+
server1 = MagicMock()
640+
server1.server_id = "server1"
641+
server1.name = "Zapier MCP"
642+
server1.alias = "zapier"
643+
644+
server2 = MagicMock()
645+
server2.server_id = "server2"
646+
server2.name = "Jira MCP"
647+
server2.alias = "jira"
648+
649+
# Mock manager
650+
mock_manager = MagicMock()
651+
mock_manager.get_allowed_mcp_servers = AsyncMock(
652+
return_value=["server1", "server2"]
653+
)
654+
mock_manager.get_mcp_server_by_id = (
655+
lambda server_id: server1 if server_id == "server1" else server2
656+
)
657+
658+
async def mock_get_tools_from_server(
659+
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
660+
):
661+
tool = MagicMock()
662+
# When multiple servers, add_prefix should be True -> prefixed names
663+
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
664+
tool.description = "desc"
665+
tool.inputSchema = {}
666+
return [tool]
667+
668+
mock_manager._get_tools_from_server = mock_get_tools_from_server
669+
670+
with patch(
671+
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
672+
mock_manager,
673+
):
674+
tools = await _get_tools_from_mcp_servers(
675+
user_api_key_auth=user_api_key_auth,
676+
mcp_auth_header=None,
677+
mcp_servers=None,
678+
mcp_server_auth_headers=None,
679+
)
680+
681+
# Should be prefixed since multiple servers are allowed
682+
names = sorted([t.name for t in tools])
683+
assert names == ["jira-toolA", "zapier-toolA"]

tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,110 @@ async def test_pre_call_tool_check_allowed_tools_takes_precedence(self):
654654
"Tool tool3 is not allowed for server test-server"
655655
in exc_info.value.detail["error"]
656656
)
657+
async def test_get_tools_from_server_add_prefix(self):
658+
"""Verify _get_tools_from_server respects add_prefix True/False."""
659+
manager = MCPServerManager()
660+
661+
# Create a minimal server with alias used as prefix
662+
server = MCPServer(
663+
server_id="zapier",
664+
name="zapier",
665+
transport=MCPTransport.http,
666+
)
667+
668+
# Mock client creation and fetching tools
669+
manager._create_mcp_client = MagicMock(return_value=object())
670+
671+
# Tools returned upstream (unprefixed from provider)
672+
upstream_tool = MagicMock()
673+
upstream_tool.name = "send_email"
674+
upstream_tool.description = "Send an email"
675+
upstream_tool.inputSchema = {}
676+
677+
manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool])
678+
679+
# Case 1: add_prefix=True (default for multi-server) -> expect prefixed
680+
tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True)
681+
assert len(tools_prefixed) == 1
682+
assert tools_prefixed[0].name == "zapier-send_email"
683+
684+
# Case 2: add_prefix=False (single-server) -> expect unprefixed
685+
tools_unprefixed = await manager._get_tools_from_server(
686+
server, add_prefix=False
687+
)
688+
assert len(tools_unprefixed) == 1
689+
assert tools_unprefixed[0].name == "send_email"
690+
691+
def test_create_prefixed_tools_updates_mapping_for_both_forms(self):
692+
"""_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output."""
693+
manager = MCPServerManager()
694+
695+
server = MCPServer(
696+
server_id="jira",
697+
name="jira",
698+
transport=MCPTransport.http,
699+
)
700+
701+
# Input tools as would come from upstream
702+
t1 = MagicMock()
703+
t1.name = "create_issue"
704+
t1.description = ""
705+
t1.inputSchema = {}
706+
t2 = MagicMock()
707+
t2.name = "close_issue"
708+
t2.description = ""
709+
t2.inputSchema = {}
710+
711+
# Do not add prefix in returned objects
712+
out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False)
713+
714+
# Returned names should be unprefixed
715+
names = sorted([t.name for t in out_tools])
716+
assert names == ["close_issue", "create_issue"]
717+
718+
# Mapping should include both original and prefixed names -> resolves calls either way
719+
assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira"
720+
assert (
721+
manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira"
722+
)
723+
assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira"
724+
assert (
725+
manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira"
726+
)
727+
728+
def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self):
729+
"""After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server."""
730+
manager = MCPServerManager()
731+
732+
server = MCPServer(
733+
server_id="zapier",
734+
name="zapier",
735+
server_name="zapier",
736+
transport=MCPTransport.http,
737+
)
738+
739+
# Register server so resolution can find it
740+
manager.registry = {server.server_id: server}
741+
742+
# Populate mapping (add_prefix value doesn't matter for mapping population)
743+
base_tool = MagicMock()
744+
base_tool.name = "create_zap"
745+
base_tool.description = ""
746+
base_tool.inputSchema = {}
747+
_ = manager._create_prefixed_tools([base_tool], server, add_prefix=False)
748+
749+
# Unprefixed resolution
750+
resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap")
751+
print(resolved_server_unpref)
752+
assert resolved_server_unpref is not None
753+
assert resolved_server_unpref.server_id == server.server_id
754+
755+
# Prefixed resolution
756+
resolved_server_pref = manager._get_mcp_server_from_tool_name(
757+
"zapier-create_zap"
758+
)
759+
assert resolved_server_pref is not None
760+
assert resolved_server_pref.server_id == server.server_id
657761

658762

659763
if __name__ == "__main__":

0 commit comments

Comments
 (0)