diff --git a/README.md b/README.md index 7b7fe50..9e49e6b 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ the following tools: * **UC Functions**: for each UC function, the server exposes a tool with the same name, arguments, and return type as the function * **Vector search indexes**: for each vector search index, the server exposes a tool for querying that vector search index * **Genie spaces**: for each Genie space, the server exposes tools for managing conversations and sending questions to the space +* **Tables**: the server exposes tools to list catalogs, schemas, tables and get table details, such as the table schema and table properties ### Deploying UC MCP server on Databricks Apps diff --git a/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py b/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py index 39b9af0..58f5e60 100644 --- a/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py +++ b/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py @@ -9,6 +9,7 @@ from databricks.labs.mcp._version import __version__ as VERSION from databricks.labs.mcp.servers.unity_catalog.cli import get_settings +from databricks.labs.mcp.servers.unity_catalog.tools.base_tool import BaseTool from databricks.labs.mcp.servers.unity_catalog.tools.genie import ( GenieTool, list_genie_tools, @@ -21,15 +22,19 @@ VectorSearchTool, list_vector_search_tools, ) +from databricks.labs.mcp.servers.unity_catalog.tools.tables import ( + TableBaseTool, + list_table_tools, +) from databricks.labs.mcp.utils import logger Content: TypeAlias = Union[TextContent, ImageContent, EmbeddedResource] -AvailableTool = UCFunctionTool | VectorSearchTool | GenieTool +AvailableTool = UCFunctionTool | VectorSearchTool | GenieTool | BaseTool | TableBaseTool def list_all_tools(settings) -> list[AvailableTool]: """ - Returns a list of all available tools, including Genie tools, UC functions, and vector search tools. + Returns a list of all available tools, including Genie tools, UC functions, vector search tools, and metadata tools. This function aggregates tools from different sources and returns them in a single list. """ @@ -37,6 +42,7 @@ def list_all_tools(settings) -> list[AvailableTool]: list_genie_tools(settings) + list_vector_search_tools(settings) + list_uc_function_tools(settings) + + list_table_tools(settings) ) diff --git a/src/databricks/labs/mcp/servers/unity_catalog/tools/tables.py b/src/databricks/labs/mcp/servers/unity_catalog/tools/tables.py new file mode 100644 index 0000000..308a122 --- /dev/null +++ b/src/databricks/labs/mcp/servers/unity_catalog/tools/tables.py @@ -0,0 +1,195 @@ +import json +from typing import Optional +from pydantic import BaseModel +from databricks.sdk import WorkspaceClient +from databricks.labs.mcp.servers.unity_catalog.tools.base_tool import BaseTool +from databricks.labs.mcp.servers.unity_catalog.cli import CliSettings +from mcp.types import TextContent, Tool as ToolSpec + + +class ListTablesInput(BaseModel): + catalog_name: Optional[str] = None + schema_name: Optional[str] = None + max_results: Optional[int] = None + + +class GetTableInput(BaseModel): + full_name: str + + +class ListTableSummariesInput(BaseModel): + catalog_name: Optional[str] = None + schema_name_pattern: Optional[str] = None + table_name_pattern: Optional[str] = None + max_results: Optional[int] = None + + +class ListCatalogsInput(BaseModel): + max_results: Optional[int] = None + + +class ListSchemasInput(BaseModel): + catalog_name: Optional[str] = None + max_results: Optional[int] = None + + +class TableBaseTool(BaseTool): + def __init__(self, tool_spec: ToolSpec): + self.tool_spec = tool_spec + + +class ListTablesTool(TableBaseTool): + def __init__(self): + tool_spec = ToolSpec( + name="list_tables", + description="List tables in a Unity Catalog schema. Returns detailed information about tables including columns, properties, and metadata.", + inputSchema=ListTablesInput.model_json_schema(), + ) + super().__init__(tool_spec) + + def execute(self, **kwargs): + model = ListTablesInput.model_validate(kwargs) + workspace_client = WorkspaceClient() + + # Use settings if catalog_name/schema_name not provided + settings = CliSettings() + catalog_name = model.catalog_name or settings.get_catalog_name() + schema_name = model.schema_name or settings.get_schema_name() + + if not catalog_name or not schema_name: + raise ValueError( + "catalog_name and schema_name must be provided or configured in settings" + ) + + tables = list( + workspace_client.tables.list( + catalog_name=catalog_name, + schema_name=schema_name, + max_results=model.max_results, + ) + ) + + # Convert to JSON-serializable format using the SDK's as_dict() method + tables_data = [table.as_dict() for table in tables] + + return [TextContent(type="text", text=json.dumps(tables_data, indent=2))] + + +class GetTableTool(TableBaseTool): + def __init__(self): + tool_spec = ToolSpec( + name="get_table", + description="Get detailed information about a specific table including its schema, properties, and metadata.", + inputSchema=GetTableInput.model_json_schema(), + ) + super().__init__(tool_spec) + + def execute(self, **kwargs): + model = GetTableInput.model_validate(kwargs) + workspace_client = WorkspaceClient() + + table = workspace_client.tables.get(full_name=model.full_name) + + # Convert to JSON-serializable format using the SDK's as_dict() method + table_data = table.as_dict() + + return [TextContent(type="text", text=json.dumps(table_data, indent=2))] + + +class ListTableSummariesTool(TableBaseTool): + def __init__(self): + tool_spec = ToolSpec( + name="list_table_summaries", + description="List table summaries for a catalog and schema. Returns concise information about tables including name, type, and basic metadata.", + inputSchema=ListTableSummariesInput.model_json_schema(), + ) + super().__init__(tool_spec) + + def execute(self, **kwargs): + model = ListTableSummariesInput.model_validate(kwargs) + workspace_client = WorkspaceClient() + + # Use settings if catalog_name not provided + settings = CliSettings() + catalog_name = model.catalog_name or settings.get_catalog_name() + + if not catalog_name: + raise ValueError("catalog_name must be provided or configured in settings") + + summaries = list( + workspace_client.tables.list_summaries( + catalog_name=catalog_name, + schema_name_pattern=model.schema_name_pattern, + table_name_pattern=model.table_name_pattern, + max_results=model.max_results, + ) + ) + + # Convert to JSON-serializable format using the SDK's as_dict() method + summaries_data = [summary.as_dict() for summary in summaries] + + return [TextContent(type="text", text=json.dumps(summaries_data, indent=2))] + + +class ListCatalogsTool(TableBaseTool): + def __init__(self): + tool_spec = ToolSpec( + name="list_catalogs", + description="List all catalogs in the Unity Catalog metastore that the user has access to.", + inputSchema=ListCatalogsInput.model_json_schema(), + ) + super().__init__(tool_spec) + + def execute(self, **kwargs): + model = ListCatalogsInput.model_validate(kwargs) + workspace_client = WorkspaceClient() + + catalogs = list(workspace_client.catalogs.list(max_results=model.max_results)) + + # Convert to JSON-serializable format using the SDK's as_dict() method + catalogs_data = [catalog.as_dict() for catalog in catalogs] + + return [TextContent(type="text", text=json.dumps(catalogs_data, indent=2))] + + +class ListSchemasTool(TableBaseTool): + def __init__(self): + tool_spec = ToolSpec( + name="list_schemas", + description="List schemas in a Unity Catalog catalog that the user has access to.", + inputSchema=ListSchemasInput.model_json_schema(), + ) + super().__init__(tool_spec) + + def execute(self, **kwargs): + model = ListSchemasInput.model_validate(kwargs) + workspace_client = WorkspaceClient() + + # Use settings if catalog_name not provided + settings = CliSettings() + catalog_name = model.catalog_name or settings.get_catalog_name() + + if not catalog_name: + raise ValueError("catalog_name must be provided or configured in settings") + + schemas = list( + workspace_client.schemas.list( + catalog_name=catalog_name, max_results=model.max_results + ) + ) + + # Convert to JSON-serializable format using the SDK's as_dict() method + schemas_data = [schema.as_dict() for schema in schemas] + + return [TextContent(type="text", text=json.dumps(schemas_data, indent=2))] + + +def list_table_tools(settings: CliSettings) -> list[TableBaseTool]: + """Returns a list of all table tools.""" + return [ + ListTablesTool(), + GetTableTool(), + ListTableSummariesTool(), + ListCatalogsTool(), + ListSchemasTool(), + ] diff --git a/tests/test_table_tools.py b/tests/test_table_tools.py new file mode 100644 index 0000000..40d56dd --- /dev/null +++ b/tests/test_table_tools.py @@ -0,0 +1,214 @@ +import pytest +from unittest.mock import Mock, patch +from databricks.sdk.service.catalog import ( + TableInfo, + TableSummary, + CatalogInfo, + SchemaInfo, + TableType, +) +from databricks.labs.mcp.servers.unity_catalog.tools.tables import ( + ListTablesTool, + GetTableTool, + ListTableSummariesTool, + ListCatalogsTool, + ListSchemasTool, + list_table_tools, +) + + +class DummyTablesAPI: + def list(self, catalog_name=None, schema_name=None, max_results=None): + # Create a real TableInfo object + table_info = TableInfo( + name="test_table", + full_name="catalog.schema.test_table", + catalog_name="catalog", + schema_name="schema", + table_type=TableType.MANAGED, + columns=[], + properties={}, + comment="Test table", + created_by="test_user", + updated_by="test_user", + owner="test_user", + ) + return [table_info] + + def get(self, full_name): + # Create a real TableInfo object + return TableInfo( + name="test_table", + full_name=full_name, + catalog_name="catalog", + schema_name="schema", + table_type=TableType.MANAGED, + columns=[], + properties={}, + comment="Test table", + created_by="test_user", + updated_by="test_user", + owner="test_user", + ) + + def list_summaries( + self, + catalog_name=None, + schema_name_pattern=None, + table_name_pattern=None, + max_results=None, + ): + # Create a real TableSummary object + summary = TableSummary( + full_name="catalog.schema.test_table", + table_type=TableType.MANAGED, + ) + return [summary] + + +class DummyCatalogsAPI: + def list(self, max_results=None): + # Create a real CatalogInfo object + catalog = CatalogInfo( + name="test_catalog", + comment="Test catalog", + properties={}, + created_by="test_user", + updated_by="test_user", + owner="test_user", + ) + return [catalog] + + +class DummySchemasAPI: + def list(self, catalog_name=None, max_results=None): + # Create a real SchemaInfo object + schema = SchemaInfo( + name="test_schema", + full_name="catalog.test_schema", + catalog_name="catalog", + comment="Test schema", + properties={}, + created_by="test_user", + updated_by="test_user", + owner="test_user", + ) + return [schema] + + +class DummyWorkspaceClient: + def __init__(self): + self.tables = DummyTablesAPI() + self.catalogs = DummyCatalogsAPI() + self.schemas = DummySchemasAPI() + + +@pytest.fixture +def mock_workspace_client(): + with patch( + "databricks.labs.mcp.servers.unity_catalog.tools.tables.WorkspaceClient" + ) as mock_client: + mock_client.return_value = DummyWorkspaceClient() + yield mock_client + + +@pytest.fixture +def mock_settings(): + with patch( + "databricks.labs.mcp.servers.unity_catalog.tools.tables.CliSettings" + ) as mock_settings: + mock_instance = Mock() + mock_instance.get_catalog_name.return_value = "test_catalog" + mock_instance.get_schema_name.return_value = "test_schema" + mock_settings.return_value = mock_instance + yield mock_settings + + +def test_list_tables_tool(mock_workspace_client, mock_settings): + tool = ListTablesTool() + result = tool.execute() + + assert len(result) == 1 + assert result[0].type == "text" + data = result[0].text + assert "test_table" in data + assert "catalog" in data + assert "schema" in data + + +def test_get_table_tool(mock_workspace_client): + tool = GetTableTool() + result = tool.execute(full_name="catalog.schema.test_table") + + assert len(result) == 1 + assert result[0].type == "text" + data = result[0].text + assert "catalog.schema.test_table" in data + + +def test_list_table_summaries_tool(mock_workspace_client, mock_settings): + tool = ListTableSummariesTool() + result = tool.execute() + + assert len(result) == 1 + assert result[0].type == "text" + data = result[0].text + assert "test_table" in data + + +def test_list_catalogs_tool(mock_workspace_client): + tool = ListCatalogsTool() + result = tool.execute() + + assert len(result) == 1 + assert result[0].type == "text" + data = result[0].text + assert "test_catalog" in data + + +def test_list_schemas_tool(mock_workspace_client, mock_settings): + tool = ListSchemasTool() + result = tool.execute() + + assert len(result) == 1 + assert result[0].type == "text" + data = result[0].text + assert "test_schema" in data + + +def test_list_table_tools(): + tools = list_table_tools(Mock()) + assert len(tools) == 5 + assert any(isinstance(tool, ListTablesTool) for tool in tools) + assert any(isinstance(tool, GetTableTool) for tool in tools) + assert any(isinstance(tool, ListTableSummariesTool) for tool in tools) + assert any(isinstance(tool, ListCatalogsTool) for tool in tools) + assert any(isinstance(tool, ListSchemasTool) for tool in tools) + + +def test_list_tables_tool_with_parameters(mock_workspace_client, mock_settings): + tool = ListTablesTool() + result = tool.execute(catalog_name="custom_catalog", schema_name="custom_schema") + + assert len(result) == 1 + assert result[0].type == "text" + + +def test_get_table_tool_missing_full_name(mock_workspace_client): + tool = GetTableTool() + with pytest.raises(ValueError): + tool.execute() + + +def test_list_tables_tool_missing_catalog_schema(mock_workspace_client): + with patch( + "databricks.labs.mcp.servers.unity_catalog.tools.tables.CliSettings" + ) as mock_settings: + mock_instance = Mock() + mock_instance.get_catalog_name.return_value = None + mock_instance.get_schema_name.return_value = None + mock_settings.return_value = mock_instance + + tool = ListTablesTool() + with pytest.raises(ValueError): + tool.execute()