|
10 | 10 | import datetime
|
11 | 11 | import hashlib
|
12 | 12 | import json
|
13 |
| -from typing import Any, Dict, List, Optional, Union, cast |
| 13 | +from typing import Any, Dict, List, Optional, Set, Union, cast |
14 | 14 |
|
15 | 15 | from fastapi import HTTPException
|
16 | 16 | from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
@@ -240,50 +240,64 @@ def remove_server(self, mcp_server: LiteLLM_MCPServerTable):
|
240 | 240 | )
|
241 | 241 |
|
242 | 242 | def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
|
243 |
| - if mcp_server.server_id not in self.get_registry(): |
244 |
| - _mcp_info: MCPInfo = mcp_server.mcp_info or {} |
245 |
| - # Use helper to deserialize environment dictionary |
246 |
| - # Safely access env field which may not exist on Prisma model objects |
247 |
| - env_data = getattr(mcp_server, "env", None) |
248 |
| - env_dict = _deserialize_env_dict(env_data) |
249 |
| - # Use alias for name if present, else server_name |
250 |
| - name_for_prefix = ( |
251 |
| - mcp_server.alias or mcp_server.server_name or mcp_server.server_id |
252 |
| - ) |
253 |
| - # Preserve all custom fields from database while setting defaults for core fields |
254 |
| - mcp_info: MCPInfo = _mcp_info.copy() |
255 |
| - # Set default values for core fields if not present |
256 |
| - if "server_name" not in mcp_info: |
257 |
| - mcp_info["server_name"] = mcp_server.server_name or mcp_server.server_id |
258 |
| - if "description" not in mcp_info and mcp_server.description: |
259 |
| - mcp_info["description"] = mcp_server.description |
| 243 | + try: |
| 244 | + if mcp_server.server_id not in self.get_registry(): |
| 245 | + _mcp_info: MCPInfo = mcp_server.mcp_info or {} |
| 246 | + # Use helper to deserialize environment dictionary |
| 247 | + # Safely access env field which may not exist on Prisma model objects |
| 248 | + env_data = getattr(mcp_server, "env", None) |
| 249 | + env_dict = _deserialize_env_dict(env_data) |
| 250 | + # Use alias for name if present, else server_name |
| 251 | + name_for_prefix = ( |
| 252 | + mcp_server.alias or mcp_server.server_name or mcp_server.server_id |
| 253 | + ) |
| 254 | + # Preserve all custom fields from database while setting defaults for core fields |
| 255 | + mcp_info: MCPInfo = _mcp_info.copy() |
| 256 | + # Set default values for core fields if not present |
| 257 | + if "server_name" not in mcp_info: |
| 258 | + mcp_info["server_name"] = ( |
| 259 | + mcp_server.server_name or mcp_server.server_id |
| 260 | + ) |
| 261 | + if "description" not in mcp_info and mcp_server.description: |
| 262 | + mcp_info["description"] = mcp_server.description |
| 263 | + |
| 264 | + new_server = MCPServer( |
| 265 | + server_id=mcp_server.server_id, |
| 266 | + name=name_for_prefix, |
| 267 | + alias=getattr(mcp_server, "alias", None), |
| 268 | + server_name=getattr(mcp_server, "server_name", None), |
| 269 | + url=mcp_server.url, |
| 270 | + transport=cast(MCPTransportType, mcp_server.transport), |
| 271 | + auth_type=cast(MCPAuthType, mcp_server.auth_type), |
| 272 | + mcp_info=mcp_info, |
| 273 | + extra_headers=getattr(mcp_server, "extra_headers", None), |
| 274 | + # oauth specific fields |
| 275 | + client_id=getattr(mcp_server, "client_id", None), |
| 276 | + client_secret=getattr(mcp_server, "client_secret", None), |
| 277 | + scopes=getattr(mcp_server, "scopes", None), |
| 278 | + authorization_url=getattr(mcp_server, "authorization_url", None), |
| 279 | + token_url=getattr(mcp_server, "token_url", None), |
| 280 | + # Stdio-specific fields |
| 281 | + command=getattr(mcp_server, "command", None), |
| 282 | + args=getattr(mcp_server, "args", None) or [], |
| 283 | + env=env_dict, |
| 284 | + access_groups=getattr(mcp_server, "mcp_access_groups", None), |
| 285 | + allowed_tools=getattr(mcp_server, "allowed_tools", None), |
| 286 | + disallowed_tools=getattr(mcp_server, "disallowed_tools", None), |
| 287 | + ) |
| 288 | + self.registry[mcp_server.server_id] = new_server |
| 289 | + verbose_logger.debug(f"Added MCP Server: {name_for_prefix}") |
260 | 290 |
|
261 |
| - new_server = MCPServer( |
262 |
| - server_id=mcp_server.server_id, |
263 |
| - name=name_for_prefix, |
264 |
| - alias=getattr(mcp_server, "alias", None), |
265 |
| - server_name=getattr(mcp_server, "server_name", None), |
266 |
| - url=mcp_server.url, |
267 |
| - transport=cast(MCPTransportType, mcp_server.transport), |
268 |
| - auth_type=cast(MCPAuthType, mcp_server.auth_type), |
269 |
| - mcp_info=mcp_info, |
270 |
| - extra_headers=getattr(mcp_server, "extra_headers", None), |
271 |
| - # oauth specific fields |
272 |
| - client_id=getattr(mcp_server, "client_id", None), |
273 |
| - client_secret=getattr(mcp_server, "client_secret", None), |
274 |
| - scopes=getattr(mcp_server, "scopes", None), |
275 |
| - authorization_url=getattr(mcp_server, "authorization_url", None), |
276 |
| - token_url=getattr(mcp_server, "token_url", None), |
277 |
| - # Stdio-specific fields |
278 |
| - command=getattr(mcp_server, "command", None), |
279 |
| - args=getattr(mcp_server, "args", None) or [], |
280 |
| - env=env_dict, |
281 |
| - access_groups=getattr(mcp_server, "mcp_access_groups", None), |
282 |
| - allowed_tools=getattr(mcp_server, "allowed_tools", None), |
283 |
| - disallowed_tools=getattr(mcp_server, "disallowed_tools", None), |
284 |
| - ) |
285 |
| - self.registry[mcp_server.server_id] = new_server |
286 |
| - verbose_logger.debug(f"Added MCP Server: {name_for_prefix}") |
| 291 | + except Exception as e: |
| 292 | + verbose_logger.debug(f"Failed to add MCP server: {str(e)}") |
| 293 | + raise e |
| 294 | + |
| 295 | + def get_all_mcp_server_ids(self) -> Set[str]: |
| 296 | + """ |
| 297 | + Get all MCP server IDs |
| 298 | + """ |
| 299 | + all_servers = list(self.get_registry().values()) |
| 300 | + return {server.server_id for server in all_servers} |
287 | 301 |
|
288 | 302 | async def get_allowed_mcp_servers(
|
289 | 303 | self, user_api_key_auth: Optional[UserAPIKeyAuth] = None
|
@@ -1118,25 +1132,23 @@ async def get_all_mcp_servers_with_health_and_teams(
|
1118 | 1132 | if _server_id in allowed_server_ids:
|
1119 | 1133 | list_mcp_servers.append(
|
1120 | 1134 | LiteLLM_MCPServerTable(
|
1121 |
| - server_id=_server_id, |
1122 |
| - server_name=_server_config.name, |
1123 |
| - alias=_server_config.alias, |
1124 |
| - url=_server_config.url, |
1125 |
| - transport=_server_config.transport, |
1126 |
| - auth_type=_server_config.auth_type, |
1127 |
| - created_at=datetime.datetime.now(), |
1128 |
| - updated_at=datetime.datetime.now(), |
1129 |
| - description=( |
1130 |
| - _server_config.mcp_info.get("description") |
1131 |
| - if _server_config.mcp_info |
1132 |
| - else None |
1133 |
| - ), |
1134 |
| - mcp_info=_server_config.mcp_info, |
1135 |
| - mcp_access_groups=_server_config.access_groups or [], |
1136 |
| - # Stdio-specific fields |
1137 |
| - command=getattr(_server_config, "command", None), |
1138 |
| - args=getattr(_server_config, "args", None) or [], |
1139 |
| - env=getattr(_server_config, "env", None) or {}, |
| 1135 | + **{ |
| 1136 | + **_server_config.model_dump(), |
| 1137 | + "created_at": datetime.datetime.now(), |
| 1138 | + "updated_at": datetime.datetime.now(), |
| 1139 | + "description": ( |
| 1140 | + _server_config.mcp_info.get("description") |
| 1141 | + if _server_config.mcp_info |
| 1142 | + else None |
| 1143 | + ), |
| 1144 | + "allowed_tools": _server_config.allowed_tools or [], |
| 1145 | + "mcp_info": _server_config.mcp_info, |
| 1146 | + "mcp_access_groups": _server_config.access_groups or [], |
| 1147 | + "extra_headers": _server_config.extra_headers or [], |
| 1148 | + "command": getattr(_server_config, "command", None), |
| 1149 | + "args": getattr(_server_config, "args", None) or [], |
| 1150 | + "env": getattr(_server_config, "env", None) or {}, |
| 1151 | + } |
1140 | 1152 | )
|
1141 | 1153 | )
|
1142 | 1154 |
|
@@ -1176,44 +1188,19 @@ async def get_all_mcp_servers_with_health_and_teams(
|
1176 | 1188 | }
|
1177 | 1189 | )
|
1178 | 1190 |
|
1179 |
| - # Map servers to their teams and return with health data |
1180 |
| - from typing import cast |
1181 |
| - |
1182 |
| - return [ |
1183 |
| - LiteLLM_MCPServerTable( |
1184 |
| - server_id=server.server_id, |
1185 |
| - server_name=server.server_name, |
1186 |
| - alias=server.alias, |
1187 |
| - description=server.description, |
1188 |
| - url=server.url, |
1189 |
| - transport=server.transport, |
1190 |
| - auth_type=server.auth_type, |
1191 |
| - created_at=server.created_at, |
1192 |
| - created_by=server.created_by, |
1193 |
| - updated_at=server.updated_at, |
1194 |
| - updated_by=server.updated_by, |
1195 |
| - mcp_access_groups=( |
1196 |
| - server.mcp_access_groups |
1197 |
| - if server.mcp_access_groups is not None |
1198 |
| - else [] |
1199 |
| - ), |
1200 |
| - allowed_tools=( |
1201 |
| - server.allowed_tools |
1202 |
| - if server.allowed_tools is not None |
1203 |
| - else [] |
1204 |
| - ), |
1205 |
| - mcp_info=server.mcp_info, |
1206 |
| - teams=cast( |
1207 |
| - List[Dict[str, str | None]], |
1208 |
| - server_to_teams_map.get(server.server_id, []), |
1209 |
| - ), |
1210 |
| - # Stdio-specific fields |
1211 |
| - command=getattr(server, "command", None), |
1212 |
| - args=getattr(server, "args", None) or [], |
1213 |
| - env=getattr(server, "env", None) or {}, |
1214 |
| - ) |
1215 |
| - for server in list_mcp_servers |
1216 |
| - ] |
| 1191 | + ## mark invalid servers w/ reason for being invalid |
| 1192 | + valid_server_ids = self.get_all_mcp_server_ids() |
| 1193 | + for server in list_mcp_servers: |
| 1194 | + if server.server_id not in valid_server_ids: |
| 1195 | + server.status = "unhealthy" |
| 1196 | + ## try adding server to registry to get error |
| 1197 | + try: |
| 1198 | + self.add_update_server(server) |
| 1199 | + except Exception as e: |
| 1200 | + server.health_check_error = str(e) |
| 1201 | + server.health_check_error = "Server is not in in memory registry yet. This could be a temporary sync issue." |
| 1202 | + |
| 1203 | + return list_mcp_servers |
1217 | 1204 |
|
1218 | 1205 | async def reload_servers_from_database(self):
|
1219 | 1206 | """
|
|
0 commit comments