Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/my-website/docs/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ mcp_servers:
extra_headers: ["custom_key", "x-custom-header"] # These headers will be forwarded from client
```

### Static Headers

Sometimes your MCP server needs specific headers on every request. Maybe it's an API key, maybe it's a custom header the server expects. Instead of configuring auth, you can just set them directly.

```yaml title="config.yaml" showLineNumbers
mcp_servers:
my_mcp_server:
url: "https://my-mcp-server.com/mcp"
static_headers:
X-API-Key: "abc123"
X-Custom-Header: "some-value"
```

These headers get sent with every request to the server. That's it.

**When to use this:**
- Your server needs custom headers that don't fit the standard auth patterns
- You want full control over exactly what headers are sent
- You're debugging and need to quickly add headers without changing auth configuration

### MCP Aliases

Expand Down
5 changes: 2 additions & 3 deletions litellm/experimental_mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,9 @@ async def list_tools(self) -> List[MCPTool]:
await self.disconnect()
raise
except Exception as e:
verbose_logger.warning(f"MCP client list_tools failed: {str(e)}")
verbose_logger.debug(f"MCP client list_tools failed: {str(e)}")
await self.disconnect()
# Return empty list instead of raising to allow graceful degradation
return []
raise e

async def call_tool(
self, call_tool_request_params: MCPCallToolRequestParams
Expand Down
108 changes: 83 additions & 25 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def load_servers_from_config(
disallowed_tools=server_config.get("disallowed_tools", None),
allowed_params=server_config.get("allowed_params", None),
access_groups=server_config.get("access_groups", None),
static_headers=server_config.get("static_headers", None),
)
self.config_mcp_servers[server_id] = new_server

Expand Down Expand Up @@ -632,6 +633,11 @@ async def _get_tools_from_server(
client = None

try:
if server.static_headers:
if extra_headers is None:
extra_headers = {}
extra_headers.update(server.static_headers)

client = self._create_mcp_client(
server=server,
mcp_auth_header=mcp_auth_header,
Expand All @@ -654,10 +660,10 @@ async def _get_tools_from_server(
return prefixed_or_original_tools

except Exception as e:
verbose_logger.warning(
verbose_logger.debug(
f"Failed to get tools from server {server.name}: {str(e)}"
)
return []
raise e
finally:
if client:
try:
Expand Down Expand Up @@ -686,14 +692,14 @@ async def _list_tools_task():
tools = await client.list_tools()
verbose_logger.debug(f"Tools from {server_name}: {tools}")
return tools
except asyncio.CancelledError:
verbose_logger.warning(f"Client operation cancelled for {server_name}")
return []
except asyncio.CancelledError as e:
verbose_logger.debug(f"Client operation cancelled for {server_name}")
raise e
except Exception as e:
verbose_logger.warning(
verbose_logger.debug(
f"Client operation failed for {server_name}: {str(e)}"
)
return []
raise e
finally:
try:
await client.disconnect()
Expand All @@ -702,22 +708,22 @@ async def _list_tools_task():

try:
return await asyncio.wait_for(_list_tools_task(), timeout=30.0)
except asyncio.TimeoutError:
verbose_logger.warning(f"Timeout while listing tools from {server_name}")
return []
except asyncio.CancelledError:
verbose_logger.warning(
except asyncio.TimeoutError as e:
verbose_logger.debug(f"Timeout while listing tools from {server_name}")
raise e
except asyncio.CancelledError as e:
verbose_logger.debug(
f"Task cancelled while listing tools from {server_name}"
)
return []
raise e
except ConnectionError as e:
verbose_logger.warning(
verbose_logger.debug(
f"Connection error while listing tools from {server_name}: {str(e)}"
)
return []
raise e
except Exception as e:
verbose_logger.warning(f"Error listing tools from {server_name}: {str(e)}")
return []
verbose_logger.debug(f"Error listing tools from {server_name}: {str(e)}")
raise e

def _create_prefixed_tools(
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
Expand Down Expand Up @@ -1058,14 +1064,7 @@ async def call_tool(
raise ValueError(f"Tool {name} not found")

# Validate that the server from prefix matches the actual server (if prefix was used)
if server_name_from_prefix:
expected_prefix = get_server_prefix(mcp_server)
if normalize_server_name(server_name_from_prefix) != normalize_server_name(
expected_prefix
):
raise ValueError(
f"Tool {name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}"
)
self._validate_server_prefix_match(name, server_name_from_prefix, mcp_server)

#########################################################
# Pre MCP Tool Call Hook
Expand All @@ -1082,6 +1081,39 @@ async def call_tool(
server=mcp_server,
)

# Get server-specific auth header if available
server_auth_header: Optional[Union[Dict[str, str], str]] = None
if mcp_server_auth_headers and mcp_server.alias:
server_auth_header = mcp_server_auth_headers.get(mcp_server.alias)
elif mcp_server_auth_headers and mcp_server.server_name:
server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name)

# Fall back to deprecated mcp_auth_header if no server-specific header found
if server_auth_header is None:
server_auth_header = mcp_auth_header

# oauth2 headers
extra_headers: Optional[Dict[str, str]] = None
if mcp_server.auth_type == MCPAuth.oauth2:
extra_headers = oauth2_headers

if mcp_server.extra_headers and raw_headers:
if extra_headers is None:
extra_headers = {}
for header in mcp_server.extra_headers:
if header in raw_headers:
extra_headers[header] = raw_headers[header]

if mcp_server.static_headers:
if extra_headers is None:
extra_headers = {}
extra_headers.update(mcp_server.static_headers)

client = self._create_mcp_client(
server=mcp_server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
)
# Prepare tasks for during hooks
tasks = []
if proxy_logging_obj:
Expand Down Expand Up @@ -1248,6 +1280,32 @@ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:

return None

def _validate_server_prefix_match(
self,
tool_name: str,
server_name_from_prefix: Optional[str],
mcp_server: MCPServer,
) -> None:
"""
Validate that the server prefix from the tool name matches the actual server.

Args:
tool_name: Original tool name provided
server_name_from_prefix: Server name extracted from tool name prefix (if any)
mcp_server: The MCP server that was found for this tool

Raises:
ValueError: If the server prefix doesn't match the expected server
"""
if server_name_from_prefix:
expected_prefix = get_server_prefix(mcp_server)
if normalize_server_name(server_name_from_prefix) != normalize_server_name(
expected_prefix
):
raise ValueError(
f"Tool {tool_name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}"
)

async def _add_mcp_servers_from_db_to_in_memory_registry(self):
from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
Expand Down
9 changes: 4 additions & 5 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,8 @@ async def list_tools() -> List[MCPTool]:
)
return tools
except Exception as e:
verbose_logger.exception(f"Error in list_tools endpoint: {str(e)}")
# Return empty list instead of failing completely
# This prevents the HTTP stream from failing and allows the client to get a response
return []
verbose_logger.debug(f"Error in list_tools endpoint: {str(e)}")
raise e

@server.call_tool()
async def mcp_server_tool_call(
Expand Down Expand Up @@ -514,10 +512,11 @@ async def _get_tools_from_mcp_servers(
f"Successfully fetched {len(tools)} tools from server {server.name}, {len(filtered_tools)} after filtering"
)
except Exception as e:
verbose_logger.exception(
verbose_logger.debug(
f"Error getting tools from server {server.name}: {str(e)}"
)
# Continue with other servers instead of failing completely
raise e

verbose_logger.info(
f"Successfully fetched {len(all_tools)} tools total from all MCP servers"
Expand Down
5 changes: 4 additions & 1 deletion litellm/types/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MCPServer(BaseModel):
authentication_token: Optional[str] = None
mcp_info: Optional[MCPInfo] = None
extra_headers: Optional[List[str]] = (
None # allow admin to specify which headers to forward to the MCP server
None # allow admin to specify which headers to forward from client to the MCP server
)
allowed_tools: Optional[List[str]] = None
disallowed_tools: Optional[List[str]] = None
Expand All @@ -40,4 +40,7 @@ class MCPServer(BaseModel):
args: Optional[List[str]] = None
env: Optional[Dict[str, str]] = None
access_groups: Optional[List[str]] = None
static_headers: Optional[Dict[str, str]] = (
None # static headers to forward to the MCP server
)
model_config = ConfigDict(arbitrary_types_allowed=True)
Loading