diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 4c866561f707..8eb865b53a5f 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -7,6 +7,7 @@ """ import asyncio +import contextlib import datetime import hashlib import json @@ -479,7 +480,7 @@ async def _get_tools_from_server( verbose_logger.warning( f"Failed to get tools from server {server.name}: {str(e)}" ) - return [] + raise finally: if client: try: @@ -503,43 +504,64 @@ async def _fetch_tools_with_timeout( async def _list_tools_task(): try: - await client.connect() + try: + await client.connect() + except asyncio.CancelledError as e: + verbose_logger.warning( + f"Client operation cancelled for {server_name}: {str(e)}" + ) + raise asyncio.CancelledError( + f"Connection error while listing tools from {server_name}" + ) + except Exception as e: + verbose_logger.warning( + f"Client operation failed for {server_name}: {str(e)}" + ) + raise Exception( + "Connection error while listing tools from {server_name} : {str(e)}" + ) - tools = await client.list_tools() - verbose_logger.debug(f"Tools from {server_name}: {tools}") - return tools - except asyncio.CancelledError: - verbose_logger.warning(f"Client operation cancelled for {server_name}") - return [] - except Exception as e: - verbose_logger.warning( - f"Client operation failed for {server_name}: {str(e)}" - ) - return [] - finally: try: - await client.disconnect() - except Exception: - pass + tools = await client.list_tools() + verbose_logger.debug(f"Tools from {server_name}: {tools}") + return tools + except asyncio.CancelledError as e: + verbose_logger.warning( + f"Failed to list tools from {server_name}: {str(e)}" + ) + raise asyncio.CancelledError( + f"Failed to list tools from {server_name}: {str(e)}" + ) + except Exception as e: + verbose_logger.warning( + f"Failed to list tools from {server_name}: {str(e)}" + ) + raise Exception("list_tools error: {asyncio.CancelledError}") + finally: + with contextlib.suppress(asyncio.CancelledError): + try: + await client.disconnect() + except Exception: + pass try: return await asyncio.wait_for(_list_tools_task(), timeout=30.0) - except asyncio.TimeoutError: + except asyncio.TimeoutError as e: verbose_logger.warning(f"Timeout while listing tools from {server_name}") - return [] - except asyncio.CancelledError: + raise Exception(f"{str(e)}") + except asyncio.CancelledError as e: verbose_logger.warning( f"Task cancelled while listing tools from {server_name}" ) - return [] + raise Exception(f"{str(e)}") except ConnectionError as e: verbose_logger.warning( f"Connection error while listing tools from {server_name}: {str(e)}" ) - return [] + raise except Exception as e: verbose_logger.warning(f"Error listing tools from {server_name}: {str(e)}") - return [] + raise def _create_prefixed_tools( self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True @@ -597,7 +619,6 @@ async def pre_call_tool_check( proxy_logging_obj: ProxyLogging, server: MCPServer, ): - ## check if the tool is allowed or banned for the given server if not self.check_allowed_or_banned_tools(name, server): raise HTTPException( diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index ecb960ebcf39..ccec94800de3 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -1,7 +1,7 @@ import importlib from typing import Dict, List, Optional, Union -from fastapi import APIRouter, Depends, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from litellm._logging import verbose_logger from litellm.proxy._types import UserAPIKeyAuth @@ -145,11 +145,14 @@ async def list_tool_rest_api( verbose_logger.exception( f"Error getting tools from {server.name}: {e}" ) - return { - "tools": [], - "error": "server_error", - "message": f"Failed to get tools from server {server.name}: {str(e)}", - } + raise HTTPException( + status_code=500, + detail={ + "tools": [], + "error": "server_error", + "message": f"An unexpected error occurred: {str(e)}", + }, + ) else: # Query all servers errors = [] @@ -183,15 +186,20 @@ async def list_tool_rest_api( ), } + except HTTPException: + raise except Exception as e: verbose_logger.exception( "Unexpected error in list_tool_rest_api: %s", str(e) ) - return { - "tools": [], - "error": "unexpected_error", - "message": f"An unexpected error occurred: {str(e)}", - } + raise HTTPException( + status_code=500, + detail={ + "tools": [], + "error": "unexpected_error", + "message": f"An unexpected error occurred: {str(e)}", + }, + ) @router.post("/tools/call", dependencies=[Depends(user_api_key_auth)]) async def call_tool_rest_api( diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_rest_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_rest_endpoints.py new file mode 100644 index 000000000000..f096c3e48577 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_rest_endpoints.py @@ -0,0 +1,115 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +@pytest.mark.asyncio +async def test_list_tool_rest_api_raises_http_exception_on_server_failure(): + from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api + from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( + MCPRequestHandler, + ) + + mock_request = MagicMock() + mock_request.headers = {} + mock_user_api_key_dict = MagicMock() + + with patch.object( + MCPRequestHandler, "_get_mcp_auth_header_from_headers" + ) as mock_get_auth: + with patch.object( + MCPRequestHandler, "_get_mcp_server_auth_headers_from_headers" + ) as mock_get_server_auth: + mock_get_auth.return_value = None + mock_get_server_auth.return_value = {} + + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints.global_mcp_server_manager" + ) as mock_manager: + mock_server = MagicMock() + mock_server.name = "test-server" + mock_server.alias = "test-server" + mock_server.server_name = "test-server" + mock_server.mcp_info = {"server_name": "test-server"} + + mock_manager.get_mcp_server_by_id.return_value = mock_server + + failing_get_tools = AsyncMock(side_effect=Exception("backend failure")) + + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints._get_tools_for_single_server", + failing_get_tools, + ): + with pytest.raises(HTTPException) as exc_info: + await list_tool_rest_api( + request=mock_request, + server_id="test-server", + user_api_key_dict=mock_user_api_key_dict, + ) + + exc = exc_info.value + assert exc.status_code == 500 + assert exc.detail["error"] == "server_error" + assert exc.detail["tools"] == [] + assert "An unexpected error occurred" in exc.detail["message"] + + +@pytest.mark.asyncio +async def test_list_tool_rest_api_returns_tools_successfully(): + from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api + from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( + MCPRequestHandler, + ) + from litellm.proxy._experimental.mcp_server.server import ( + ListMCPToolsRestAPIResponseObject, + ) + + mock_request = MagicMock() + mock_request.headers = { + "authorization": "Bearer user_token", + "x-mcp-authorization": "Bearer default_token", + } + mock_user_api_key_dict = MagicMock() + + with patch.object( + MCPRequestHandler, "_get_mcp_auth_header_from_headers" + ) as mock_get_auth: + with patch.object( + MCPRequestHandler, "_get_mcp_server_auth_headers_from_headers" + ) as mock_get_server_auth: + mock_get_auth.return_value = "Bearer default_token" + mock_get_server_auth.return_value = {} + + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints.global_mcp_server_manager" + ) as mock_manager: + mock_server = MagicMock() + mock_server.name = "test-server" + mock_server.alias = "test-server" + mock_server.server_name = "test-server" + mock_server.mcp_info = {"server_name": "test-server"} + + mock_manager.get_mcp_server_by_id.return_value = mock_server + + tool = ListMCPToolsRestAPIResponseObject( + name="send_email", + description="Send an email", + inputSchema={"type": "object"}, + mcp_info={"server_name": "test-server"}, + ) + + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints._get_tools_for_single_server", + AsyncMock(return_value=[tool]), + ) as mock_get_tools: + result = await list_tool_rest_api( + request=mock_request, + server_id="test-server", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result["error"] is None + assert result["message"] == "Successfully retrieved tools" + assert len(result["tools"]) == 1 + assert result["tools"][0].name == "send_email" + mock_get_tools.assert_awaited_once() diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index 80ea95a22102..e7ffee8ed2ca 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -1,3 +1,4 @@ +import asyncio import sys from datetime import datetime from unittest.mock import AsyncMock, MagicMock @@ -688,6 +689,64 @@ async def test_get_tools_from_server_add_prefix(self): assert len(tools_unprefixed) == 1 assert tools_unprefixed[0].name == "send_email" + @pytest.mark.asyncio + async def test_fetch_tools_with_timeout_propagates_cancelled_error(self): + manager = MCPServerManager() + + client = AsyncMock() + client.connect.side_effect = asyncio.CancelledError( + "connect cancelled" + ) + client.disconnect = AsyncMock() + + with pytest.raises(Exception) as exc_info: + await manager._fetch_tools_with_timeout(client, "test-server") + + assert ( + str(exc_info.value) + == "Connection error while listing tools from test-server" + ) + assert client.disconnect.await_count == 1 + + @pytest.mark.asyncio + async def test_fetch_tools_with_timeout_propagates_list_tools_error(self): + manager = MCPServerManager() + + client = AsyncMock() + client.connect = AsyncMock(return_value=None) + client.list_tools = AsyncMock(side_effect=Exception("list failure")) + client.disconnect = AsyncMock() + + with pytest.raises(Exception) as exc_info: + await manager._fetch_tools_with_timeout(client, "test-server") + + assert str(exc_info.value) == "list_tools error: {asyncio.CancelledError}" + assert client.disconnect.await_count == 1 + + @pytest.mark.asyncio + async def test_get_tools_from_server_propagates_fetch_failure(self): + manager = MCPServerManager() + + server = MagicMock() + server.name = "failing-server" + + client = AsyncMock() + client.disconnect = AsyncMock() + + manager._create_mcp_client = MagicMock(return_value=client) + manager._create_prefixed_tools = MagicMock() + + manager._fetch_tools_with_timeout = AsyncMock( + side_effect=Exception("fetch failure") + ) + + with pytest.raises(Exception) as exc_info: + await manager._get_tools_from_server(server) + + assert str(exc_info.value) == "fetch failure" + assert client.disconnect.await_count == 1 + manager._create_prefixed_tools.assert_not_called() + def test_create_prefixed_tools_updates_mapping_for_both_forms(self): """_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output.""" manager = MCPServerManager() diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index b16543b8f7ad..fea631854b6c 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -5812,8 +5812,8 @@ export const listMCPTools = async (accessToken: string, serverId: string, authVa if (!response.ok) { // If the server returned an error response, use it - if (data.error && data.message) { - throw new Error(data.message); + if (data.detail.error && data.detail.message) { + throw new Error(data.detail.message); } // Otherwise use a generic error throw new Error("Failed to fetch MCP tools");