diff --git a/mcp-server/README.md b/mcp-server/README.md new file mode 100644 index 00000000000..76c8776e722 --- /dev/null +++ b/mcp-server/README.md @@ -0,0 +1,399 @@ + + +# Apache Cloudberry MCP Server + +A Model Communication Protocol (MCP) server for Apache Cloudberry database interaction, providing secure and efficient database management capabilities through AI-ready interfaces. + +## Features + +- **Database Metadata Resources**: Access schemas, tables, views, indexes, and column information +- **Safe Query Tools**: Execute parameterized SQL queries with security validation +- **Administrative Tools**: Table statistics, large table analysis, and query optimization +- **Context-Aware Prompts**: Predefined prompts for common database tasks +- **Security-First Design**: SQL injection prevention, read-only constraints, and connection pooling +- **Async Performance**: Built with asyncpg for high-performance database operations + +## Prerequisites + +- Python 3.8+ +- uv (for dependency management) + +## Installation + +### Install uv + +```bash +curl -sSfL https://astral.sh/uv/install.sh | sh +``` + +### Install Dependencies + +```bash +cd mcp-server +uv venv +source .venv/bin/activate +uv sync +``` + +### Install Project + +```bash +uv pip install -e . +``` + +### Build Project + +```bash +uv build +``` + +## Configuration + +Create a `.env` file in the project root: + +```env +# Database Configuration +DB_HOST=localhost +DB_PORT=5432 +DB_NAME=postgres +DB_USER=postgres +DB_PASSWORD=your_password + +# Server Configuration +MCP_HOST=localhost +MCP_PORT=8000 +MCP_DEBUG=false +``` + +## Usage + +### Running the Server + +```bash +# Run the MCP server +python -m cbmcp.server + +# Or run with cloudberry-mcp-server +cloudberry-mcp-server + +# Or run with custom configuration +MCP_HOST=0.0.0.0 MCP_PORT=8080 python -m cbmcp.server +``` + +### Testing the Client + +```bash +# Run the test client +python -m cbmcp.client +``` + +## API Reference + +### Resources + +- `postgres://schemas` - List all database schemas +- `postgres://database/info` - Get general database info +- `postgres://database/summary` - Get detailed database summary + +### Tools + +#### Query Tools +- `execute_query(query, params, readonly)` - Execute a SQL query +- `explain_query(query, params)` - Get query execution plan +- `get_table_stats(schema, table)` - Get table statistics +- `list_large_tables(limit)` - List largest tables + +#### User & Permission Management +- `list_users()` - List all database users +- `list_user_permissions(username)` - List permissions for a specific user +- `list_table_privileges(schema, table)` - List privileges for a specific table + +#### Schema & Structure +- `list_constraints(schema, table)` - List constraints for a table +- `list_foreign_keys(schema, table)` - List foreign keys for a table +- `list_referenced_tables(schema, table)` - List tables that reference this table +- `get_table_ddl(schema, table)` - Get DDL statement for a table + +#### Performance & Monitoring +- `get_slow_queries(limit)` - List slow queries +- `get_index_usage()` - Analyze index usage statistics +- `get_table_bloat_info()` - Analyze table bloat information +- `get_database_activity()` - Show current database activity +- `get_vacuum_info()` - Get vacuum and analyze statistics + +#### Database Objects +- `list_functions(schema)` - List functions in a schema +- `get_function_definition(schema, function)` - Get function definition +- `list_triggers(schema, table)` - List triggers for a table +- `list_materialized_views(schema)` - List materialized views in a schema +- `list_active_connections()` - List active database connections + +### Prompts + +- `analyze_query_performance` - Query optimization assistance +- `suggest_indexes` - Index recommendation guidance +- `database_health_check` - Database health assessment + +## Security Features + +- **SQL Injection Prevention**: Comprehensive query validation +- **Read-Only Constraints**: Configurable write protection +- **Parameterized Queries**: Safe parameter handling +- **Connection Pooling**: Secure connection management +- **Sensitive Table Protection**: Blocks access to system tables + + +## Quick Start with Cloudberry Demo Cluster + +This section shows how to quickly set up and test the Cloudberry MCP Server using a local Cloudberry demo cluster. This is ideal for development and testing purposes. + +Assume you already have a running [Cloudberry demo cluster](https://cloudberry.apache.org/docs/deployment/set-demo-cluster) and install & build MCP server as described above. + +1. Configure local connections in `pg_hba.conf` + +**Note**: This configuration is for demo purposes only. Do not use `trust` authentication in production environments. + +```bash +[gpadmin@cdw]$ vi ~/cloudberry/gpAux/gpdemo/datadirs/qddir/demoDataDir-1/pg_hba.conf +``` + +Add the following lines to the end of the pg_hba.conf: + +``` +# IPv4 local connections +host all all 127.0.0.1/32 trust +# IPv6 local connections +host all all ::1/128 trust +``` + +After modifying `pg_hba.conf`, reload the configuration parameters: +```bash +[gpadmin@cdw]$ gpstop -u +``` + +2. Create environment configuration + +Create a `.env` in the project root directory: + +``` +# Database Configuration (Demo cluster defaults) +DB_HOST=localhost +DB_PORT=7000 +DB_NAME=postgres +DB_USER=gpadmin +# No password required for demo cluster + +# Server Configuration +MCP_HOST=localhost +MCP_PORT=8000 +MCP_DEBUG=false +``` + +3. Start the MCP server + +```bash +MCP_HOST=0.0.0.0 MCP_PORT=8000 python -m cbmcp.server +``` + +You should see output indicating the server is running: +``` +[09/17/25 14:07:50] INFO Starting MCP server 'Apache Cloudberry MCP Server' with transport server.py:1572 + 'streamable-http' on http://0.0.0.0:8000/mcp/ +``` + +4. Configure your MCP client. + +Add the following server configuration to your MCP client: + +- Server Type: Streamable-HTTP +- URL: http://[YOUR_HOST_IP]:8000/mcp + +Replace `[YOUR_HOST_IP]` with your actual host IP address. + + +## LLM Client Integration + +### Claude Desktop Configuration + +Add the following configuration to your Claude Desktop configuration file: + +#### Stdio Transport (Recommended) + +```json +{ + "mcpServers": { + "cloudberry-mcp-server": { + "command": "uvx", + "args": [ + "--with", + "PATH/TO/cbmcp-0.1.0-py3-none-any.whl", + "python", + "-m", + "cbmcp.server", + "--mode", + "stdio" + ], + "env": { + "DB_HOST": "localhost", + "DB_PORT": "5432", + "DB_NAME": "dvdrental", + "DB_USER": "yangshengwen", + "DB_PASSWORD": "" + } + } + } +} +``` + +#### HTTP Transport + +```json +{ + "mcpServers": { + "cloudberry-mcp-server": { + "type": "streamable-http", + "url": "https://localhost:8000/mcp/", + "headers": { + "Authorization": "" + } + } + } +} +``` + +### Cursor Configuration + +For Cursor IDE, add the configuration to your `.cursor/mcp.json` file: + +```json +{ + "mcpServers": { + "cloudberry-mcp": { + "command": "uvx", + "args": ["--with", "cbmcp", "python", "-m", "cbmcp.server", "--mode", "stdio"], + "env": { + "DB_HOST": "localhost", + "DB_PORT": "5432", + "DB_NAME": "dvdrental", + "DB_USER": "postgres", + "DB_PASSWORD": "your_password" + } + } + } +} +``` + +### Windsurf Configuration + +For Windsurf IDE, configure in your settings: + +```json +{ + "mcp": { + "servers": { + "cloudberry-mcp": { + "type": "stdio", + "command": "uvx", + "args": ["--with", "cbmcp", "python", "-m", "cbmcp.server", "--mode", "stdio"], + "env": { + "DB_HOST": "localhost", + "DB_PORT": "5432", + "DB_NAME": "dvdrental", + "DB_USER": "postgres", + "DB_PASSWORD": "your_password" + } + } + } + } +} +``` + +### VS Code with Cline + +For VS Code with the Cline extension, add to your settings: + +```json +{ + "cline.mcpServers": { + "cloudberry-mcp": { + "command": "uvx", + "args": ["--with", "cbmcp", "python", "-m", "cbmcp.server", "--mode", "stdio"], + "env": { + "DB_HOST": "localhost", + "DB_PORT": "5432", + "DB_NAME": "dvdrental", + "DB_USER": "postgres", + "DB_PASSWORD": "your_password" + } + } + } +} +``` + +### Installation via pip + +If you prefer to install the package globally instead of using uvx: + +```bash +# Install the package +pip install cbmcp-0.1.0-py3-none-any.whl + +# Or using pip install from source +pip install -e . + +# Then use in configuration +{ + "command": "python", + "args": ["-m", "cbmcp.server", "--mode", "stdio"] +} +``` + +### Environment Variables + +All configurations support the following environment variables: + +- `DB_HOST`: Database host (default: localhost) +- `DB_PORT`: Database port (default: 5432) +- `DB_NAME`: Database name (default: postgres) +- `DB_USER`: Database username +- `DB_PASSWORD`: Database password +- `MCP_HOST`: Server host for HTTP mode (default: localhost) +- `MCP_PORT`: Server port for HTTP mode (default: 8000) +- `MCP_DEBUG`: Enable debug logging (default: false) + +### Troubleshooting + +#### Common Issues + +1. **Connection refused**: Ensure Apache Cloudberry is running and accessible +2. **Authentication failed**: Check database credentials in environment variables +3. **Module not found**: Ensure the package is installed correctly +4. **Permission denied**: Check file permissions for the package + +#### Debug Mode + +Enable debug logging by setting: +```bash +export MCP_DEBUG=true +``` + +## License + +Apache License 2.0 \ No newline at end of file diff --git a/mcp-server/dotenv.example b/mcp-server/dotenv.example new file mode 100644 index 00000000000..896cb33a539 --- /dev/null +++ b/mcp-server/dotenv.example @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example environment configuration +# Copy this file to .env and update with your actual values + +# Database Configuration +DB_HOST=localhost +DB_PORT=5432 +DB_NAME=postgres +DB_USER=postgres +DB_PASSWORD=your_password_here + +# Server Configuration +MCP_HOST=localhost +MCP_PORT=8000 +MCP_DEBUG=false diff --git a/mcp-server/pyproject.toml b/mcp-server/pyproject.toml new file mode 100644 index 00000000000..984cb5e12a1 --- /dev/null +++ b/mcp-server/pyproject.toml @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[project] +name = "cloudberry-mcp-server" +version = "0.1.0" +description = "MCP server for Apache Cloudberry database interaction" +readme = "README.md" +requires-python = ">=3.10" +authors = [ + {name = "Shengwen Yang", email = "yangshengwen@gmail.com"}, +] +maintainers = [ + {name = "Shengwen Yang", email = "yangshengwen@gmail.com"}, +] +license = {text = "Apache License 2.0"} +keywords = ["mcp", "cloudberry", "postgresql", "database", "server", "ai"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License, Version 2.0", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Database", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Database :: Front-Ends", +] +dependencies = [ + "fastmcp>=2.10.6", + "psycopg2-binary==2.9.10", + "asyncpg>=0.29.0", + "pydantic>=2.0.0", + "python-dotenv>=1.0.0", + "aiohttp>=3.12.15", + "starlette>=0.27.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.0.0", +] + +[project.urls] +Homepage = "https://github.com/apache/cloudberry//tree/main/mcp-server" +Repository = "https://github.com/apache/cloudberry" +Documentation = "https://github.com/apache/cloudberry/mcp-server/tree/main/mcp-server/README.md" + +[project.scripts] +cloudberry-mcp-server = "cbmcp.server:main" diff --git a/mcp-server/pytest.ini b/mcp-server/pytest.ini new file mode 100644 index 00000000000..519115d6a19 --- /dev/null +++ b/mcp-server/pytest.ini @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -v + --tb=short + --strict-markers + --disable-warnings +asyncio_mode = auto +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + integration: marks tests as integration tests + unit: marks tests as unit tests \ No newline at end of file diff --git a/mcp-server/run_tests.sh b/mcp-server/run_tests.sh new file mode 100755 index 00000000000..a47184cab9f --- /dev/null +++ b/mcp-server/run_tests.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Test script for Apache Cloudberry MCP Server + +echo "=== Install test dependencies ===" +uv pip install -e ".[dev]" + +echo "=== Run all tests ===" +uv run pytest tests/ -v + +echo "=== Run specific test patterns ===" +echo "Run stdio mode test:" +uv run pytest tests/test_cbmcp.py::TestCloudberryMCPClient::test_list_capabilities -v + +echo "Run http mode test:" +uv run pytest tests/test_cbmcp.py::TestCloudberryMCPClient::test_list_capabilities -v + +echo "=== Run coverage tests ===" +uv run pytest tests/ --cov=cbmcp --cov-report=html --cov-report=term + +echo "=== Test completed ===" diff --git a/mcp-server/src/cbmcp/__init__.py b/mcp-server/src/cbmcp/__init__.py new file mode 100644 index 00000000000..a584c415ee6 --- /dev/null +++ b/mcp-server/src/cbmcp/__init__.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + # -*- coding: utf-8 -*- + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + +""" +Apache Cloudberry MCP Server Package +""" + +from .server import CloudberryMCPServer +from .client import CloudberryMCPClient +from .config import DatabaseConfig, ServerConfig +from .database import DatabaseManager +from .security import SQLValidator + +__version__ = "0.1.0" +__all__ = [ + "CloudberryMCPServer", + "CloudberryMCPClient", + "DatabaseConfig", + "ServerConfig", + "DatabaseManager", + "SQLValidator", +] \ No newline at end of file diff --git a/mcp-server/src/cbmcp/__main__.py b/mcp-server/src/cbmcp/__main__.py new file mode 100644 index 00000000000..f467fb2ac53 --- /dev/null +++ b/mcp-server/src/cbmcp/__main__.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Main entry point for the cbmcp package +""" + +from .server import main + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mcp-server/src/cbmcp/client.py b/mcp-server/src/cbmcp/client.py new file mode 100644 index 00000000000..b88f9993bd1 --- /dev/null +++ b/mcp-server/src/cbmcp/client.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP Client for testing the Apache Cloudberry MCP Server + +A client using the fastmcp SDK to interact with the Apache Cloudberry MCP server implementation. +""" + +from typing import Any, Dict, Optional +from fastmcp import Client + +from .config import DatabaseConfig, ServerConfig +from .server import CloudberryMCPServer + +class CloudberryMCPClient: + """MCP client for testing the Apache Cloudberry server using fastmcp SDK + + Usage: + # Method 1: Using async context manager + async with CloudberryMCPClient() as client: + tools = await client.list_tools() + resources = await client.list_resources() + + # Method 2: Using create class method + client = await CloudberryMCPClient.create() + tools = await client.list_tools() + await client.close() + + # Method 3: Manual initialization + client = CloudberryMCPClient() + await client.initialize() + tools = await client.list_tools() + await client.close() + """ + + def __init__(self, mode: str = "stdio", server_url: str = "http://localhost:8000/mcp/"): + self.mode = mode + self.server_url = server_url + self.client: Optional[Client] = None + + @classmethod + async def create(cls, mode: str = "stdio", server_url: str = "http://localhost:8000/mcp/") -> "CloudberryMCPClient": + """Asynchronously create and initialize the client""" + instance = cls(mode, server_url) + await instance.initialize() + return instance + + async def initialize(self): + """Initialize the client connection""" + if self.mode == "stdio": + server_config = ServerConfig.from_env() + db_config = DatabaseConfig.from_env() + server = CloudberryMCPServer(server_config, db_config) + self.client = Client(server.mcp) + else: + self.client = Client(self.server_url) + + await self.client.__aenter__() + + async def close(self): + """Close the client connection""" + if self.client: + await self.client.__aexit__(None, None, None) + self.client = None + + async def __aenter__(self): + if self.mode == "stdio": + server_config = ServerConfig.from_env() + db_config = DatabaseConfig.from_env() + server = CloudberryMCPServer(server_config, db_config) + self.client = Client(server.mcp) + else: + self.client = Client(self.server_url) + + await self.client.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.client: + await self.client.__aexit__(exc_type, exc_val, exc_tb) + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """Call a tool on the MCP server""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + + return await self.client.call_tool(tool_name, arguments) + + async def get_resource(self, resource_uri: str): + """Get a resource from the MCP server""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + + return await self.client.read_resource(resource_uri) + + async def get_prompt(self, prompt_name: str, params: Dict[str, Any]=None): + """Get a prompt from the MCP server""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + + return await self.client.get_prompt(prompt_name, params) + + async def list_tools(self) -> list: + """List available tools on the server""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + + return await self.client.list_tools() + + async def list_resources(self) -> list: + """List available resources on the server""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + + return await self.client.list_resources() + + async def list_prompts(self) -> list: + """List available prompts on the server""" + if not self.client: + raise RuntimeError("Client not initialized. Use async with statement.") + + return await self.client.list_prompts() + + +if __name__ == "__main__": + import asyncio + + async def main(): + async with CloudberryMCPClient(mode="http") as client: + results = await client.call_tool("execute_query", { + "query": "SELECT * FROM film LIMIT 5" + }) + print("Results:", results) + + results = await client.call_tool("list_columns", { + "table": "film", + "schema": "public" + }) + print("Columns:", results) + + asyncio.run(main()) \ No newline at end of file diff --git a/mcp-server/src/cbmcp/config.py b/mcp-server/src/cbmcp/config.py new file mode 100644 index 00000000000..a810e67ed8f --- /dev/null +++ b/mcp-server/src/cbmcp/config.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Configuration utilities for the Apache Cloudberry MCP server +""" + +import os +from dataclasses import dataclass +from dotenv import load_dotenv + + +@dataclass +class DatabaseConfig: + """Database connection configuration""" + host: str + port: int + database: str + username: str + password: str + + @classmethod + def from_env(cls) -> "DatabaseConfig": + """Create config from environment variables""" + load_dotenv() + return cls( + host=os.getenv("DB_HOST", "localhost"), + port=int(os.getenv("DB_PORT", "5432")), + database=os.getenv("DB_NAME", "postgres"), + username=os.getenv("DB_USER", "postgres"), + password=os.getenv("DB_PASSWORD", ""), + ) + + @property + def connection_string(self) -> str: + """Get Apache Cloudberry connection string""" + return f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}" + + +@dataclass +class ServerConfig: + """MCP server configuration""" + host: str + port: int + path: str + debug: bool + + @classmethod + def from_env(cls) -> "ServerConfig": + """Create config from environment variables""" + load_dotenv() + return cls( + host=os.getenv("MCP_HOST", "localhost"), + port=int(os.getenv("MCP_PORT", "8000")), + path=os.getenv("MCP_PATH", "/mcp/"), + debug=os.getenv("MCP_DEBUG", "false").lower() == "true", + ) \ No newline at end of file diff --git a/mcp-server/src/cbmcp/database.py b/mcp-server/src/cbmcp/database.py new file mode 100644 index 00000000000..77cdd2b308b --- /dev/null +++ b/mcp-server/src/cbmcp/database.py @@ -0,0 +1,773 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Database utilities for the Apache Cloudberry MCP server +""" + +import logging +from typing import Any, Dict, Optional +from contextlib import asynccontextmanager +import asyncpg + +from .config import DatabaseConfig +from .security import SQLValidator + + +logger = logging.getLogger(__name__) + + +class DatabaseManager: + """Manages database connections and operations""" + + def __init__(self, config: DatabaseConfig): + self.config = config + self._connection_pool: Optional[asyncpg.Pool] = None + + @asynccontextmanager + async def get_connection(self): + """Get a database connection from the pool""" + if not self._connection_pool: + self._connection_pool = await asyncpg.create_pool( + host=self.config.host, + port=self.config.port, + database=self.config.database, + user=self.config.username, + password=self.config.password, + min_size=1, + max_size=10, + command_timeout=60.0, + ) + + try: + async with self._connection_pool.acquire() as conn: + yield conn + except Exception as e: + logger.error(f"Database connection error: {e}") + raise + + async def execute_query( + self, + query: str, + params: Optional[Dict[str, Any]] = None, + readonly: bool = True + ) -> Dict[str, Any]: + """Execute a SQL query with safety validation""" + # Validate query for security + is_valid, error_msg = SQLValidator.validate_query(query) + if not is_valid: + return {"error": f"Query validation failed: {error_msg}"} + + # Check readonly constraint + if readonly and not SQLValidator.is_readonly_query(query): + return {"error": "Only read-only queries are allowed"} + + try: + async with self.get_connection() as conn: + if params: + # Sanitize parameter names + sanitized_params = { + SQLValidator.sanitize_parameter_name(k): v + for k, v in params.items() + } + result = await conn.fetch(query, **sanitized_params) + else: + result = await conn.fetch(query) + + if not result: + return {"columns": [], "rows": [], "row_count": 0} + + columns = list(result[0].keys()) + rows = [list(row.values()) for row in result] + + return { + "columns": columns, + "rows": rows, + "row_count": len(rows) + } + + except Exception as e: + logger.error(f"Query execution error: {e}") + return {"error": f"Error executing query: {str(e)}"} + + + async def get_table_info(self, schema: str, table: str) -> Dict[str, Any]: + """Get detailed information about a table""" + try: + async with self.get_connection() as conn: + # Get column information + columns = await conn.fetch( + "SELECT column_name, data_type, is_nullable, column_default " + "FROM information_schema.columns " + "WHERE table_schema = $1 AND table_name = $2 " + "ORDER BY ordinal_position", + schema, table + ) + + # Get index information + indexes = await conn.fetch( + "SELECT indexname, indexdef FROM pg_indexes " + "WHERE schemaname = $1 AND tablename = $2 " + "ORDER BY indexname", + schema, table + ) + + # Get table statistics + stats = await conn.fetchrow( + "SELECT " + "pg_size_pretty(pg_total_relation_size($1)) as total_size, " + "pg_size_pretty(pg_relation_size($1)) as table_size, " + "pg_size_pretty(pg_total_relation_size($1) - pg_relation_size($1)) as indexes_size, " + "(SELECT COUNT(*) FROM $1) as row_count", + f"{schema}.{table}" + ) + + return { + "columns": [dict(col) for col in columns], + "indexes": [dict(idx) for idx in indexes], + "statistics": dict(stats) if stats else {} + } + + except Exception as e: + logger.error(f"Error getting table info: {e}") + return {"error": str(e)} + + async def close(self): + """Close the connection pool""" + if self._connection_pool: + await self._connection_pool.close() + self._connection_pool = None + + async def list_schemas(self) -> list[str]: + """List all database schemas""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT schema_name FROM information_schema.schemata " + "WHERE schema_name NOT LIKE 'pg_%' AND schema_name != 'information_schema' " + "ORDER BY schema_name" + ) + return [r["schema_name"] for r in records] + + async def get_database_info(self) -> dict[str, str]: + """Get general database information""" + async with self.get_connection() as conn: + version = await conn.fetchval("SELECT version()") + size = await conn.fetchval("SELECT pg_size_pretty(pg_database_size(current_database()))") + stats = await conn.fetchrow( + "SELECT COUNT(*) as total_tables FROM information_schema.tables " + "WHERE table_type = 'BASE TABLE' AND table_schema NOT LIKE 'pg_%'" + ) + + return { + "Version": version, + "Size": size, + "Total Tables": str(stats['total_tables']) + } + + async def get_database_summary(self) -> dict[str, dict]: + """Get comprehensive database summary""" + summary = {} + + async with self.get_connection() as conn: + # Get schemas + schemas = await conn.fetch( + "SELECT schema_name FROM information_schema.schemata " + "WHERE schema_name NOT LIKE 'pg_%' AND schema_name != 'information_schema' " + "ORDER BY schema_name" + ) + + for schema_row in schemas: + schema = schema_row["schema_name"] + summary[schema] = {} + + # Get tables + tables = await conn.fetch( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = $1 AND table_type = 'BASE TABLE' " + "ORDER BY table_name", + schema + ) + summary[schema]["tables"] = [t["table_name"] for t in tables] + + # Get views + views = await conn.fetch( + "SELECT table_name FROM information_schema.views " + "WHERE table_schema = $1 " + "ORDER BY table_name", + schema + ) + summary[schema]["views"] = [v["table_name"] for v in views] + + return summary + + async def list_tables(self, schema: str) -> list[str]: + """List tables in a specific schema""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = $1 AND table_type = 'BASE TABLE' " + "ORDER BY table_name", + schema + ) + return [r["table_name"] for r in records] + + async def list_views(self, schema: str) -> list[str]: + """List views in a specific schema""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT table_name FROM information_schema.views " + "WHERE table_schema = $1 " + "ORDER BY table_name", + schema + ) + return [r["table_name"] for r in records] + + async def list_indexes(self, schema: str, table: str) -> list[dict]: + """List indexes for a specific table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT indexname, indexdef FROM pg_indexes " + "WHERE schemaname = $1 AND tablename = $2 " + "ORDER BY indexname", + schema, table + ) + return [{"indexname": r["indexname"], "indexdef": r["indexdef"]} for r in records] + + async def list_columns(self, schema: str, table: str) -> list[dict]: + """List columns for a specific table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT column_name, data_type, is_nullable, column_default " + "FROM information_schema.columns " + "WHERE table_schema = $1 AND table_name = $2 " + "ORDER BY ordinal_position", + schema, table + ) + return [ + { + "column_name": r["column_name"], + "data_type": r["data_type"], + "is_nullable": r["is_nullable"], + "column_default": r["column_default"] + } + for r in records + ] + + async def get_table_stats(self, schema: str, table: str) -> dict[str, Any]: + """Get statistics for a table""" + try: + # Validate schema and table names to prevent SQL injection + if not schema.replace('_', '').replace('-', '').isalnum(): + return {"error": "Invalid schema name"} + if not table.replace('_', '').replace('-', '').isalnum(): + return {"error": "Invalid table name"} + + async with self.get_connection() as conn: + # Use format() with proper identifier quoting + qualified_name = f"{schema}.{table}" + sql = ( + f"SELECT " + f"pg_size_pretty(pg_total_relation_size('{qualified_name}')) as total_size, " + f"pg_size_pretty(pg_relation_size('{qualified_name}')) as table_size, " + f"pg_size_pretty(pg_total_relation_size('{qualified_name}') - pg_relation_size('{qualified_name}')) as indexes_size, " + f"(SELECT COUNT(*) FROM {qualified_name}) as row_count" + ) + result = await conn.fetchrow(sql) + + if not result: + return {"error": f"Table {schema}.{table} not found"} + + return { + "total_size": result["total_size"], + "table_size": result["table_size"], + "indexes_size": result["indexes_size"], + "row_count": result["row_count"] + } + except Exception as e: + return {"error": f"Error getting table stats: {str(e)}"} + + async def list_large_tables(self, limit: int = 10) -> list[dict]: + """List the largest tables in the database""" + async with self.get_connection() as conn: + result = await conn.fetch( + "SELECT " + "schemaname, tablename, " + "pg_size_pretty(pg_total_relation_size(schemaname||'.'||tablename)) as size, " + "pg_total_relation_size(schemaname||'.'||tablename) as size_bytes " + "FROM pg_tables " + "WHERE schemaname NOT LIKE 'pg_%' " + "ORDER BY pg_total_relation_size(schemaname||'.'||tablename) DESC " + "LIMIT $1", + limit + ) + + return [ + { + "schema": row["schemaname"], + "table": row["tablename"], + "size": row["size"], + "size_bytes": row["size_bytes"] + } + for row in result + ] + + async def list_users(self) -> list[str]: + """List all database users""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT usename FROM pg_user WHERE usename != 'cloudberry' " + "ORDER BY usename" + ) + return [r["usename"] for r in records] + + async def list_user_permissions(self, username: str) -> list[dict]: + """List permissions for a specific user""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "n.nspname as schema_name, " + "c.relname as object_name, " + "c.relkind as object_type, " + "p.perm as permission " + "FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "CROSS JOIN LATERAL aclexplode(c.relacl) p " + "WHERE p.grantee = (SELECT oid FROM pg_user WHERE usename = $1) " + "AND n.nspname NOT LIKE 'pg_%' " + "ORDER BY n.nspname, c.relname", + username + ) + return [ + { + "schema": r["schema_name"], + "object": r["object_name"], + "type": r["object_type"], + "permission": r["permission"] + } + for r in records + ] + + async def list_table_privileges(self, schema: str, table: str) -> list[dict]: + """List privileges for a specific table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "grantee, privilege_type " + "FROM information_schema.table_privileges " + "WHERE table_schema = $1 AND table_name = $2 " + "ORDER BY grantee, privilege_type", + schema, table + ) + return [ + { + "user": r["grantee"], + "privilege": r["privilege_type"] + } + for r in records + ] + + async def list_constraints(self, schema: str, table: str) -> list[dict]: + """List constraints for a specific table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "c.conname as constraint_name, " + "c.contype as constraint_type, " + "pg_get_constraintdef(c.oid) as constraint_definition, " + "f.relname as foreign_table_name, " + "nf.nspname as foreign_schema_name " + "FROM pg_constraint c " + "JOIN pg_class t ON t.oid = c.conrelid " + "JOIN pg_namespace n ON n.oid = t.relnamespace " + "LEFT JOIN pg_class f ON f.oid = c.confrelid " + "LEFT JOIN pg_namespace nf ON nf.oid = f.relnamespace " + "WHERE n.nspname = $1 AND t.relname = $2 " + "ORDER BY c.conname", + schema, table + ) + constraints = [] + for r in records: + constraint_type = { + 'p': 'PRIMARY KEY', + 'f': 'FOREIGN KEY', + 'u': 'UNIQUE', + 'c': 'CHECK', + 'x': 'EXCLUSION' + }.get(r["constraint_type"], r["constraint_type"]) + + constraints.append({ + "name": r["constraint_name"], + "type": constraint_type, + "definition": r["constraint_definition"], + "foreign_table": r["foreign_table_name"], + "foreign_schema": r["foreign_schema_name"] + }) + return constraints + + async def list_foreign_keys(self, schema: str, table: str) -> list[dict]: + """List foreign keys for a specific table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "tc.constraint_name, " + "tc.table_name, " + "kcu.column_name, " + "ccu.table_name AS foreign_table_name, " + "ccu.column_name AS foreign_column_name, " + "ccu.table_schema AS foreign_schema_name " + "FROM information_schema.table_constraints AS tc " + "JOIN information_schema.key_column_usage AS kcu " + "ON tc.constraint_name = kcu.constraint_name " + "JOIN information_schema.constraint_column_usage AS ccu " + "ON ccu.constraint_name = tc.constraint_name " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema = $1 AND tc.table_name = $2 " + "ORDER BY tc.constraint_name", + schema, table + ) + return [ + { + "constraint_name": r["constraint_name"], + "column": r["column_name"], + "foreign_schema": r["foreign_schema_name"], + "foreign_table": r["foreign_table_name"], + "foreign_column": r["foreign_column_name"] + } + for r in records + ] + + async def list_referenced_tables(self, schema: str, table: str) -> list[dict]: + """List tables that reference this table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "tc.table_schema, " + "tc.table_name, " + "tc.constraint_name " + "FROM information_schema.table_constraints AS tc " + "JOIN information_schema.constraint_column_usage AS ccu " + "ON ccu.constraint_name = tc.constraint_name " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND ccu.table_schema = $1 AND ccu.table_name = $2 " + "ORDER BY tc.table_schema, tc.table_name", + schema, table + ) + return [ + { + "schema": r["table_schema"], + "table": r["table_name"], + "constraint": r["constraint_name"] + } + for r in records + ] + + async def explain_query(self, query: str, params: Optional[dict] = None) -> str: + """Get the execution plan for a query""" + try: + async with self.get_connection() as conn: + if params: + result = await conn.fetch(f"EXPLAIN (ANALYZE, BUFFERS) {query}", **params) + else: + result = await conn.fetch(f"EXPLAIN (ANALYZE, BUFFERS) {query}") + + return "\n".join([row["QUERY PLAN"] for row in result]) + except Exception as e: + return f"Error explaining query: {str(e)}" + + async def get_slow_queries(self, limit: int = 10) -> list[dict]: + """Get slow queries from pg_stat_statements""" + async with self.get_connection() as conn: + try: + records = await conn.fetch( + "SELECT " + "query, " + "calls, " + "total_time, " + "mean_time, " + "rows " + "FROM pg_stat_statements " + "ORDER BY mean_time DESC " + "LIMIT $1", + limit + ) + return [ + { + "query": r["query"], + "calls": r["calls"], + "total_time": r["total_time"], + "mean_time": r["mean_time"], + "rows": r["rows"] + } + for r in records + ] + except Exception: + # pg_stat_statements might not be available + return [] + + async def get_index_usage(self) -> list[dict]: + """Get index usage statistics""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "schemaname, " + "relname as tablename, " + "indexrelname as indexname, " + "idx_scan, " + "idx_tup_read, " + "idx_tup_fetch " + "FROM pg_stat_user_indexes " + "WHERE schemaname NOT LIKE 'pg_%' " + "ORDER BY idx_scan DESC" + ) + return [ + { + "schema": r["schemaname"], + "table": r["tablename"], + "index": r["indexname"], + "scans": r["idx_scan"], + "tup_read": r["idx_tup_read"], + "tup_fetch": r["idx_tup_fetch"] + } + for r in records + ] + + async def get_table_bloat_info(self) -> list[dict]: + """Get table bloat information""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "schemaname, " + "relname as tablename, " + "pg_size_pretty(pg_total_relation_size(schemaname||'.'||relname)) as total_size, " + "round(100 * (relpages - (relpages * fillfactor / 100)) / relpages, 2) as bloat_ratio " + "FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "JOIN pg_stat_user_tables s ON s.relid = c.oid " + "WHERE c.relkind = 'r' AND n.nspname NOT LIKE 'pg_%' " + "ORDER BY bloat_ratio DESC " + "LIMIT 20" + ) + return [ + { + "schema": r["schemaname"], + "table": r["tablename"], + "total_size": r["total_size"], + "bloat_ratio": r["bloat_ratio"] + } + for r in records + ] + + async def get_database_activity(self) -> list[dict]: + """Get current database activity""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "pid, " + "usename, " + "application_name, " + "client_addr, " + "state, " + "query_start, " + "query " + "FROM pg_stat_activity " + "WHERE state != 'idle' AND usename != 'cloudberry' " + "ORDER BY query_start" + ) + return [ + { + "pid": r["pid"], + "username": r["usename"], + "application": r["application_name"], + "client_addr": str(r["client_addr"]) if r["client_addr"] else None, + "state": r["state"], + "query_start": str(r["query_start"]), + "query": r["query"] + } + for r in records + ] + + async def list_functions(self, schema: str) -> list[dict]: + """List functions in a specific schema""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "proname as function_name, " + "pg_get_function_identity_arguments(p.oid) as arguments, " + "pg_get_function_result(p.oid) as return_type, " + "prokind as function_type " + "FROM pg_proc p " + "JOIN pg_namespace n ON n.oid = p.pronamespace " + "WHERE n.nspname = $1 AND p.prokind IN ('f', 'p') " + "ORDER BY proname", + schema + ) + return [ + { + "name": r["function_name"], + "arguments": r["arguments"], + "return_type": r["return_type"], + "type": "function" if r["function_type"] == "f" else "procedure" + } + for r in records + ] + + async def get_function_definition(self, schema: str, function_name: str) -> str: + """Get function definition""" + try: + async with self.get_connection() as conn: + definition = await conn.fetchval( + "SELECT pg_get_functiondef(p.oid) " + "FROM pg_proc p " + "JOIN pg_namespace n ON n.oid = p.pronamespace " + "WHERE n.nspname = $1 AND p.proname = $2 " + "LIMIT 1", + schema, function_name + ) + return definition or "Function definition not found" + except Exception as e: + return f"Error getting function definition: {str(e)}" + + async def list_triggers(self, schema: str, table: str) -> list[dict]: + """List triggers for a specific table""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "trigger_name, " + "event_manipulation, " + "action_timing, " + "action_statement " + "FROM information_schema.triggers " + "WHERE event_object_schema = $1 AND event_object_table = $2 " + "ORDER BY trigger_name", + schema, table + ) + return [ + { + "name": r["trigger_name"], + "event": r["event_manipulation"], + "timing": r["action_timing"], + "action": r["action_statement"] + } + for r in records + ] + + async def get_table_ddl(self, schema: str, table: str) -> str: + """Get DDL statement for a table""" + try: + async with self.get_connection() as conn: + # Try the newer method first + try: + ddl = await conn.fetchval( + "SELECT pg_get_tabledef($1, $2, true)", + schema, table + ) + if ddl: + return ddl + except Exception: + pass + + # Fallback to a more compatible approach + ddl_query = """ + SELECT 'CREATE TABLE ' || $1 || '.' || $2 || ' (' || E'\n' || + string_agg( + ' ' || column_name || ' ' || + data_type || + CASE WHEN character_maximum_length IS NOT NULL + THEN '(' || character_maximum_length || ')' + ELSE '' END || + CASE WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END || + CASE WHEN column_default IS NOT NULL + THEN ' DEFAULT ' || column_default + ELSE '' END, + E',\n' ORDER BY ordinal_position + ) || E'\n);' as ddl + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + """ + ddl = await conn.fetchval(ddl_query, schema, table) + return ddl or "Table DDL not found" + except Exception as e: + return f"Error getting table DDL: {str(e)}" + + async def list_materialized_views(self, schema: str) -> list[str]: + """List materialized views in a specific schema""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT matviewname " + "FROM pg_matviews " + "WHERE schemaname = $1 " + "ORDER BY matviewname", + schema + ) + return [r["matviewname"] for r in records] + + async def get_vacuum_info(self, schema: str, table: str) -> dict: + """Get vacuum information for a table""" + async with self.get_connection() as conn: + record = await conn.fetchrow( + "SELECT " + "last_vacuum, " + "last_autovacuum, " + "n_dead_tup, " + "n_live_tup, " + "vacuum_count, " + "autovacuum_count " + "FROM pg_stat_user_tables " + "WHERE schemaname = $1 AND relname = $2", + schema, table + ) + if record: + return { + "last_vacuum": str(record["last_vacuum"]) if record["last_vacuum"] else None, + "last_autovacuum": str(record["last_autovacuum"]) if record["last_autovacuum"] else None, + "dead_tuples": record["n_dead_tup"], + "live_tuples": record["n_live_tup"], + "vacuum_count": record["vacuum_count"], + "autovacuum_count": record["autovacuum_count"] + } + return {"error": "Table not found"} + + async def list_active_connections(self) -> list[dict]: + """List active database connections""" + async with self.get_connection() as conn: + records = await conn.fetch( + "SELECT " + "pid, " + "usename, " + "application_name, " + "client_addr, " + "state, " + "backend_start, " + "query_start " + "FROM pg_stat_activity " + "WHERE usename != 'cloudberry' " + "ORDER BY backend_start" + ) + return [ + { + "pid": r["pid"], + "username": r["usename"], + "application": r["application_name"], + "client_addr": str(r["client_addr"]) if r["client_addr"] else None, + "state": r["state"], + "backend_start": str(r["backend_start"]), + "query_start": str(r["query_start"]) if r["query_start"] else None + } + for r in records + ] \ No newline at end of file diff --git a/mcp-server/src/cbmcp/prompt.py b/mcp-server/src/cbmcp/prompt.py new file mode 100644 index 00000000000..774aabac6ba --- /dev/null +++ b/mcp-server/src/cbmcp/prompt.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prompt templates for Apache Cloudberry database analysis.""" + +ANALYZE_QUERY_PERFORMANCE_PROMPT = """Please help me analyze and optimize a PostgreSQL query. + +I'll provide you with: +1. The SQL query: {sql} +2. The EXPLAIN ANALYZE output: {explain} +3. Table schema information: {table_info} + +Please analyze: +- Query execution plan +- Potential performance bottlenecks +- Index usage +- Suggested optimizations +- Alternative query approaches +""" + +SUGGEST_INDEXES_PROMPT = """Please help me suggest optimal indexes for your PostgreSQL tables. + +I'll provide you with: +1. The table schema(s): {table_info} +2. Common query patterns: {query} +3. Current indexes: {table_info} +4. Table size and row count: {table_stats} + +Please analyze: +- Missing indexes based on query patterns +- Index type recommendations (B-tree, GIN, GiST, etc.) +- Composite index suggestions +- Index maintenance considerations +""" + +DATABASE_HEALTH_CHECK_PROMPT = """Let's perform a comprehensive health check of your PostgreSQL database. + +Please analyze: +- Database size and growth trends +- Large tables and indexes +- Query performance metrics +- Connection pool usage +- Vacuum and analyze statistics +- Index fragmentation +- Table bloat +""" \ No newline at end of file diff --git a/mcp-server/src/cbmcp/security.py b/mcp-server/src/cbmcp/security.py new file mode 100644 index 00000000000..d9cd44a7e17 --- /dev/null +++ b/mcp-server/src/cbmcp/security.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Security utilities for the Apache Cloudberry MCP server +""" + +from typing import Set +import re + + +class SQLValidator: + """Validates SQL queries for security""" + + # Allowed SQL operations for safety + ALLOWED_OPERATIONS: Set[str] = { + "SELECT", "WITH", "SHOW", "EXPLAIN", "DESCRIBE", "PRAGMA" + } + + # Blocked SQL operations + BLOCKED_OPERATIONS: Set[str] = { + "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", + "TRUNCATE", "GRANT", "REVOKE", "REPLACE" + } + + # Sensitive tables that should not be queried + SENSITIVE_TABLES: Set[str] = { + "pg_user", "pg_shadow", "pg_authid", "pg_passfile", + "information_schema.user_privileges" + } + + @classmethod + def validate_query(cls, query: str) -> tuple[bool, str]: + """Validate a SQL query for security + + Returns: + tuple: (is_valid, error_message) + """ + query_upper = query.upper().strip() + + # Check for blocked operations + for blocked in cls.BLOCKED_OPERATIONS: + if re.search(rf"\b{blocked}\b", query_upper): + return False, f"Blocked SQL operation: {blocked}" + + # Check if query starts with allowed operation + if not any(query_upper.startswith(op) for op in cls.ALLOWED_OPERATIONS): + return False, f"Query must start with one of: {', '.join(cls.ALLOWED_OPERATIONS)}" + + # Check for sensitive table access + for sensitive_table in cls.SENSITIVE_TABLES: + if re.search(rf"\b{sensitive_table}\b", query_upper): + return False, f"Access to sensitive table not allowed: {sensitive_table}" + + # Check for potential SQL injection patterns + injection_patterns = [ + r";.*--", # Comments after statements + r"/\*.*\*/", # Block comments + r"'OR'1'='1", # Basic SQL injection + r"'UNION.*SELECT", # Union attacks + r"EXEC\s*\(", # Dynamic SQL execution + ] + + for pattern in injection_patterns: + if re.search(pattern, query_upper): + return False, f"Potential SQL injection detected" + + return True, "Query is valid" + + @classmethod + def sanitize_parameter_name(cls, param_name: str) -> str: + """Sanitize parameter names to prevent injection""" + # Remove any non-alphanumeric characters except underscores + return re.sub(r"[^a-zA-Z0-9_]", "", param_name) + + @classmethod + def is_readonly_query(cls, query: str) -> bool: + """Check if a query is read-only""" + query_upper = query.upper().strip() + return query_upper.startswith(("SELECT", "WITH", "SHOW", "EXPLAIN")) \ No newline at end of file diff --git a/mcp-server/src/cbmcp/server.py b/mcp-server/src/cbmcp/server.py new file mode 100644 index 00000000000..edf0b539545 --- /dev/null +++ b/mcp-server/src/cbmcp/server.py @@ -0,0 +1,551 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Apache Cloudberry MCP Server Implementation + +A Model Communication Protocol server for Apache Cloudberry database interaction +providing resources, tools, and prompts for database management. +""" + +from typing import Annotated, Any, Dict, List, Optional +import logging +from fastmcp import FastMCP +from pydantic import Field + +from .config import DatabaseConfig, ServerConfig +from .database import DatabaseManager +from .prompt import ( + ANALYZE_QUERY_PERFORMANCE_PROMPT, + SUGGEST_INDEXES_PROMPT, + DATABASE_HEALTH_CHECK_PROMPT +) + +logger = logging.getLogger(__name__) + +class CloudberryMCPServer: + """Apache Cloudberry MCP Server implementation""" + + def __init__(self, server_config: ServerConfig, db_config: DatabaseConfig): + self.server_config = server_config + self.db_config = db_config + self.mcp = FastMCP("Apache Cloudberry MCP Server") + self.db_manager = DatabaseManager(db_config) + + self._setup_resources() + self._setup_tools() + self._setup_prompts() + + + def _setup_resources(self): + """Setup MCP resources for database metadata""" + + @self.mcp.resource("postgres://schemas", mime_type="application/json") + async def list_schemas() -> List[str]: + """List all database schemas""" + logger.info("Listing schemas") + return await self.db_manager.list_schemas() + + @self.mcp.resource("postgres://database/info", mime_type="application/json") + async def database_info() -> Dict[str, str]: + """Get general database information""" + logger.info("Getting database info") + return await self.db_manager.get_database_info() + + @self.mcp.resource("postgres://database/summary", mime_type="application/json") + async def database_summary() -> Dict[str, dict]: + """Get comprehensive database summary""" + logger.info("Getting database summary") + return await self.db_manager.get_database_summary() + + + def _setup_tools(self): + """Setup MCP tools for database operations""" + + @self.mcp.tool() + async def list_tables( + schema: Annotated[str, Field(description="The schema name to list tables from")] + ) -> List[str]: + """List tables in a specific schema""" + logger.info(f"Listing tables in schema: {schema}") + try: + return await self.db_manager.list_tables(schema) + except Exception as e: + return f"Error listing tables: {str(e)}" + + @self.mcp.tool() + async def list_views( + schema: Annotated[str, Field(description="The schema name to list views from")] + ) -> List[str]: + """List views in a specific schema""" + logger.info(f"Listing views in schema: {schema}") + try: + return await self.db_manager.list_views(schema) + except Exception as e: + return f"Error listing views: {str(e)}" + + @self.mcp.tool() + async def list_indexes( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name to list indexes for")] + ) -> List[str]: + """List indexes for a specific table""" + logger.info(f"Listing indexes for table: {schema}.{table}") + try: + indexes = await self.db_manager.list_indexes(schema, table) + return [f"{idx['indexname']}: {idx['indexdef']}" for idx in indexes] + except Exception as e: + return f"Error listing indexes: {str(e)}" + + @self.mcp.tool() + async def list_columns( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name to list columns for")] + ) -> List[Dict[str, Any]]: + """List columns for a specific table""" + logger.info(f"Listing columns for table: {schema}.{table}") + try: + return await self.db_manager.list_columns(schema, table) + except Exception as e: + return f"Error listing columns: {str(e)}" + + @self.mcp.tool() + async def execute_query( + query: Annotated[str, Field(description="The SQL query to execute")], + params: Annotated[Optional[Dict[str, Any]], Field(description="The parameters for the query")] = None, + readonly: Annotated[bool, Field(description="Whether the query is read-only")] = True + ) -> Dict[str, Any]: + """ + Execute a safe SQL query with parameters + """ + logger.info(f"Executing query: {query}") + try: + return await self.db_manager.execute_query(query, params, readonly) + except Exception as e: + return {"error": f"Error executing query: {str(e)}"} + + @self.mcp.tool() + async def explain_query( + query: Annotated[str, Field(description="The SQL query to explain")], + params: Annotated[Optional[Dict[str, Any]], Field(description="The parameters for the query")] = None + ) -> str: + """ + Get the execution plan for a query + """ + logger.info(f"Explaining query: {query}") + try: + return await self.db_manager.explain_query(query, params) + except Exception as e: + return f"Error explaining query: {str(e)}" + + @self.mcp.tool() + async def get_table_stats( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")], + ) -> Dict[str, Any]: + """ + Get statistics for a table + """ + logger.info(f"Getting table stats for: {schema}.{table}") + try: + result = await self.db_manager.get_table_stats(schema, table) + if "error" in result: + return result["error"] + return result + except Exception as e: + return f"Error getting table stats: {str(e)}" + + @self.mcp.tool() + async def list_large_tables(limit: Annotated[int, Field(description="Number of tables to return")] = 10) -> List[Dict[str, Any]]: + """ + List the largest tables in the database + """ + logger.info(f"Listing large tables, limit: {limit}") + try: + return await self.db_manager.list_large_tables(limit) + except Exception as e: + return f"Error listing large tables: {str(e)}" + + @self.mcp.tool() + async def get_database_schemas() -> List[str]: + """Get database schemas""" + logger.info("Getting database schemas") + try: + return await self.db_manager.list_schemas() + except Exception as e: + return f"Error getting schemas: {str(e)}" + + @self.mcp.tool() + async def get_database_information() -> Dict[str, str]: + """Get general database information""" + logger.info("Getting database information") + try: + return await self.db_manager.get_database_info() + except Exception as e: + return f"Error getting database info: {str(e)}" + + @self.mcp.tool() + async def get_database_summary() -> Dict[str, dict]: + """Get detailed database summary""" + logger.info("Getting database summary") + try: + return await self.db_manager.get_database_summary() + except Exception as e: + return f"Error getting database summary: {str(e)}" + + @self.mcp.tool() + async def list_users() -> List[str]: + """List all database users""" + logger.info("Listing database users") + try: + return await self.db_manager.list_users() + except Exception as e: + return f"Error listing users: {str(e)}" + + @self.mcp.tool() + async def list_user_permissions( + username: Annotated[str, Field(description="The username to check permissions for")] + ) -> List[Dict[str, Any]]: + """List permissions for a specific user""" + logger.info(f"Listing permissions for user: {username}") + try: + return await self.db_manager.list_user_permissions(username) + except Exception as e: + return f"Error listing user permissions: {str(e)}" + + @self.mcp.tool() + async def list_table_privileges( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> List[Dict[str, Any]]: + """List privileges for a specific table""" + logger.info(f"Listing table privileges for: {schema}.{table}") + try: + return await self.db_manager.list_table_privileges(schema, table) + except Exception as e: + return f"Error listing table privileges: {str(e)}" + + @self.mcp.tool() + async def list_constraints( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> List[Dict[str, Any]]: + """List constraints for a specific table""" + logger.info(f"Listing constraints for table: {schema}.{table}") + try: + return await self.db_manager.list_constraints(schema, table) + except Exception as e: + return f"Error listing constraints: {str(e)}" + + @self.mcp.tool() + async def list_foreign_keys( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> List[Dict[str, Any]]: + """List foreign keys for a specific table""" + logger.info(f"Listing foreign keys for table: {schema}.{table}") + try: + return await self.db_manager.list_foreign_keys(schema, table) + except Exception as e: + return f"Error listing foreign keys: {str(e)}" + + @self.mcp.tool() + async def list_referenced_tables( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> List[Dict[str, Any]]: + """List tables that reference this table""" + logger.info(f"Listing referenced tables for: {schema}.{table}") + try: + return await self.db_manager.list_referenced_tables(schema, table) + except Exception as e: + return f"Error listing referenced tables: {str(e)}" + + @self.mcp.tool() + async def get_slow_queries( + limit: Annotated[int, Field(description="Number of slow queries to return")] = 10 + ) -> List[Dict[str, Any]]: + """Get slow queries from database statistics""" + logger.info(f"Getting slow queries, limit: {limit}") + try: + return await self.db_manager.get_slow_queries(limit) + except Exception as e: + return f"Error getting slow queries: {str(e)}" + + @self.mcp.tool() + async def get_index_usage() -> List[Dict[str, Any]]: + """Get index usage statistics""" + logger.info("Getting index usage statistics") + try: + return await self.db_manager.get_index_usage() + except Exception as e: + return f"Error getting index usage: {str(e)}" + + @self.mcp.tool() + async def get_table_bloat_info() -> List[Dict[str, Any]]: + """Get table bloat information""" + logger.info("Getting table bloat information") + try: + return await self.db_manager.get_table_bloat_info() + except Exception as e: + return f"Error getting table bloat info: {str(e)}" + + @self.mcp.tool() + async def get_database_activity() -> List[Dict[str, Any]]: + """Get current database activity""" + logger.info("Getting database activity") + try: + return await self.db_manager.get_database_activity() + except Exception as e: + return f"Error getting database activity: {str(e)}" + + @self.mcp.tool() + async def list_functions( + schema: Annotated[str, Field(description="The schema name")] + ) -> List[Dict[str, Any]]: + """List functions in a specific schema""" + logger.info(f"Listing functions in schema: {schema}") + try: + return await self.db_manager.list_functions(schema) + except Exception as e: + return f"Error listing functions: {str(e)}" + + @self.mcp.tool() + async def get_function_definition( + schema: Annotated[str, Field(description="The schema name")], + function_name: Annotated[str, Field(description="The function name")] + ) -> str: + """Get function definition""" + logger.info(f"Getting function definition: {schema}.{function_name}") + try: + return await self.db_manager.get_function_definition(schema, function_name) + except Exception as e: + return f"Error getting function definition: {str(e)}" + + @self.mcp.tool() + async def list_triggers( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> List[Dict[str, Any]]: + """List triggers for a specific table""" + logger.info(f"Listing triggers for table: {schema}.{table}") + try: + return await self.db_manager.list_triggers(schema, table) + except Exception as e: + return f"Error listing triggers: {str(e)}" + + @self.mcp.tool() + async def get_table_ddl( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> str: + """Get DDL statement for a table""" + logger.info(f"Getting table DDL: {schema}.{table}") + try: + return await self.db_manager.get_table_ddl(schema, table) + except Exception as e: + return f"Error getting table DDL: {str(e)}" + + @self.mcp.tool() + async def list_materialized_views( + schema: Annotated[str, Field(description="The schema name")] + ) -> List[str]: + """List materialized views in a specific schema""" + logger.info(f"Listing materialized views in schema: {schema}") + try: + return await self.db_manager.list_materialized_views(schema) + except Exception as e: + return f"Error listing materialized views: {str(e)}" + + @self.mcp.tool() + async def get_vacuum_info( + schema: Annotated[str, Field(description="The schema name")], + table: Annotated[str, Field(description="The table name")] + ) -> Dict[str, Any]: + """Get vacuum information for a table""" + logger.info(f"Getting vacuum info for table: {schema}.{table}") + try: + return await self.db_manager.get_vacuum_info(schema, table) + except Exception as e: + return f"Error getting vacuum info: {str(e)}" + + @self.mcp.tool() + async def list_active_connections() -> List[Dict[str, Any]]: + """List active database connections""" + logger.info("Listing active connections") + try: + return await self.db_manager.list_active_connections() + except Exception as e: + return f"Error listing active connections: {str(e)}" + + def _setup_prompts(self): + """Setup MCP prompts for common database tasks""" + + @self.mcp.prompt() + def analyze_query_performance( + sql: Annotated[str, Field(description="The SQL query to analyze")], + explain: Annotated[str, Field(description="The EXPLAIN ANALYZE output")], + table_info: Annotated[str, Field(description="The table schema information")], + ) -> str: + """Prompt for analyzing query performance""" + logger.info(f"Analyzing query performance for: {sql}") + return ANALYZE_QUERY_PERFORMANCE_PROMPT.format( + sql=sql, + explain=explain, + table_info=table_info + ) + @self.mcp.prompt() + def suggest_indexes( + query: Annotated[str, Field(description="The common query pattern")], + table_info: Annotated[str, Field(description="The table schema information")], + table_stats: Annotated[str, Field(description="The table statistics")], + ) -> str: + """Prompt for suggesting indexes""" + logger.info(f"Suggesting indexes for query: {query}") + return SUGGEST_INDEXES_PROMPT.format( + query=query, + table_info=table_info, + table_stats=table_stats + ) + + + @self.mcp.prompt() + def database_health_check() -> str: + """Prompt for database health check""" + logger.info(f"Checking database health") + return DATABASE_HEALTH_CHECK_PROMPT + + def run(self, mode: str="http"): + """Run the MCP server""" + if mode == "stdio": + return self.mcp.run( + transport="stdio", + ) + elif mode == "http": + return self.mcp.run( + transport="streamable-http", + host=self.server_config.host, + port=self.server_config.port, + path=self.server_config.path, + stateless_http=True + ) + + async def close(self): + """Close the server and cleanup resources""" + await self.db_manager.close() + + +def main(): + """Main entry point""" + import argparse + + parser = argparse.ArgumentParser( + description="Cloudberry MCP Server - Cloudberry database management tools", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s --mode stdio + %(prog)s --mode http --host 0.0.0.0 --port 8080 + %(prog)s --mode http --log-level INFO + %(prog)s --help + """ + ) + + parser.add_argument( + "--mode", + choices=["stdio", "http"], + default="http", + help="Server mode: stdio for stdin/stdout communication, http for HTTP server (default: http)" + ) + + parser.add_argument( + "--host", + default=None, + help="HTTP server host (default: from CLOUDBERRY_MCP_HOST env var or 127.0.0.1)" + ) + + parser.add_argument( + "--port", + type=int, + default=None, + help="HTTP server port (default: from CLOUDBERRY_MCP_PORT env var or 8080)" + ) + + parser.add_argument( + "--path", + default=None, + help="HTTP server path (default: from CLOUDBERRY_MCP_PATH env var or /mcp)" + ) + + parser.add_argument( + "--log-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="WARNING", + help="Logging level (default: WARNING)" + ) + + parser.add_argument( + "--version", + action="version", + version="Cloudberry MCP Server 1.0.0" + ) + + args = parser.parse_args() + + # Configure logging + log_level = getattr(logging, args.log_level.upper()) + logging.basicConfig( + level=log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] + ) + + # Create configurations + server_config = ServerConfig.from_env() + db_config = DatabaseConfig.from_env() + + # Override with command line arguments + if args.host: + server_config.host = args.host + if args.port: + server_config.port = args.port + if args.path: + server_config.path = args.path + + server = CloudberryMCPServer(server_config, db_config) + + try: + logger.info(f"Starting server in {args.mode} mode...") + server.run(args.mode) + except KeyboardInterrupt: + logger.error("Server stopped by user") + except Exception as e: + logger.error(f"Server error: {e}") + sys.exit(1) + finally: + import asyncio + asyncio.run(server.close()) + + +if __name__ == "__main__": + import sys + + main() diff --git a/mcp-server/tests/README.md b/mcp-server/tests/README.md new file mode 100644 index 00000000000..9fa69976b26 --- /dev/null +++ b/mcp-server/tests/README.md @@ -0,0 +1,115 @@ + + +# Apache Cloudberry MCP Testing Guide + +## Test Structure + +This project uses the `pytest` framework for testing, supporting both asynchronous testing and parameterized testing. + +### Test Files +- `test_cbmcp.py` - Main test file containing all MCP client functionality tests + +### Test Categories +- **Unit Tests** - Test individual features independently +- **Integration Tests** - Test overall system functionality +- **Parameterized Tests** - Test both stdio and http modes simultaneously + +## Running Tests + +### Install Test Dependencies +```bash +pip install -e ".[dev]" +``` + +### Run All Tests +```bash +pytest tests/ +``` + +### Run Specific Tests +```bash +# Run specific test file +pytest tests/test_cbmcp.py + +# Run specific test class +pytest tests/test_cbmcp.py::TestCloudberryMCPClient + +# Run specific test method +pytest tests/test_cbmcp.py::TestCloudberryMCPClient::test_list_capabilities + +# Run tests for specific mode +pytest tests/test_cbmcp.py -k "stdio" +``` + +### Verbose Output +```bash +pytest tests/ -v +``` + +### Coverage Testing +```bash +pytest tests/ --cov=src.cbmcp --cov-report=html --cov-report=term +``` + +## Test Features + +### 1. Server Capabilities Tests +- `test_list_capabilities` - Test tool, resource, and prompt listings + +### 2. Resource Tests +- `test_get_schemas_resource` - Get database schemas +- `test_get_tables_resource` - Get table listings +- `test_get_database_info_resource` - Get database information +- `test_get_database_summary_resource` - Get database summary + +### 3. Tool Tests +- `test_tools` - Parameterized testing of all tool calls + - list_tables + - list_views + - list_columns + - list_indexes + - execute_query + - list_large_tables + - get_table_stats + - explain_query + +### 4. Prompt Tests +- `test_analyze_query_performance_prompt` - Query performance analysis prompts +- `test_suggest_indexes_prompt` - Index suggestion prompts +- `test_database_health_check_prompt` - Database health check prompts + +## Test Modes + +Tests support two modes: +- **stdio** - Standard input/output mode +- **http** - HTTP mode + +## Notes + +1. Tests will skip inaccessible features (e.g., when database is not connected) +2. Ensure Apache Cloudberry service is started and configured correctly +3. Check database connection configuration in .env file + +## Using Scripts to Run + +You can use the provided script to run tests: +```bash +./run_tests.sh +``` \ No newline at end of file diff --git a/mcp-server/tests/test_cbmcp.py b/mcp-server/tests/test_cbmcp.py new file mode 100644 index 00000000000..5560c1d84e6 --- /dev/null +++ b/mcp-server/tests/test_cbmcp.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import pytest_asyncio +import asyncio +import json +from typing import Any +from pydantic import AnyUrl + +from cbmcp.client import CloudberryMCPClient + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, AnyUrl): + return str(obj) + return super().default(obj) + + +@pytest.fixture +def event_loop(): + """Create event loop for async testing""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture(params=["stdio", "http"]) +async def client(request): + """Create CloudberryMCPClient instance supporting stdio and http modes""" + client_instance = await CloudberryMCPClient.create(mode=request.param) + yield client_instance + await client_instance.close() + + +@pytest.mark.asyncio +class TestCloudberryMCPClient: + """Apache Cloudberry MCP client test class""" + + async def test_list_capabilities(self, client): + """Test server capabilities list""" + tools = await client.list_tools() + resources = await client.list_resources() + prompts = await client.list_prompts() + + assert tools is not None + assert resources is not None + assert prompts is not None + + assert isinstance(tools, list) + assert isinstance(resources, list) + assert isinstance(prompts, list) + + async def test_get_schemas_resource(self, client): + """Test getting database schemas resource""" + try: + schemas = await client.get_resource("postgres://schemas") + assert schemas is not None + assert isinstance(schemas, list) + except Exception as e: + pytest.skip(f"Skipping test - unable to get schemas: {e}") + + async def test_get_database_info_resource(self, client): + """Test getting database info resource""" + try: + db_infos = await client.get_resource("postgres://database/info") + assert db_infos is not None + assert isinstance(db_infos, list) + except Exception as e: + pytest.skip(f"Skipping test - unable to get database info: {e}") + + async def test_get_database_summary_resource(self, client): + """Test getting database summary resource""" + try: + db_summary = await client.get_resource("postgres://database/summary") + assert db_summary is not None + assert isinstance(db_summary, list) + except Exception as e: + pytest.skip(f"Skipping test - unable to get database summary: {e}") + + @pytest.mark.parametrize("tool_name,parameters", [ + ("list_tables", {"schema": "public"}), + ("list_views", {"schema": "public"}), + ("list_columns", {"schema": "public", "table": "test"}), + ("list_indexes", {"schema": "public", "table": "test"}), + ("execute_query", {"query": "SELECT version()", "readonly": True}), + ("list_large_tables", {"limit": 5}), + ("get_table_stats", {"schema": "public", "table": "film"}), + ("explain_query", {"query": "SELECT version()"}), + ]) + async def test_tools(self, client, tool_name, parameters): + """Test various tool calls""" + try: + result = await client.call_tool(tool_name, parameters) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call tool {tool_name}: {e}") + + async def test_analyze_query_performance_prompt(self, client): + """Test query performance analysis prompt""" + try: + prompt = await client.get_prompt( + "analyze_query_performance", + params={ + "sql": "SELECT * FROM public.test", + "explain": "public.test", + "table_info": "100 rows, 10 MB" + } + ) + assert prompt is not None + assert prompt.description is not None + assert isinstance(prompt.messages, list) + except Exception as e: + pytest.skip(f"Skipping test - unable to get analyze_query_performance prompt: {e}") + + async def test_suggest_indexes_prompt(self, client): + """Test index suggestion prompt""" + try: + prompt = await client.get_prompt( + "suggest_indexes", + params={ + "query": "public", + "table_info": "public.test", + "table_stats": "100 rows, 10 MB" + } + ) + assert prompt is not None + assert prompt.description is not None + assert isinstance(prompt.messages, list) + except Exception as e: + pytest.skip(f"Skipping test - unable to get suggest_indexes prompt: {e}") + + async def test_database_health_check_prompt(self, client): + """Test database health check prompt""" + try: + prompt = await client.get_prompt("database_health_check") + assert prompt is not None + assert prompt.description is not None + assert isinstance(prompt.messages, list) + except Exception as e: + pytest.skip(f"Skipping test - unable to get database_health_check prompt: {e}") + + # User and permission management tests + async def test_list_users(self, client): + """Test listing all users""" + try: + result = await client.call_tool("list_users", {}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_users tool: {e}") + + async def test_list_user_permissions(self, client): + """Test listing user permissions""" + try: + result = await client.call_tool("list_user_permissions", {"username": "postgres"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_user_permissions tool: {e}") + + async def test_list_table_privileges(self, client): + """Test listing table privileges""" + try: + result = await client.call_tool("list_table_privileges", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_table_privileges tool: {e}") + + # Constraint and relationship management tests + async def test_list_constraints(self, client): + """Test listing constraints""" + try: + result = await client.call_tool("list_constraints", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_constraints tool: {e}") + + async def test_list_foreign_keys(self, client): + """Test listing foreign keys""" + try: + result = await client.call_tool("list_foreign_keys", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_foreign_keys tool: {e}") + + async def test_list_referenced_tables(self, client): + """Test listing referenced tables""" + try: + result = await client.call_tool("list_referenced_tables", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_referenced_tables tool: {e}") + + # Performance monitoring and optimization tests + async def test_get_slow_queries(self, client): + """Test getting slow queries""" + try: + result = await client.call_tool("get_slow_queries", {"limit": 5}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_slow_queries tool: {e}") + + async def test_get_index_usage(self, client): + """Test getting index usage""" + try: + result = await client.call_tool("get_index_usage", {"schema": "public"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_index_usage tool: {e}") + + async def test_get_table_bloat_info(self, client): + """Test getting table bloat information""" + try: + result = await client.call_tool("get_table_bloat_info", {"schema": "public", "limit": 5}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_table_bloat_info tool: {e}") + + async def test_get_database_activity(self, client): + """Test getting database activity""" + try: + result = await client.call_tool("get_database_activity", {}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_database_activity tool: {e}") + + async def test_get_vacuum_info(self, client): + """Test getting vacuum information""" + try: + result = await client.call_tool("get_vacuum_info", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_vacuum_info tool: {e}") + + # Database object management tests + async def test_list_functions(self, client): + """Test listing functions""" + try: + result = await client.call_tool("list_functions", {"schema": "public"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_functions tool: {e}") + + async def test_get_function_definition(self, client): + """Test getting function definition""" + try: + result = await client.call_tool("get_function_definition", {"schema": "public", "function_name": "now"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_function_definition tool: {e}") + + async def test_list_triggers(self, client): + """Test listing triggers""" + try: + result = await client.call_tool("list_triggers", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_triggers tool: {e}") + + async def test_list_materialized_views(self, client): + """Test listing materialized views""" + try: + result = await client.call_tool("list_materialized_views", {"schema": "public"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_materialized_views tool: {e}") + + async def test_get_materialized_view_definition(self, client): + """Test getting materialized view definition""" + try: + result = await client.call_tool("get_materialized_view_definition", {"schema": "public", "view_name": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_materialized_view_definition tool: {e}") + + async def test_get_table_ddl(self, client): + """Test getting table DDL""" + try: + result = await client.call_tool("get_table_ddl", {"schema": "public", "table": "film"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call get_table_ddl tool: {e}") + + async def test_list_active_connections(self, client): + """Test listing active connections""" + try: + result = await client.call_tool("list_active_connections", {}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping test - unable to call list_active_connections tool: {e}") + + +@pytest.mark.asyncio +async def test_client_modes(): + """Test basic client functionality in different modes""" + for mode in ["stdio", "http"]: + client = await CloudberryMCPClient.create(mode=mode) + try: + # Basic connection test + tools = await client.list_tools() + assert isinstance(tools, list) + finally: + await client.close() diff --git a/mcp-server/tests/test_database_tools.py b/mcp-server/tests/test_database_tools.py new file mode 100644 index 00000000000..e0fcc644869 --- /dev/null +++ b/mcp-server/tests/test_database_tools.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Database tools test module +Tests newly added database management tool functionality +""" +import pytest +import pytest_asyncio +import asyncio +import json +from typing import Any +from pydantic import AnyUrl + +from cbmcp.client import CloudberryMCPClient + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, AnyUrl): + return str(obj) + return super().default(obj) + + +@pytest.fixture +def event_loop(): + """Create event loop for async testing""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture(params=["stdio", "http"]) +async def client(request): + """Create CloudberryMCPClient instance supporting stdio and http modes""" + client_instance = await CloudberryMCPClient.create(mode=request.param) + yield client_instance + await client_instance.close() + + +@pytest.mark.asyncio +class TestDatabaseTools: + """Database management tools test class""" + + # User and permission management tests + async def test_list_users_basic(self, client): + """Test basic user listing functionality""" + try: + result = await client.call_tool("list_users", {}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping user list test: {e}") + + async def test_list_user_permissions(self, client): + """Test user permissions query""" + try: + result = await client.call_tool("list_user_permissions", {"username": "postgres"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping user permissions test: {e}") + + async def test_list_table_privileges(self, client): + """Test table privileges query""" + try: + result = await client.call_tool("list_table_privileges", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping table privileges test: {e}") + + # Constraint and relationship management tests + async def test_list_constraints(self, client): + """Test constraints query""" + try: + result = await client.call_tool("list_constraints", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping constraints query test: {e}") + + async def test_list_foreign_keys(self, client): + """Test foreign keys query""" + try: + result = await client.call_tool("list_foreign_keys", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping foreign keys query test: {e}") + + async def test_list_referenced_tables(self, client): + """Test referenced tables query""" + try: + result = await client.call_tool("list_referenced_tables", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping referenced tables query test: {e}") + + # Performance monitoring and optimization tests + async def test_get_slow_queries(self, client): + """Test slow queries retrieval""" + result = await client.call_tool("get_slow_queries", {"limit": 5}) + assert result is not None + assert hasattr(result, 'structured_content') + + async def test_get_index_usage(self, client): + """Test index usage""" + try: + result = await client.call_tool("get_index_usage", {"schema": "public"}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping index usage test: {e}") + + async def test_get_table_bloat_info(self, client): + """Test table bloat information""" + try: + result = await client.call_tool("get_table_bloat_info", { + "schema": "public", + "limit": 5 + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping table bloat info test: {e}") + + async def test_get_database_activity(self, client): + """Test database activity monitoring""" + try: + result = await client.call_tool("get_database_activity", {}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping database activity test: {e}") + + async def test_get_vacuum_info(self, client): + """Test vacuum information""" + try: + result = await client.call_tool("get_vacuum_info", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping vacuum info test: {e}") + + # Database object management tests + async def test_list_functions(self, client): + """Test functions list""" + result = await client.call_tool("list_functions", {"schema": "public"}) + assert result is not None + assert hasattr(result, 'structured_content') + + async def test_get_function_definition(self, client): + """Test function definition retrieval""" + try: + result = await client.call_tool("get_function_definition", { + "schema": "public", + "function_name": "now" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping function definition test: {e}") + + async def test_list_triggers(self, client): + """Test triggers list""" + try: + result = await client.call_tool("list_triggers", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping triggers list test: {e}") + + async def test_list_materialized_views(self, client): + """Test materialized views list""" + result = await client.call_tool("list_materialized_views", {"schema": "public"}) + assert result is not None + assert hasattr(result, 'structured_content') + + async def test_get_materialized_view_ddl(self, client): + """Test materialized view DDL""" + try: + result = await client.call_tool("get_materialized_view_ddl", { + "schema": "public", + "view_name": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping materialized view DDL test: {e}") + + async def test_get_table_ddl(self, client): + """Test table DDL retrieval""" + try: + result = await client.call_tool("get_table_ddl", { + "schema": "public", + "table": "film" + }) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping table DDL test: {e}") + + async def test_list_active_connections(self, client): + """Test active connections list""" + try: + result = await client.call_tool("list_active_connections", {}) + assert result is not None + assert hasattr(result, 'structured_content') + except Exception as e: + pytest.skip(f"Skipping active connections test: {e}") + + async def test_all_tools_availability(self, client): + """Test availability of all newly added tools""" + tools = await client.list_tools() + tool_names = [tool.name for tool in tools] + + new_tools = [ + "list_users", "list_user_permissions", "list_table_privileges", + "list_constraints", "list_foreign_keys", "list_referenced_tables", + "get_slow_queries", "get_index_usage", "get_table_bloat_info", + "get_database_activity", "get_vacuum_info", "list_functions", + "get_function_definition", "list_triggers", "list_materialized_views", + "get_materialized_view_ddl", "get_table_ddl", "list_active_connections" + ] + + available_tools = [tool for tool in new_tools if tool in tool_names] + print(f"Found {len(available_tools)} new tools: {available_tools}") + + async def test_tool_parameter_validation(self, client): + """Test tool parameter validation""" + try: + await client.call_tool("get_table_ddl", {"schema": "public"}) + except Exception: + pass + + try: + result = await client.call_tool("get_table_ddl", { + "schema": "public", + "table": "film" + }) + assert result is not None + except Exception as e: + pytest.skip(f"Skipping DDL test: {e}") + + +@pytest.mark.asyncio +async def test_database_tools_comprehensive(): + """Comprehensive test of all database tools""" + client = await CloudberryMCPClient.create() + try: + tools = await client.list_tools() + assert isinstance(tools, list) + + try: + users = await client.call_tool("list_users", {}) + assert users is not None + + activity = await client.call_tool("get_database_activity", {}) + assert activity is not None + + connections = await client.call_tool("list_active_connections", {}) + assert connections is not None + + except Exception as e: + pytest.skip(f"Skipping comprehensive test: {e}") + finally: + await client.close()