diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 5216f17a..4c82939a 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -17,11 +17,15 @@ dependencies = [ "unitycatalog-langchain[databricks]>=0.3.0", "databricks-sdk>=0.65.0", "openai>=1.99.9", + "langchain-mcp-adapters>=0.1.13", + "databricks_mcp>=0.4.0" + ] [project.optional-dependencies] dev = [ "pytest", + "pytest-asyncio", "typing_extensions", "ruff==0.6.4", ] diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 40135d97..3691ed7c 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -20,6 +20,11 @@ from databricks_langchain.chat_models import ChatDatabricks from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent +from databricks_langchain.multi_server_mcp_client import ( + DatabricksMCPServer, + DatabricksMultiServerMCPClient, + MCPServer, +) from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool from databricks_langchain.vectorstores import DatabricksVectorSearch @@ -34,4 +39,7 @@ "UnityCatalogTool", "DatabricksFunctionClient", "set_uc_function_client", + "DatabricksMultiServerMCPClient", + "DatabricksMCPServer", + "MCPServer", ] diff --git a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py new file mode 100644 index 00000000..3142e6ea --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py @@ -0,0 +1,223 @@ +from typing import Any, Callable, List, Union + +from databricks.sdk import WorkspaceClient +from databricks_mcp.oauth_provider import DatabricksOAuthClientProvider +from langchain_mcp_adapters.client import MultiServerMCPClient +from pydantic import BaseModel, ConfigDict, Field + + +class MCPServer(BaseModel): + """ + Base configuration for an MCP server connection using streamable HTTP transport. + + Accepts any additional keyword arguments which are automatically passed through + to LangChain's Connection type, making this forward-compatible with future updates. + + Common optional parameters: + - headers: dict[str, str] - Custom HTTP headers + - timeout: float - Request timeout in seconds + - sse_read_timeout: float - SSE read timeout in seconds + - auth: httpx.Auth - Authentication handler + - httpx_client_factory: Callable - Custom httpx client factory + - terminate_on_close: bool - Terminate connection on close + - session_kwargs: dict - Additional session kwargs + + Example: + ```python + from databricks_langchain import DatabricksMultiServerMCPClient, MCPServer + + # Generic server with custom params - flat API for easy configuration + server = MCPServer( + name="other-server", + url="https://other-server.com/mcp", + headers={"X-API-Key": "secret"}, + timeout=15.0, + handle_tool_error="An error occurred. Please try again.", + ) + + client = DatabricksMultiServerMCPClient([server]) + tools = await client.get_tools() + ``` + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + name: str = Field(..., exclude=True, description="Name to identify this server connection") + url: str + handle_tool_error: Union[bool, str, Callable[[Exception], str], None] = Field( + default=None, + exclude=True, + description=( + "How to handle errors raised by tools from this server. Options:\n" + "- None/False: Raise the error\n" + "- True: Return error message as string\n" + "- str: Return this string when errors occur\n" + "- Callable: Function that takes error and returns error message string" + ), + ) + + def to_connection_dict(self) -> dict[str, Any]: + """ + Convert to connection dictionary for LangChain MultiServerMCPClient. + + Automatically includes all extra fields passed to the constructor, + allowing forward compatibility with new LangChain connection fields. + """ + # Get all model fields including extra fields (name is auto-excluded) + data = self.model_dump() + + # Add transport type (hardcoded to streamable_http) + data["transport"] = "streamable_http" + + return data + + +class DatabricksMCPServer(MCPServer): + """ + MCP server configuration with Databricks authentication. + + Automatically sets up OAuth authentication using the provided WorkspaceClient. + Also accepts any additional connection parameters as keyword arguments. + + Example: + ```python + from databricks.sdk import WorkspaceClient + from databricks_langchain import DatabricksMultiServerMCPClient, DatabricksMCPServer + + # Databricks server with automatic OAuth - just pass params as kwargs! + server = DatabricksMCPServer( + name="databricks-prod", + url="https://your-workspace.databricks.com/mcp", + workspace_client=WorkspaceClient(), + timeout=30.0, + sse_read_timeout=60.0, + handle_tool_error=True, # Return errors as strings instead of raising + ) + + client = DatabricksMultiServerMCPClient([server]) + tools = await client.get_tools() + ``` + """ + + workspace_client: WorkspaceClient | None = Field( + default=None, + description="Databricks WorkspaceClient for authentication. If None, will be auto-initialized.", + exclude=True, + ) + + def model_post_init(self, __context: Any) -> None: + """Initialize DatabricksServer with auth setup.""" + super().model_post_init(__context) + + # Set up Databricks OAuth authentication after initialization + if self.workspace_client is None: + self.workspace_client = WorkspaceClient() + + # Store the auth provider internally + self._auth_provider = DatabricksOAuthClientProvider(self.workspace_client) + + def to_connection_dict(self) -> dict[str, Any]: + """ + Convert to connection dictionary, including Databricks auth. + """ + # Get base connection dict + data = super().to_connection_dict() + + # Add Databricks auth provider + data["auth"] = self._auth_provider + + return data + + +class DatabricksMultiServerMCPClient(MultiServerMCPClient): + """ + MultiServerMCPClient with simplified configuration for Databricks servers. + + This wrapper provides an ergonomic interface similar to LangChain's API while + remaining forward-compatible with future connection parameters. + + Example: + ```python + from databricks.sdk import WorkspaceClient + from databricks_langchain import ( + DatabricksMultiServerMCPClient, + DatabricksMCPServer, + MCPServer, + ) + + client = DatabricksMultiServerMCPClient( + [ + # Databricks server with automatic OAuth - just pass params as kwargs! + DatabricksMCPServer( + name="databricks-prod", + url="https://your-workspace.databricks.com/mcp", + workspace_client=WorkspaceClient(), + timeout=30.0, + sse_read_timeout=60.0, + handle_tool_error=True, # Return errors as strings instead of raising + ), + # Generic server with custom params - same flat API + MCPServer( + name="other-server", + url="https://other-server.com/mcp", + headers={"X-API-Key": "secret"}, + timeout=15.0, + handle_tool_error="An error occurred. Please try again.", + ), + ] + ) + + tools = await client.get_tools() + ``` + """ + + def __init__(self, servers: List[MCPServer], **kwargs): + """ + Initialize the client with a list of server configurations. + + Args: + servers: List of MCPServer or DatabricksMCPServer configurations + **kwargs: Additional arguments to pass to MultiServerMCPClient + """ + # Store server configs for later use (e.g., handle_tool_errors) + self._server_configs = {server.name: server for server in servers} + + # Create connections dict (excluding tool-level params like handle_tool_errors) + connections = {server.name: server.to_connection_dict() for server in servers} + super().__init__(connections=connections, **kwargs) + + async def get_tools(self, server_name: str | None = None): + """ + Get tools from MCP servers, applying handle_tool_error configuration. + + Args: + server_name: Optional server name to get tools from. If None, gets tools from all servers. + + Returns: + List of LangChain tools with handle_tool_error configurations applied. + """ + import asyncio + + # Determine which servers to load from + server_names = [server_name] if server_name is not None else list(self.connections.keys()) + + # Load tools from servers in parallel + load_tool_tasks = [ + asyncio.create_task( + super(DatabricksMultiServerMCPClient, self).get_tools(server_name=name) + ) + for name in server_names + ] + tools_list = await asyncio.gather(*load_tool_tasks) + + # Apply handle_tool_error configurations and collect tools + all_tools = [] + for name, tools in zip(server_names, tools_list, strict=True): + if name in self._server_configs: + server_config = self._server_configs[name] + if server_config.handle_tool_error is not None: + for tool in tools: + tool.handle_tool_error = server_config.handle_tool_error + all_tools.extend(tools) + + return all_tools diff --git a/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py new file mode 100644 index 00000000..17b469f4 --- /dev/null +++ b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py @@ -0,0 +1,315 @@ +"""Unit tests for DatabricksMultiServerMCPClient and related classes.""" + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch + +import pytest +from databricks.sdk import WorkspaceClient + +from databricks_langchain.multi_server_mcp_client import ( + DatabricksMCPServer, + DatabricksMultiServerMCPClient, + MCPServer, +) + + +class TestMCPServer: + """Tests for the MCPServer class.""" + + def test_basic_server_creation(self): + """Test creating a basic server with minimal parameters.""" + server = MCPServer(name="test-server", url="https://example.com/mcp") + + assert server.name == "test-server" + assert server.url == "https://example.com/mcp" + assert server.handle_tool_error is None + + @pytest.mark.parametrize( + "extra_params", + [ + {"timeout": 30.0}, + {"headers": {"X-API-Key": "secret"}}, + {"sse_read_timeout": 60.0}, + {"timeout": 15.0, "headers": {"Authorization": "Bearer token"}}, + {"session_kwargs": {"some_param": "value"}}, + ], + ) + def test_server_accepts_extra_params(self, extra_params: dict[str, Any]): + """Test that MCPServer accepts and preserves extra parameters.""" + server = MCPServer( + name="test-server", + url="https://example.com/mcp", + handle_tool_error=True, + **extra_params, + ) + + connection_dict = server.to_connection_dict() + + # Check that extra params are in connection dict + for key, value in extra_params.items(): + assert connection_dict[key] == value + assert "name" not in connection_dict + assert "handle_tool_error" not in connection_dict + + @pytest.mark.parametrize( + "handle_tool_error_value", + [ + True, + False, + "Custom error message", + lambda e: f"Error: {e}", + None, + ], + ) + def test_server_handle_tool_error_types(self, handle_tool_error_value: Any): + """Test that handle_tool_error accepts various types.""" + server = MCPServer( + name="test-server", + url="https://example.com/mcp", + handle_tool_error=handle_tool_error_value, + ) + + assert server.handle_tool_error == handle_tool_error_value + + +class TestDatabricksMCPServer: + """Tests for the DatabricksMCPServer class.""" + + def test_databricks_server_without_workspace_client(self): + """Test DatabricksMCPServer creates WorkspaceClient automatically.""" + with ( + patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, + patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth, + ): + mock_ws_instance = MagicMock() + mock_ws.return_value = mock_ws_instance + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksMCPServer(name="databricks", url="https://databricks.com/mcp") + + # Should have created WorkspaceClient + mock_ws.assert_called_once() + # Should have created auth provider + mock_auth.assert_called_once_with(mock_ws_instance) + + def test_databricks_server_with_workspace_client(self): + """Test DatabricksMCPServer uses provided WorkspaceClient.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + + with patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth: + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksMCPServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client, + ) + + # Should have used provided client + mock_auth.assert_called_once_with(mock_workspace_client) + assert server.workspace_client is mock_workspace_client + + connection_dict = server.to_connection_dict() + assert "workspace_client" not in connection_dict + assert "auth" in connection_dict + assert connection_dict["auth"] is mock_auth_instance + + def test_databricks_server_accepts_extra_params(self): + """Test that DatabricksMCPServer accepts extra connection params.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + + with patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth: + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksMCPServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client, + timeout=45.0, + headers={"X-Custom": "header"}, + ) + + connection_dict = server.to_connection_dict() + + assert connection_dict["timeout"] == 45.0 + assert connection_dict["headers"] == {"X-Custom": "header"} + + +class TestDatabricksMultiServerMCPClient: + """Tests for the DatabricksMultiServerMCPClient class.""" + + def test_client_initialization_with_multiple_servers(self): + """Test client initialization with multiple servers.""" + with patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init: + mock_init.return_value = None + + servers = [ + MCPServer(name="server1", url="https://server1.com/mcp"), + MCPServer(name="server2", url="https://server2.com/mcp"), + ] + client = DatabricksMultiServerMCPClient(servers) + + # Check that parent __init__ was called + mock_init.assert_called_once() + + # Check connections dict structure + call_kwargs = mock_init.call_args[1] + connections = call_kwargs["connections"] + + assert len(connections) == 2 + assert "server1" in connections + assert "server2" in connections + + assert hasattr(client, "_server_configs") + assert len(client._server_configs) == 2 + assert "server1" in client._server_configs + assert "server2" in client._server_configs + + @pytest.mark.asyncio + async def test_get_tools_all_servers(self): + """Test get_tools without server_name (all servers).""" + servers = [ + MCPServer(name="server1", url="https://server1.com/mcp", handle_tool_error=True), + MCPServer( + name="server2", url="https://server2.com/mcp", handle_tool_error="Custom error" + ), + ] + + # Create mock tools for each server + mock_tool1 = MagicMock() + mock_tool2 = MagicMock() + mock_tool3 = MagicMock() + + # Mock parent get_tools to return different tools for different servers + async def mock_get_tools_side_effect(server_name=None): + if server_name == "server1": + return [mock_tool1, mock_tool2] + elif server_name == "server2": + return [mock_tool3] + return [] + + with ( + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init, + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", + new_callable=AsyncMock, + side_effect=mock_get_tools_side_effect, + ) as mock_parent_get_tools, + ): + mock_init.return_value = None + + client = DatabricksMultiServerMCPClient(servers) + client.connections = { + "server1": servers[0].to_connection_dict(), + "server2": servers[1].to_connection_dict(), + } + + tools = await client.get_tools() + + # Should call parent get_tools for each server + assert mock_parent_get_tools.call_count == 2 + + # Should apply handle_tool_error from respective servers + assert mock_tool1.handle_tool_error is True + assert mock_tool2.handle_tool_error is True + assert mock_tool3.handle_tool_error == "Custom error" + + # Should return all tools + assert len(tools) == 3 + assert mock_tool1 in tools + assert mock_tool2 in tools + assert mock_tool3 in tools + + @pytest.mark.asyncio + async def test_get_tools_parallel_execution(self): + """Test that get_tools executes server requests in parallel.""" + servers = [MCPServer(name=f"server{i}", url=f"https://server{i}.com/mcp") for i in range(5)] + + call_count = 0 + call_times = [] + + async def mock_get_tools_with_delay(server_name=None): + nonlocal call_count + call_count += 1 + call_times.append(asyncio.get_event_loop().time()) + await asyncio.sleep(0.1) # Simulate async work + return [MagicMock()] + + with ( + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init, + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", + new_callable=AsyncMock, + side_effect=mock_get_tools_with_delay, + ) as mock_parent_get_tools, + ): + mock_init.return_value = None + + client = DatabricksMultiServerMCPClient(servers) + client.connections = {server.name: server.to_connection_dict() for server in servers} + + start_time = asyncio.get_event_loop().time() + tools = await client.get_tools() + end_time = asyncio.get_event_loop().time() + + # All 5 servers should be called + assert call_count == 5 + + # Should return tools from all servers + assert len(tools) == 5 + + @pytest.mark.asyncio + async def test_get_tools_with_databricks_server(self): + """Test get_tools with DatabricksMCPServer.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + mock_tool = MagicMock() + + with ( + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init, + patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth, + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", + new_callable=AsyncMock, + ) as mock_parent_get_tools, + ): + mock_init.return_value = None + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + mock_parent_get_tools.return_value = [mock_tool] + + server = DatabricksMCPServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client, + handle_tool_error=True, + ) + client = DatabricksMultiServerMCPClient([server]) + client.connections = {"databricks": server.to_connection_dict()} + + tools = await client.get_tools(server_name="databricks") + + # Should apply handle_tool_error + assert mock_tool.handle_tool_error is True + + # Connection should have auth + assert "auth" in client.connections["databricks"]