Skip to content

Commit 3e9540f

Browse files
authored
feat: enable custom fields in mcp_info configuration (#14794)
Allow proxy admins to add arbitrary metadata fields to MCP servers in config.yaml under mcp_servers.<server>.mcp_info, similar to how model_info already works. Changes: - Changed MCPInfo from TypedDict to Dict[str, Any] for flexibility - Updated load_servers_from_config to preserve all custom fields - Updated add_update_server to handle arbitrary fields from database - Added comprehensive unit tests covering all scenarios
1 parent 65532e5 commit 3e9540f

File tree

3 files changed

+231
-20
lines changed

3 files changed

+231
-20
lines changed

litellm/proxy/_experimental/mcp_server/mcp_server_manager.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,13 @@ def load_servers_from_config(
121121
for server_name, server_config in mcp_servers_config.items():
122122
validate_mcp_server_name(server_name)
123123
_mcp_info: Dict[str, Any] = server_config.get("mcp_info", None) or {}
124-
# Convert Dict[str, Any] to MCPInfo properly
125-
mcp_info: MCPInfo = {
126-
"server_name": _mcp_info.get("server_name", server_name),
127-
"description": _mcp_info.get(
128-
"description", server_config.get("description", None)
129-
),
130-
"logo_url": _mcp_info.get("logo_url", None),
131-
"mcp_server_cost_info": _mcp_info.get("mcp_server_cost_info", None),
132-
}
124+
# Preserve all custom fields from config while setting defaults for core fields
125+
mcp_info: MCPInfo = _mcp_info.copy()
126+
# Set default values for core fields if not present
127+
if "server_name" not in mcp_info:
128+
mcp_info["server_name"] = server_name
129+
if "description" not in mcp_info and server_config.get("description"):
130+
mcp_info["description"] = server_config.get("description")
133131

134132
# Use alias for name if present, else server_name
135133
alias = server_config.get("alias", None)
@@ -243,6 +241,14 @@ def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
243241
name_for_prefix = (
244242
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
245243
)
244+
# Preserve all custom fields from database while setting defaults for core fields
245+
mcp_info: MCPInfo = _mcp_info.copy()
246+
# Set default values for core fields if not present
247+
if "server_name" not in mcp_info:
248+
mcp_info["server_name"] = mcp_server.server_name or mcp_server.server_id
249+
if "description" not in mcp_info and mcp_server.description:
250+
mcp_info["description"] = mcp_server.description
251+
246252
new_server = MCPServer(
247253
server_id=mcp_server.server_id,
248254
name=name_for_prefix,
@@ -251,11 +257,7 @@ def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
251257
url=mcp_server.url,
252258
transport=cast(MCPTransportType, mcp_server.transport),
253259
auth_type=cast(MCPAuthType, mcp_server.auth_type),
254-
mcp_info=MCPInfo(
255-
server_name=mcp_server.server_name or mcp_server.server_id,
256-
description=mcp_server.description,
257-
mcp_server_cost_info=_mcp_info.get("mcp_server_cost_info", None),
258-
),
260+
mcp_info=mcp_info,
259261
# Stdio-specific fields
260262
command=getattr(mcp_server, "command", None),
261263
args=getattr(mcp_server, "args", None) or [],

litellm/types/mcp_server/mcp_server_manager.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from pydantic import BaseModel, ConfigDict
44
from typing_extensions import TypedDict
@@ -7,11 +7,8 @@
77
from litellm.types.mcp import MCPServerCostInfo
88

99

10-
class MCPInfo(TypedDict, total=False):
11-
server_name: str
12-
description: Optional[str]
13-
logo_url: Optional[str]
14-
mcp_server_cost_info: Optional[MCPServerCostInfo]
10+
# MCPInfo now allows arbitrary additional fields for custom metadata
11+
MCPInfo = Dict[str, Any]
1512

1613

1714
class MCPServer(BaseModel):
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
Test suite for MCP server custom fields functionality.
3+
4+
Tests that mcp_info can accept arbitrary custom fields in addition to predefined ones.
5+
"""
6+
import pytest
7+
import sys
8+
import os
9+
from unittest.mock import Mock, patch
10+
from typing import Dict, Any
11+
12+
# Add the path to find the modules
13+
sys.path.insert(
14+
0, os.path.abspath("../../../..")
15+
) # Adjust the path as needed
16+
17+
from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager
18+
from litellm.types.mcp import MCPAuth
19+
from litellm.proxy._types import LiteLLM_MCPServerTable
20+
21+
22+
class TestMCPCustomFields:
23+
"""Test custom fields functionality in MCP server configuration."""
24+
25+
def test_custom_fields_preserved_from_config(self):
26+
"""Test that custom fields in mcp_info are preserved when loading from config."""
27+
manager = MCPServerManager()
28+
29+
# Mock config with custom fields
30+
mock_config = {
31+
"test_server": {
32+
"url": "http://localhost:3000",
33+
"transport": "http",
34+
"auth_type": "bearer_token",
35+
"authentication_token": "test-token",
36+
"mcp_info": {
37+
"server_name": "Test Server",
38+
"description": "A test server",
39+
"custom_field_1": "custom_value_1",
40+
"custom_field_2": {"nested": "value"},
41+
"custom_field_3": ["list", "values"],
42+
"priority": 10,
43+
"tags": ["production", "api"]
44+
}
45+
}
46+
}
47+
48+
# Load servers from config
49+
manager.load_servers_from_config(mock_config)
50+
51+
# Get the loaded server
52+
servers = list(manager.config_mcp_servers.values())
53+
assert len(servers) == 1
54+
55+
server = servers[0]
56+
mcp_info = server.mcp_info
57+
58+
# Verify standard fields are preserved
59+
assert mcp_info["server_name"] == "Test Server"
60+
assert mcp_info["description"] == "A test server"
61+
62+
# Verify custom fields are preserved
63+
assert mcp_info["custom_field_1"] == "custom_value_1"
64+
assert mcp_info["custom_field_2"] == {"nested": "value"}
65+
assert mcp_info["custom_field_3"] == ["list", "values"]
66+
assert mcp_info["priority"] == 10
67+
assert mcp_info["tags"] == ["production", "api"]
68+
69+
def test_custom_fields_preserved_from_database(self):
70+
"""Test that custom fields in mcp_info are preserved when adding from database."""
71+
manager = MCPServerManager()
72+
73+
# Mock database record with custom fields
74+
mock_server = Mock(spec=LiteLLM_MCPServerTable)
75+
mock_server.server_id = "test-server-id"
76+
mock_server.server_name = "Test Server"
77+
mock_server.description = "A test server"
78+
mock_server.url = "http://localhost:3000"
79+
mock_server.transport = "http"
80+
mock_server.auth_type = MCPAuth.bearer_token
81+
mock_server.alias = None
82+
mock_server.mcp_info = {
83+
"server_name": "Test Server",
84+
"description": "A test server",
85+
"custom_db_field": "database_value",
86+
"metadata": {"source": "database"},
87+
"version": "1.0.0"
88+
}
89+
mock_server.command = None
90+
mock_server.args = None
91+
mock_server.env = None
92+
mock_server.mcp_access_groups = None
93+
94+
# Add server to manager
95+
manager.add_update_server(mock_server)
96+
97+
# Get the added server
98+
server = manager.get_mcp_server_by_id("test-server-id")
99+
assert server is not None
100+
101+
mcp_info = server.mcp_info
102+
103+
# Verify standard fields are preserved
104+
assert mcp_info["server_name"] == "Test Server"
105+
assert mcp_info["description"] == "A test server"
106+
107+
# Verify custom fields are preserved
108+
assert mcp_info["custom_db_field"] == "database_value"
109+
assert mcp_info["metadata"] == {"source": "database"}
110+
assert mcp_info["version"] == "1.0.0"
111+
112+
def test_empty_mcp_info_handled_gracefully(self):
113+
"""Test that empty or missing mcp_info is handled gracefully."""
114+
manager = MCPServerManager()
115+
116+
# Config with empty mcp_info
117+
mock_config = {
118+
"test_server": {
119+
"url": "http://localhost:3000",
120+
"transport": "http",
121+
"mcp_info": {}
122+
}
123+
}
124+
125+
manager.load_servers_from_config(mock_config)
126+
127+
servers = list(manager.config_mcp_servers.values())
128+
assert len(servers) == 1
129+
130+
server = servers[0]
131+
mcp_info = server.mcp_info
132+
133+
# Should have default server_name
134+
assert mcp_info["server_name"] == "test_server"
135+
136+
def test_missing_mcp_info_creates_defaults(self):
137+
"""Test that missing mcp_info creates appropriate defaults."""
138+
manager = MCPServerManager()
139+
140+
# Config without mcp_info
141+
mock_config = {
142+
"test_server": {
143+
"url": "http://localhost:3000",
144+
"transport": "http",
145+
"description": "Server description"
146+
}
147+
}
148+
149+
manager.load_servers_from_config(mock_config)
150+
151+
servers = list(manager.config_mcp_servers.values())
152+
assert len(servers) == 1
153+
154+
server = servers[0]
155+
mcp_info = server.mcp_info
156+
157+
# Should have default server_name and description from config
158+
assert mcp_info["server_name"] == "test_server"
159+
assert mcp_info["description"] == "Server description"
160+
161+
def test_config_description_fallback(self):
162+
"""Test that description from config level is used as fallback."""
163+
manager = MCPServerManager()
164+
165+
# Config with description at server level but not in mcp_info
166+
mock_config = {
167+
"test_server": {
168+
"url": "http://localhost:3000",
169+
"transport": "http",
170+
"description": "Config level description",
171+
"mcp_info": {
172+
"custom_field": "custom_value"
173+
}
174+
}
175+
}
176+
177+
manager.load_servers_from_config(mock_config)
178+
179+
servers = list(manager.config_mcp_servers.values())
180+
server = servers[0]
181+
mcp_info = server.mcp_info
182+
183+
# Should use config level description as fallback
184+
assert mcp_info["description"] == "Config level description"
185+
assert mcp_info["custom_field"] == "custom_value"
186+
187+
def test_mcp_info_description_takes_precedence(self):
188+
"""Test that description in mcp_info takes precedence over config level."""
189+
manager = MCPServerManager()
190+
191+
# Config with description at both levels
192+
mock_config = {
193+
"test_server": {
194+
"url": "http://localhost:3000",
195+
"transport": "http",
196+
"description": "Config level description",
197+
"mcp_info": {
198+
"description": "MCP info description",
199+
"custom_field": "custom_value"
200+
}
201+
}
202+
}
203+
204+
manager.load_servers_from_config(mock_config)
205+
206+
servers = list(manager.config_mcp_servers.values())
207+
server = servers[0]
208+
mcp_info = server.mcp_info
209+
210+
# Should use mcp_info description, not config level
211+
assert mcp_info["description"] == "MCP info description"
212+
assert mcp_info["custom_field"] == "custom_value"

0 commit comments

Comments
 (0)