diff --git a/README.md b/README.md index 3ba73ae..94d84ce 100644 --- a/README.md +++ b/README.md @@ -229,6 +229,7 @@ The following table lists the main tools currently available for invocation via | `get_table_indexes` | Get index information for specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | | `get_recent_audit_logs` | Get audit log records for recent period. | `days` (integer, Optional), `limit` (integer, Optional) | | `get_catalog_list` | Get list of all catalog names. | `random_string` (string, Required) | +| `get_table_summary` | Get table summary information. | `table_name` (string, Required), `db_name` (string, Optional), `include_sample` (boolean, Optional), `sample_size` (integer, Optional) | | `get_sql_explain` | Get SQL execution plan with configurable content truncation and file export for LLM analysis. | `sql` (string, Required), `verbose` (boolean, Optional), `db_name` (string, Optional), `catalog_name` (string, Optional) | | `get_sql_profile` | Get SQL execution profile with content management and file export for LLM optimization workflows. | `sql` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional), `timeout` (integer, Optional) | | `get_table_data_size` | Get table data size information via FE HTTP API. | `db_name` (string, Optional), `table_name` (string, Optional), `single_replica` (boolean, Optional) | diff --git a/doris_mcp_server/tools/tools_manager.py b/doris_mcp_server/tools/tools_manager.py index 9f2aab9..6ecd5a0 100644 --- a/doris_mcp_server/tools/tools_manager.py +++ b/doris_mcp_server/tools/tools_manager.py @@ -280,6 +280,31 @@ async def get_catalog_list_tool(random_string: str) -> str: "random_string": random_string }) + # Get table summary tool + @mcp.tool( + "get_table_summary", + description="""[Function Description]: Get table summary information. + +[Parameter Content]: + +- table_name (string) [Required] - Target table name to analyze + +- db_name (string) [Optional] - Target database name, if not specified, defaults to being the same as the environment variable DB_DATABASE + +- include_sample (boolean) [Optional] - Whether to include sample data, default is true + +- sample_size (integer) [Optional] - sample data size, default value is 10 +""", + ) + async def get_table_summary_tool(table_name: str, db_name: str, include_sample: bool = True, sample_size: int = 10) -> str: + """Get table summary""" + return await self.call_tool("get_table_summary", { + "table_name": table_name, + "db_name": db_name, + "include_sample": include_sample, + "sample_size": sample_size, + }) + # SQL Explain tool @mcp.tool( "get_sql_explain", @@ -958,6 +983,31 @@ async def list_tools(self) -> List[Tool]: "required": ["random_string"], }, ), + Tool( + name="get_table_summary", + description="""[Function Description]: Get table summary information. + +[Parameter Content]: + +- table_name (string) [Required] - Target table name to analyze + +- db_name (string) [Optional] - Target database name, if not specified, defaults to being the same as the environment variable DB_DATABASE + +- include_sample (boolean) [Optional] - Whether to include sample data, default is true + +- sample_size (integer) [Optional] - sample data size, default value is 10 +""", + inputSchema={ + "type": "object", + "properties": { + "table_name": {"type": "string", "description": "Target table name to analyze"}, + "db_name": {"type": "string", "description": "Target database name"}, + "include_sample": {"type": "boolean", "description": "Whether to include sample data", "default": True}, + "sample_size": {"type": "integer", "description": "sample data size", "default": 10}, + }, + "required": ["table_name", "db_name"], + }, + ), Tool( name="get_sql_explain", description="""[Function Description]: Get SQL execution plan using EXPLAIN command based on Doris syntax. @@ -1387,6 +1437,8 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: result = await self._get_recent_audit_logs_tool(arguments) elif name == "get_catalog_list": result = await self._get_catalog_list_tool(arguments) + elif name == "get_table_summary": + result = await self._get_table_summary_tool(arguments) elif name == "get_sql_explain": result = await self._get_sql_explain_tool(arguments) elif name == "get_sql_profile": @@ -1547,6 +1599,16 @@ async def _get_catalog_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, A # Delegate to metadata extractor for processing return await self.metadata_extractor.get_catalog_list_for_mcp() + + async def _get_table_summary_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Table summary tool routing""" + table_name = arguments.get("table_name") + db_name = arguments.get("db_name") + include_sample = arguments.get("include_sample", True) + sample_size = arguments.get("sample_size", 10) + + # Delegate to Table analzyer for processing + return await self.table_analyzer.get_table_summary(table_name, db_name, include_sample, sample_size) async def _get_sql_explain_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """SQL Explain tool routing""" @@ -1914,4 +1976,4 @@ async def _exec_adbc_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, An async def _get_adbc_connection_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """ADBC connection information tool routing""" # Delegate to ADBC query tools for processing - return await self.adbc_query_tools.get_adbc_connection_info() \ No newline at end of file + return await self.adbc_query_tools.get_adbc_connection_info() diff --git a/doris_mcp_server/utils/analysis_tools.py b/doris_mcp_server/utils/analysis_tools.py index 84715e2..dcbea16 100644 --- a/doris_mcp_server/utils/analysis_tools.py +++ b/doris_mcp_server/utils/analysis_tools.py @@ -19,6 +19,7 @@ Provides data analysis functions including table analysis, column statistics, performance monitoring, etc. """ +import os import time from datetime import datetime from typing import Any, Dict, List @@ -42,11 +43,13 @@ def __init__(self, connection_manager: DorisConnectionManager): async def get_table_summary( self, table_name: str, + db_name: str, include_sample: bool = True, sample_size: int = 10 ) -> Dict[str, Any]: """Get table summary information""" - connection = await self.connection_manager.get_connection("query") + connection = await self.connection_manager.get_connection("system") + database = db_name or os.getenv("DB_DATABASE", "") # Get table basic information table_info_sql = f""" @@ -57,13 +60,13 @@ async def get_table_summary( create_time, engine FROM information_schema.tables - WHERE table_schema = DATABASE() + WHERE table_schema = '{database}' AND table_name = '{table_name}' """ table_info_result = await connection.execute(table_info_sql) if not table_info_result.data: - raise ValueError(f"Table {table_name} does not exist") + raise ValueError(f"Table {db_name}.{table_name} does not exist") table_info = table_info_result.data[0] @@ -75,7 +78,7 @@ async def get_table_summary( is_nullable, column_comment FROM information_schema.columns - WHERE table_schema = DATABASE() + WHERE table_schema = '{database}' AND table_name = '{table_name}' ORDER BY ordinal_position """ @@ -94,7 +97,7 @@ async def get_table_summary( # Get sample data if include_sample and sample_size > 0: - sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}" + sample_sql = f"SELECT * FROM {database}.{table_name} LIMIT {sample_size}" sample_result = await connection.execute(sample_sql) summary["sample_data"] = sample_result.data @@ -1237,4 +1240,4 @@ async def get_historical_memory_stats( "tracker_names": tracker_names, "time_range": time_range, "timestamp": datetime.now().isoformat() - } \ No newline at end of file + } diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index c71928a..e590928 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -79,6 +79,14 @@ async def execute(self, sql: str, params: tuple | None = None, auth_context=None """Execute SQL query""" start_time = time.time() + # In some cases, the trailing space characters will affect the execution result. + # For example: + # sql = "\n SELECT \n table_rows\n FROM information_schema.tables \n WHERE table_schema = '__internal_schema'\n AND table_name = 'column_statistics'\n " + # result1 = await connection.execute(sql) # => result1.data=[] + # result2 = await connection.execute(sql.strip()) # => result2.data=[{'table_rows': 5}] + # To solve this problem, we should remove those characters before executing any sql statements. + sql = sql.strip() + try: # If security manager exists, perform SQL security check security_result = None @@ -96,7 +104,7 @@ async def execute(self, sql: str, params: tuple | None = None, auth_context=None await cursor.execute(sql, params) # Check if it's a query statement (statement that returns result set) - sql_upper = sql.strip().upper() + sql_upper = sql.upper() if (sql_upper.startswith("SELECT") or sql_upper.startswith("SHOW") or sql_upper.startswith("DESCRIBE") or @@ -811,4 +819,4 @@ async def generate_health_report(self) -> dict[str, Any]: if pool_status["free_connections"] == 0: report["recommendations"].append("No free connections available, consider increasing pool size") - return report \ No newline at end of file + return report