diff --git a/doris_mcp_server/tools/tools_manager.py b/doris_mcp_server/tools/tools_manager.py index cb5560c..68a13f8 100644 --- a/doris_mcp_server/tools/tools_manager.py +++ b/doris_mcp_server/tools/tools_manager.py @@ -478,7 +478,31 @@ async def get_historical_memory_stats_tool( "time_range": time_range }) - logger.info("Successfully registered 16 tools to MCP server") + # Get table partition info tool + @mcp.tool( + "get_table_partition_info", + description="""[Function Description]: Get partition information for the specified table. + +[Parameter Content]: + +- table_name (string) [Required] - Name of the table to query + +- db_name (string) [Optional] - Target database name, defaults to the current database + +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog +""", + ) + async def get_table_partition_info_tool( + table_name: str, db_name: str = None, catalog_name: str = None + ) -> str: + """Get table partition information""" + return await self.call_tool("get_table_partition_info", { + "table_name": table_name, + "db_name": db_name, + "catalog_name": catalog_name + }) + + logger.info("Successfully registered 17 tools to MCP server") async def list_tools(self) -> List[Tool]: """List all available query tools (for stdio mode)""" @@ -848,6 +872,28 @@ async def list_tools(self) -> List[Tool]: }, }, ), + Tool( + name="get_table_partition_info", + description="""[Function Description]: Get partition information for the specified table. + +[Parameter Content]: + +- table_name (string) [Required] - Name of the table to query + +- db_name (string) [Optional] - Target database name, defaults to the current database + +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog +""", + inputSchema={ + "type": "object", + "properties": { + "table_name": {"type": "string", "description": "Table name"}, + "db_name": {"type": "string", "description": "Database name"}, + "catalog_name": {"type": "string", "description": "Catalog name"}, + }, + "required": ["table_name"], + }, + ), ] return tools @@ -892,6 +938,8 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: result = await self._get_realtime_memory_stats_tool(arguments) elif name == "get_historical_memory_stats": result = await self._get_historical_memory_stats_tool(arguments) + elif name == "get_table_partition_info": + result = await self._get_table_partition_info_tool(arguments) else: raise ValueError(f"Unknown tool: {name}") @@ -1082,4 +1130,17 @@ async def _get_historical_memory_stats_tool(self, arguments: Dict[str, Any]) -> # Delegate to memory tracker for processing return await self.memory_tracker.get_historical_memory_stats( tracker_names, time_range - ) \ No newline at end of file + ) + + async def _get_table_partition_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Get table partition information tool routing""" + table_name = arguments.get("table_name") + db_name = arguments.get("db_name") + catalog_name = arguments.get("catalog_name") + + # Delegate to metadata extractor for processing + result = await self.metadata_extractor.get_table_partition_info_for_mcp( + db_name, table_name + ) + + return result diff --git a/doris_mcp_server/utils/schema_extractor.py b/doris_mcp_server/utils/schema_extractor.py index fd711c8..06da41c 100644 --- a/doris_mcp_server/utils/schema_extractor.py +++ b/doris_mcp_server/utils/schema_extractor.py @@ -1075,9 +1075,9 @@ def _extract_tables_from_sql(self, sql: str) -> List[str]: - def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, Any]: + async def get_table_partition_info_async(self, db_name: str, table_name: str) -> Dict[str, Any]: """ - Get partition information for a table + Get partition information for a table (async version) using SHOW PARTITION syntax Args: db_name: Database name @@ -1087,21 +1087,10 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A Dict: Partition information """ try: - # Get partition information - query = f""" - SELECT - PARTITION_NAME, - PARTITION_EXPRESSION, - PARTITION_DESCRIPTION, - TABLE_ROWS - FROM - information_schema.partitions - WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{table_name}' - """ + # Get partition information using SHOW PARTITION syntax + query = f"SHOW PARTITIONS FROM `{db_name}`.`{table_name}`" - partitions = self._execute_query(query) + partitions = await self._execute_query_async(query) if not partitions: return {} @@ -1113,10 +1102,10 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A for part in partitions: partition_info["partitions"].append({ - "name": part.get("PARTITION_NAME", ""), - "expression": part.get("PARTITION_EXPRESSION", ""), - "description": part.get("PARTITION_DESCRIPTION", ""), - "rows": part.get("TABLE_ROWS", 0) + "name": part.get("PartitionName", ""), + "expression": part.get("PartitionExpr", ""), + "description": part.get("PartitionDesc", ""), + "rows": part.get("PartitionRows", 0) }) return partition_info @@ -1124,6 +1113,18 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}") return {} + def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, Any]: + """ + Get partition information for a table (sync version) + """ + import asyncio + try: + return asyncio.run(self.get_table_partition_info_async(db_name, table_name)) + except RuntimeError: + # If there's already a running event loop + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.get_table_partition_info_async(db_name, table_name)) + def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None): """ Execute query with catalog-aware metadata operations using three-part naming @@ -1551,6 +1552,33 @@ async def get_table_indexes_for_mcp( logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True) return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes") + async def get_table_partition_info_for_mcp( + self, + database_name: str = None, + table_name: str = None, + db_name: str = None # For backward compatibility + ) -> Dict[str, Any]: + """Get partition information for specified table - MCP interface""" + effective_db = database_name or db_name or self.db_name + logger.info(f"Getting table partition info: Table: {table_name}, DB: {effective_db}") + + if not table_name: + return self._format_response(success=False, error="Missing table_name parameter") + + if not effective_db: + return self._format_response( + success=False, + error="Database name not specified", + message="Please specify database name or set default database" + ) + + try: + partition_info = await self.get_table_partition_info_async(db_name=effective_db, table_name=table_name) + return self._format_response(success=True, result=partition_info) + except Exception as e: + logger.error(f"Failed to get table partition info: {str(e)}", exc_info=True) + return self._format_response(success=False, error=str(e), message="Error occurred while getting table partition info") + def _serialize_datetime_objects(self, data): """Serialize datetime objects to JSON compatible format""" if isinstance(data, list): @@ -1650,4 +1678,4 @@ async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[s async def get_catalog_list(self) -> Dict[str, Any]: """Get Doris catalog list""" - return await self.extractor.get_catalog_list_for_mcp() \ No newline at end of file + return await self.extractor.get_catalog_list_for_mcp() diff --git a/test/tools/test_tools_manager.py b/test/tools/test_tools_manager.py index dad4e2f..ce46032 100644 --- a/test/tools/test_tools_manager.py +++ b/test/tools/test_tools_manager.py @@ -230,42 +230,84 @@ async def test_get_catalog_list_tool(self, tools_manager): elif "result" in result_data: assert len(result_data["result"]) >= 0 # May be empty if no catalogs - + @pytest.mark.asyncio + async def test_get_table_partition_info_with_database_name(self, tools_manager): + """Test get_table_partition_info with database_name parameter""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "partitions": [{"PartitionName": "p1"}], + "partition_type": "RANGE" + } + } + + arguments = { + "table_name": "sales", + "database_name": "retail" + } + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "partitions" in result_data["result"] + assert len(result_data["result"]["partitions"]) == 1 @pytest.mark.asyncio - async def test_invalid_tool_name(self, tools_manager): - """Test calling invalid tool""" - result = await tools_manager.call_tool("invalid_tool", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - if "error" in result_data: - assert "Unknown tool" in result_data["error"] + async def test_get_table_partition_info_with_db_name(self, tools_manager): + """Test get_table_partition_info with db_name parameter (backward compatibility)""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "partitions": [{"PartitionName": "p1"}], + "partition_type": "RANGE" + } + } + + arguments = { + "table_name": "sales", + "db_name": "retail" + } + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "partitions" in result_data["result"] + assert len(result_data["result"]["partitions"]) == 1 @pytest.mark.asyncio - async def test_missing_required_arguments(self, tools_manager): - """Test calling tool with missing required arguments""" - # exec_query requires sql parameter - result = await tools_manager.call_tool("exec_query", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - # The test may pass if the tool handles missing parameters gracefully + async def test_get_table_partition_info_with_default_db(self, tools_manager): + """Test get_table_partition_info with default database""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "partitions": [{"PartitionName": "p1"}], + "partition_type": "RANGE" + } + } + + arguments = { + "table_name": "sales" + } + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "partitions" in result_data["result"] + assert len(result_data["result"]["partitions"]) == 1 @pytest.mark.asyncio - async def test_tool_definitions_structure(self, tools_manager): - """Test tool definitions have correct structure""" - tools = await tools_manager.list_tools() - - for tool in tools: - # Each tool should have required fields - assert hasattr(tool, 'name') - assert hasattr(tool, 'description') - assert hasattr(tool, 'inputSchema') + async def test_get_table_partition_info_error(self, tools_manager): + """Test get_table_partition_info with error""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.side_effect = Exception("Table not found") - # Input schema should have properties - assert 'properties' in tool.inputSchema + arguments = {"table_name": "nonexistent_table"} + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result - # Required fields should be defined - if 'required' in tool.inputSchema: - assert isinstance(tool.inputSchema['required'], list) \ No newline at end of file + assert not result_data["success"] + assert "error" in result_data + assert "not found" in result_data["error"].lower()