diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index 6333f7f09131..d143c02ac5c1 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -219,6 +219,25 @@ mcp_servers: extra_headers: ["custom_key", "x-custom-header"] # These headers will be forwarded from client ``` +### Static Headers + +Sometimes your MCP server needs specific headers on every request. Maybe it's an API key, maybe it's a custom header the server expects. Instead of configuring auth, you can just set them directly. + +```yaml title="config.yaml" showLineNumbers +mcp_servers: + my_mcp_server: + url: "https://my-mcp-server.com/mcp" + static_headers: + X-API-Key: "abc123" + X-Custom-Header: "some-value" +``` + +These headers get sent with every request to the server. That's it. + +**When to use this:** +- Your server needs custom headers that don't fit the standard auth patterns +- You want full control over exactly what headers are sent +- You're debugging and need to quickly add headers without changing auth configuration ### MCP Aliases diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index b10ddc9e8126..5a31749d46ac 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -279,10 +279,9 @@ async def list_tools(self) -> List[MCPTool]: await self.disconnect() raise except Exception as e: - verbose_logger.warning(f"MCP client list_tools failed: {str(e)}") + verbose_logger.debug(f"MCP client list_tools failed: {str(e)}") await self.disconnect() - # Return empty list instead of raising to allow graceful degradation - return [] + raise e async def call_tool( self, call_tool_request_params: MCPCallToolRequestParams diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index ffdcfe8679f6..f285deb325e3 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -218,6 +218,7 @@ def load_servers_from_config( disallowed_tools=server_config.get("disallowed_tools", None), allowed_params=server_config.get("allowed_params", None), access_groups=server_config.get("access_groups", None), + static_headers=server_config.get("static_headers", None), ) self.config_mcp_servers[server_id] = new_server @@ -632,6 +633,11 @@ async def _get_tools_from_server( client = None try: + if server.static_headers: + if extra_headers is None: + extra_headers = {} + extra_headers.update(server.static_headers) + client = self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, @@ -654,10 +660,10 @@ async def _get_tools_from_server( return prefixed_or_original_tools except Exception as e: - verbose_logger.warning( + verbose_logger.debug( f"Failed to get tools from server {server.name}: {str(e)}" ) - return [] + raise e finally: if client: try: @@ -686,14 +692,14 @@ async def _list_tools_task(): tools = await client.list_tools() verbose_logger.debug(f"Tools from {server_name}: {tools}") return tools - except asyncio.CancelledError: - verbose_logger.warning(f"Client operation cancelled for {server_name}") - return [] + except asyncio.CancelledError as e: + verbose_logger.debug(f"Client operation cancelled for {server_name}") + raise e except Exception as e: - verbose_logger.warning( + verbose_logger.debug( f"Client operation failed for {server_name}: {str(e)}" ) - return [] + raise e finally: try: await client.disconnect() @@ -702,22 +708,22 @@ async def _list_tools_task(): try: return await asyncio.wait_for(_list_tools_task(), timeout=30.0) - except asyncio.TimeoutError: - verbose_logger.warning(f"Timeout while listing tools from {server_name}") - return [] - except asyncio.CancelledError: - verbose_logger.warning( + except asyncio.TimeoutError as e: + verbose_logger.debug(f"Timeout while listing tools from {server_name}") + raise e + except asyncio.CancelledError as e: + verbose_logger.debug( f"Task cancelled while listing tools from {server_name}" ) - return [] + raise e except ConnectionError as e: - verbose_logger.warning( + verbose_logger.debug( f"Connection error while listing tools from {server_name}: {str(e)}" ) - return [] + raise e except Exception as e: - verbose_logger.warning(f"Error listing tools from {server_name}: {str(e)}") - return [] + verbose_logger.debug(f"Error listing tools from {server_name}: {str(e)}") + raise e def _create_prefixed_tools( self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True @@ -1058,14 +1064,7 @@ async def call_tool( raise ValueError(f"Tool {name} not found") # Validate that the server from prefix matches the actual server (if prefix was used) - if server_name_from_prefix: - expected_prefix = get_server_prefix(mcp_server) - if normalize_server_name(server_name_from_prefix) != normalize_server_name( - expected_prefix - ): - raise ValueError( - f"Tool {name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}" - ) + self._validate_server_prefix_match(name, server_name_from_prefix, mcp_server) ######################################################### # Pre MCP Tool Call Hook @@ -1082,6 +1081,39 @@ async def call_tool( server=mcp_server, ) + # Get server-specific auth header if available + server_auth_header: Optional[Union[Dict[str, str], str]] = None + if mcp_server_auth_headers and mcp_server.alias: + server_auth_header = mcp_server_auth_headers.get(mcp_server.alias) + elif mcp_server_auth_headers and mcp_server.server_name: + server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name) + + # Fall back to deprecated mcp_auth_header if no server-specific header found + if server_auth_header is None: + server_auth_header = mcp_auth_header + + # oauth2 headers + extra_headers: Optional[Dict[str, str]] = None + if mcp_server.auth_type == MCPAuth.oauth2: + extra_headers = oauth2_headers + + if mcp_server.extra_headers and raw_headers: + if extra_headers is None: + extra_headers = {} + for header in mcp_server.extra_headers: + if header in raw_headers: + extra_headers[header] = raw_headers[header] + + if mcp_server.static_headers: + if extra_headers is None: + extra_headers = {} + extra_headers.update(mcp_server.static_headers) + + client = self._create_mcp_client( + server=mcp_server, + mcp_auth_header=server_auth_header, + extra_headers=extra_headers, + ) # Prepare tasks for during hooks tasks = [] if proxy_logging_obj: @@ -1248,6 +1280,32 @@ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]: return None + def _validate_server_prefix_match( + self, + tool_name: str, + server_name_from_prefix: Optional[str], + mcp_server: MCPServer, + ) -> None: + """ + Validate that the server prefix from the tool name matches the actual server. + + Args: + tool_name: Original tool name provided + server_name_from_prefix: Server name extracted from tool name prefix (if any) + mcp_server: The MCP server that was found for this tool + + Raises: + ValueError: If the server prefix doesn't match the expected server + """ + if server_name_from_prefix: + expected_prefix = get_server_prefix(mcp_server) + if normalize_server_name(server_name_from_prefix) != normalize_server_name( + expected_prefix + ): + raise ValueError( + f"Tool {tool_name} server prefix mismatch: expected {expected_prefix}, got {server_name_from_prefix}" + ) + async def _add_mcp_servers_from_db_to_in_memory_registry(self): from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers from litellm.proxy.management_endpoints.mcp_management_endpoints import ( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 77d6abfed62b..d5879289ae16 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -206,10 +206,8 @@ async def list_tools() -> List[MCPTool]: ) return tools except Exception as e: - verbose_logger.exception(f"Error in list_tools endpoint: {str(e)}") - # Return empty list instead of failing completely - # This prevents the HTTP stream from failing and allows the client to get a response - return [] + verbose_logger.debug(f"Error in list_tools endpoint: {str(e)}") + raise e @server.call_tool() async def mcp_server_tool_call( @@ -514,10 +512,11 @@ async def _get_tools_from_mcp_servers( f"Successfully fetched {len(tools)} tools from server {server.name}, {len(filtered_tools)} after filtering" ) except Exception as e: - verbose_logger.exception( + verbose_logger.debug( f"Error getting tools from server {server.name}: {str(e)}" ) # Continue with other servers instead of failing completely + raise e verbose_logger.info( f"Successfully fetched {len(all_tools)} tools total from all MCP servers" diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 03bee94eb8b9..b0ddc017d174 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -22,7 +22,7 @@ class MCPServer(BaseModel): authentication_token: Optional[str] = None mcp_info: Optional[MCPInfo] = None extra_headers: Optional[List[str]] = ( - None # allow admin to specify which headers to forward to the MCP server + None # allow admin to specify which headers to forward from client to the MCP server ) allowed_tools: Optional[List[str]] = None disallowed_tools: Optional[List[str]] = None @@ -40,4 +40,7 @@ class MCPServer(BaseModel): args: Optional[List[str]] = None env: Optional[Dict[str, str]] = None access_groups: Optional[List[str]] = None + static_headers: Optional[Dict[str, str]] = ( + None # static headers to forward to the MCP server + ) model_config = ConfigDict(arbitrary_types_allowed=True)