Skip to content

Commit a7152f3

Browse files
Merge pull request #15185 from BerriAI/litellm_dev_10_03_2025_p1
(MCP - feat) UI - show health status of MCP servers, allow setting extra headers on the UI, allow editing allowed tools on the UI
2 parents 171b6b1 + 2d1000c commit a7152f3

File tree

16 files changed

+423
-166
lines changed

16 files changed

+423
-166
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-- AlterTable
2+
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "allowed_tools" TEXT[] DEFAULT ARRAY[]::TEXT[];
3+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-- AlterTable
2+
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "extra_headers" TEXT[] DEFAULT ARRAY[]::TEXT[];
3+

litellm-proxy-extras/litellm_proxy_extras/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ model LiteLLM_MCPServerTable {
179179
mcp_info Json? @default("{}")
180180
mcp_access_groups String[]
181181
allowed_tools String[] @default([])
182+
extra_headers String[] @default([])
182183
// Health check status
183184
status String? @default("unknown")
184185
last_health_check DateTime?

litellm/proxy/_experimental/mcp_server/db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from litellm._uuid import uuid
21
from typing import Any, Dict, Iterable, List, Optional, Set, Union
32

43
from litellm._logging import verbose_proxy_logger
4+
from litellm._uuid import uuid
55
from litellm.proxy._types import (
66
LiteLLM_MCPServerTable,
77
LiteLLM_ObjectPermissionTable,
@@ -30,7 +30,7 @@ def _prepare_mcp_server_data(
3030
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
3131

3232
# Convert model to dict
33-
data_dict = data.model_dump()
33+
data_dict = data.model_dump(exclude_none=True)
3434
# Ensure alias is always present in the dict (even if None)
3535
if "alias" not in data_dict:
3636
data_dict["alias"] = getattr(data, "alias", None)

litellm/proxy/_experimental/mcp_server/mcp_server_manager.py

Lines changed: 88 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import datetime
1111
import hashlib
1212
import json
13-
from typing import Any, Dict, List, Optional, Union, cast
13+
from typing import Any, Dict, List, Optional, Set, Union, cast
1414

1515
from fastapi import HTTPException
1616
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
@@ -240,50 +240,64 @@ def remove_server(self, mcp_server: LiteLLM_MCPServerTable):
240240
)
241241

242242
def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
243-
if mcp_server.server_id not in self.get_registry():
244-
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
245-
# Use helper to deserialize environment dictionary
246-
# Safely access env field which may not exist on Prisma model objects
247-
env_data = getattr(mcp_server, "env", None)
248-
env_dict = _deserialize_env_dict(env_data)
249-
# Use alias for name if present, else server_name
250-
name_for_prefix = (
251-
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
252-
)
253-
# Preserve all custom fields from database while setting defaults for core fields
254-
mcp_info: MCPInfo = _mcp_info.copy()
255-
# Set default values for core fields if not present
256-
if "server_name" not in mcp_info:
257-
mcp_info["server_name"] = mcp_server.server_name or mcp_server.server_id
258-
if "description" not in mcp_info and mcp_server.description:
259-
mcp_info["description"] = mcp_server.description
243+
try:
244+
if mcp_server.server_id not in self.get_registry():
245+
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
246+
# Use helper to deserialize environment dictionary
247+
# Safely access env field which may not exist on Prisma model objects
248+
env_data = getattr(mcp_server, "env", None)
249+
env_dict = _deserialize_env_dict(env_data)
250+
# Use alias for name if present, else server_name
251+
name_for_prefix = (
252+
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
253+
)
254+
# Preserve all custom fields from database while setting defaults for core fields
255+
mcp_info: MCPInfo = _mcp_info.copy()
256+
# Set default values for core fields if not present
257+
if "server_name" not in mcp_info:
258+
mcp_info["server_name"] = (
259+
mcp_server.server_name or mcp_server.server_id
260+
)
261+
if "description" not in mcp_info and mcp_server.description:
262+
mcp_info["description"] = mcp_server.description
263+
264+
new_server = MCPServer(
265+
server_id=mcp_server.server_id,
266+
name=name_for_prefix,
267+
alias=getattr(mcp_server, "alias", None),
268+
server_name=getattr(mcp_server, "server_name", None),
269+
url=mcp_server.url,
270+
transport=cast(MCPTransportType, mcp_server.transport),
271+
auth_type=cast(MCPAuthType, mcp_server.auth_type),
272+
mcp_info=mcp_info,
273+
extra_headers=getattr(mcp_server, "extra_headers", None),
274+
# oauth specific fields
275+
client_id=getattr(mcp_server, "client_id", None),
276+
client_secret=getattr(mcp_server, "client_secret", None),
277+
scopes=getattr(mcp_server, "scopes", None),
278+
authorization_url=getattr(mcp_server, "authorization_url", None),
279+
token_url=getattr(mcp_server, "token_url", None),
280+
# Stdio-specific fields
281+
command=getattr(mcp_server, "command", None),
282+
args=getattr(mcp_server, "args", None) or [],
283+
env=env_dict,
284+
access_groups=getattr(mcp_server, "mcp_access_groups", None),
285+
allowed_tools=getattr(mcp_server, "allowed_tools", None),
286+
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
287+
)
288+
self.registry[mcp_server.server_id] = new_server
289+
verbose_logger.debug(f"Added MCP Server: {name_for_prefix}")
260290

261-
new_server = MCPServer(
262-
server_id=mcp_server.server_id,
263-
name=name_for_prefix,
264-
alias=getattr(mcp_server, "alias", None),
265-
server_name=getattr(mcp_server, "server_name", None),
266-
url=mcp_server.url,
267-
transport=cast(MCPTransportType, mcp_server.transport),
268-
auth_type=cast(MCPAuthType, mcp_server.auth_type),
269-
mcp_info=mcp_info,
270-
extra_headers=getattr(mcp_server, "extra_headers", None),
271-
# oauth specific fields
272-
client_id=getattr(mcp_server, "client_id", None),
273-
client_secret=getattr(mcp_server, "client_secret", None),
274-
scopes=getattr(mcp_server, "scopes", None),
275-
authorization_url=getattr(mcp_server, "authorization_url", None),
276-
token_url=getattr(mcp_server, "token_url", None),
277-
# Stdio-specific fields
278-
command=getattr(mcp_server, "command", None),
279-
args=getattr(mcp_server, "args", None) or [],
280-
env=env_dict,
281-
access_groups=getattr(mcp_server, "mcp_access_groups", None),
282-
allowed_tools=getattr(mcp_server, "allowed_tools", None),
283-
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
284-
)
285-
self.registry[mcp_server.server_id] = new_server
286-
verbose_logger.debug(f"Added MCP Server: {name_for_prefix}")
291+
except Exception as e:
292+
verbose_logger.debug(f"Failed to add MCP server: {str(e)}")
293+
raise e
294+
295+
def get_all_mcp_server_ids(self) -> Set[str]:
296+
"""
297+
Get all MCP server IDs
298+
"""
299+
all_servers = list(self.get_registry().values())
300+
return {server.server_id for server in all_servers}
287301

288302
async def get_allowed_mcp_servers(
289303
self, user_api_key_auth: Optional[UserAPIKeyAuth] = None
@@ -1118,25 +1132,23 @@ async def get_all_mcp_servers_with_health_and_teams(
11181132
if _server_id in allowed_server_ids:
11191133
list_mcp_servers.append(
11201134
LiteLLM_MCPServerTable(
1121-
server_id=_server_id,
1122-
server_name=_server_config.name,
1123-
alias=_server_config.alias,
1124-
url=_server_config.url,
1125-
transport=_server_config.transport,
1126-
auth_type=_server_config.auth_type,
1127-
created_at=datetime.datetime.now(),
1128-
updated_at=datetime.datetime.now(),
1129-
description=(
1130-
_server_config.mcp_info.get("description")
1131-
if _server_config.mcp_info
1132-
else None
1133-
),
1134-
mcp_info=_server_config.mcp_info,
1135-
mcp_access_groups=_server_config.access_groups or [],
1136-
# Stdio-specific fields
1137-
command=getattr(_server_config, "command", None),
1138-
args=getattr(_server_config, "args", None) or [],
1139-
env=getattr(_server_config, "env", None) or {},
1135+
**{
1136+
**_server_config.model_dump(),
1137+
"created_at": datetime.datetime.now(),
1138+
"updated_at": datetime.datetime.now(),
1139+
"description": (
1140+
_server_config.mcp_info.get("description")
1141+
if _server_config.mcp_info
1142+
else None
1143+
),
1144+
"allowed_tools": _server_config.allowed_tools or [],
1145+
"mcp_info": _server_config.mcp_info,
1146+
"mcp_access_groups": _server_config.access_groups or [],
1147+
"extra_headers": _server_config.extra_headers or [],
1148+
"command": getattr(_server_config, "command", None),
1149+
"args": getattr(_server_config, "args", None) or [],
1150+
"env": getattr(_server_config, "env", None) or {},
1151+
}
11401152
)
11411153
)
11421154

@@ -1176,44 +1188,19 @@ async def get_all_mcp_servers_with_health_and_teams(
11761188
}
11771189
)
11781190

1179-
# Map servers to their teams and return with health data
1180-
from typing import cast
1181-
1182-
return [
1183-
LiteLLM_MCPServerTable(
1184-
server_id=server.server_id,
1185-
server_name=server.server_name,
1186-
alias=server.alias,
1187-
description=server.description,
1188-
url=server.url,
1189-
transport=server.transport,
1190-
auth_type=server.auth_type,
1191-
created_at=server.created_at,
1192-
created_by=server.created_by,
1193-
updated_at=server.updated_at,
1194-
updated_by=server.updated_by,
1195-
mcp_access_groups=(
1196-
server.mcp_access_groups
1197-
if server.mcp_access_groups is not None
1198-
else []
1199-
),
1200-
allowed_tools=(
1201-
server.allowed_tools
1202-
if server.allowed_tools is not None
1203-
else []
1204-
),
1205-
mcp_info=server.mcp_info,
1206-
teams=cast(
1207-
List[Dict[str, str | None]],
1208-
server_to_teams_map.get(server.server_id, []),
1209-
),
1210-
# Stdio-specific fields
1211-
command=getattr(server, "command", None),
1212-
args=getattr(server, "args", None) or [],
1213-
env=getattr(server, "env", None) or {},
1214-
)
1215-
for server in list_mcp_servers
1216-
]
1191+
## mark invalid servers w/ reason for being invalid
1192+
valid_server_ids = self.get_all_mcp_server_ids()
1193+
for server in list_mcp_servers:
1194+
if server.server_id not in valid_server_ids:
1195+
server.status = "unhealthy"
1196+
## try adding server to registry to get error
1197+
try:
1198+
self.add_update_server(server)
1199+
except Exception as e:
1200+
server.health_check_error = str(e)
1201+
server.health_check_error = "Server is not in in memory registry yet. This could be a temporary sync issue."
1202+
1203+
return list_mcp_servers
12171204

12181205
async def reload_servers_from_database(self):
12191206
"""

litellm/proxy/_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
925925
mcp_info: Optional[MCPInfo] = None
926926
mcp_access_groups: List[str] = Field(default_factory=list)
927927
allowed_tools: Optional[List[str]] = None
928+
extra_headers: Optional[List[str]] = None
928929
# Stdio-specific fields
929930
command: Optional[str] = None
930931
args: List[str] = Field(default_factory=list)
@@ -994,9 +995,10 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
994995
teams: List[Dict[str, Optional[str]]] = Field(default_factory=list)
995996
mcp_access_groups: List[str] = Field(default_factory=list)
996997
allowed_tools: List[str] = Field(default_factory=list)
998+
extra_headers: List[str] = Field(default_factory=list)
997999
mcp_info: Optional[MCPInfo] = None
9981000
# Health check status
999-
status: Optional[str] = Field(
1001+
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
10001002
default="unknown",
10011003
description="Health status: 'healthy', 'unhealthy', 'unknown'",
10021004
)

litellm/proxy/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ model LiteLLM_MCPServerTable {
179179
mcp_info Json? @default("{}")
180180
mcp_access_groups String[]
181181
allowed_tools String[] @default([])
182+
extra_headers String[] @default([])
182183
// Health check status
183184
status String? @default("unknown")
184185
last_health_check DateTime?

schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ model LiteLLM_MCPServerTable {
179179
mcp_info Json? @default("{}")
180180
mcp_access_groups String[]
181181
allowed_tools String[] @default([])
182+
extra_headers String[] @default([])
182183
// Health check status
183184
status String? @default("unknown")
184185
last_health_check DateTime?

tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ async def test_pre_call_tool_check_allowed_tools_takes_precedence(self):
654654
"Tool tool3 is not allowed for server test-server"
655655
in exc_info.value.detail["error"]
656656
)
657+
657658
async def test_get_tools_from_server_add_prefix(self):
658659
"""Verify _get_tools_from_server respects add_prefix True/False."""
659660
manager = MCPServerManager()
@@ -909,6 +910,39 @@ async def test_rest_endpoint_shows_all_when_allowed_tools_is_empty_list(self):
909910
assert "tool_1" in tool_names
910911
assert "tool_2" in tool_names
911912

913+
def test_add_db_mcp_server_to_registry(self):
914+
"""Test that add_db_mcp_server_to_registry adds a MCP server to the registry"""
915+
manager = MCPServerManager()
916+
server = LiteLLM_MCPServerTable(
917+
**{
918+
"server_id": "4c679a81-acd9-4954-9f84-30b739362498",
919+
"server_name": "edc_mcp_server",
920+
"alias": "edc_mcp_server",
921+
"description": None,
922+
"url": "fake_mcp_url",
923+
"transport": "http",
924+
"auth_type": "none",
925+
"created_at": "2025-09-30T08:28:31.353000Z",
926+
"created_by": "a1248959",
927+
"updated_at": "2025-09-30T08:28:31.353000Z",
928+
"updated_by": "a1248959",
929+
"teams": [],
930+
"mcp_access_groups": [],
931+
"mcp_info": {
932+
"server_name": "edc_mcp_server",
933+
"mcp_server_cost_info": None,
934+
},
935+
"status": "unknown",
936+
"last_health_check": None,
937+
"health_check_error": None,
938+
"command": None,
939+
"args": [],
940+
"env": {},
941+
},
942+
)
943+
manager.add_update_server(server)
944+
assert server.server_id in manager.get_registry()
945+
912946

913947
if __name__ == "__main__":
914948
pytest.main([__file__])

0 commit comments

Comments
 (0)