diff --git a/.gitignore b/.gitignore index 5460bc4..f1dd1e5 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ venv.bak/ +.coverage +coverage.xml diff --git a/doris_mcp_server/tools/tools_manager.py b/doris_mcp_server/tools/tools_manager.py index cb5560c..57ec820 100644 --- a/doris_mcp_server/tools/tools_manager.py +++ b/doris_mcp_server/tools/tools_manager.py @@ -478,7 +478,100 @@ async def get_historical_memory_stats_tool( "time_range": time_range }) - logger.info("Successfully registered 16 tools to MCP server") + # Get table partition info tool + @mcp.tool( + "get_table_partition_info", + description="""[Function Description]: Get partition information for the specified table. + +[Parameter Content]: + +- table_name (string) [Required] - Name of the table to query + +- db_name (string) [Optional] - Target database name, defaults to the current database + +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog +""", + ) + async def get_table_partition_info_tool( + table_name: str, db_name: str = None, catalog_name: str = None + ) -> str: + """Get table partition information""" + return await self.call_tool("get_table_partition_info", { + "table_name": table_name, + "db_name": db_name, + "catalog_name": catalog_name + }) + + # Table sample data tool + @mcp.tool( + "table_sample_data", + description="""[Function Description]: Sample data from specified table with various sampling methods. + +[Parameter Content]: + +- table_name (string) [Required] - Name of the table to sample +- db_name (string) [Optional] - Target database name, defaults to current database +- catalog_name (string) [Optional] - Target catalog name for federation queries +- sample_method (string) [Optional] - Sampling method (RANDOM), default RANDOM +- sample_size (number) [Required] - Sample size or ratio (e.g. 100) +- columns (string) [Optional] - Columns to return, comma separated +- where_condition (string) [Optional] - Filter condition before sampling +- cache_ttl (integer) [Optional] - Result cache TTL in seconds, default 300 +""", + ) + async def table_sample_data_tool( + table_name: str, + sample_size: float, + db_name: str = None, + catalog_name: str = None, + sample_method: str = "RANDOM", + columns: str = None, + where_condition: str = None, + cache_ttl: int = 300 + ) -> str: + """Table data sampling tool""" + return await self.call_tool("table_sample_data", { + "table_name": table_name, + "db_name": db_name, + "catalog_name": catalog_name, + "sample_method": sample_method, + "sample_size": sample_size, + "columns": columns, + "where_condition": where_condition, + "cache_ttl": cache_ttl + }) + + # Data lineage analysis tool + @mcp.tool( + "analyze_data_lineage", + description="""[Function Description]: Analyze data lineage relationships between tables. + +[Parameter Content]: + +- table_name (string) [Optional] - Target table name (if not provided analyzes all tables) +- db_name (string) [Optional] - Target database name, defaults to current database +- catalog_name (string) [Optional] - Target catalog name for federation queries +- depth (integer) [Optional] - Analysis depth (default 1) +- direction (string) [Optional] - Analysis direction - "upstream", "downstream" or "both" (default) +""", + ) + async def analyze_data_lineage_tool( + table_name: str = None, + db_name: str = None, + catalog_name: str = None, + depth: int = 1, + direction: str = "both" + ) -> str: + """Data lineage analysis tool""" + return await self.call_tool("analyze_data_lineage", { + "table_name": table_name, + "db_name": db_name, + "catalog_name": catalog_name, + "depth": depth, + "direction": direction + }) + + logger.info("Successfully registered 19 tools to MCP server") async def list_tools(self) -> List[Tool]: """List all available query tools (for stdio mode)""" @@ -848,6 +941,81 @@ async def list_tools(self) -> List[Tool]: }, }, ), + Tool( + name="get_table_partition_info", + description="""[Function Description]: Get partition information for the specified table. + +[Parameter Content]: + +- table_name (string) [Required] - Name of the table to query + +- db_name (string) [Optional] - Target database name, defaults to the current database + +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog +""", + inputSchema={ + "type": "object", + "properties": { + "table_name": {"type": "string", "description": "Table name"}, + "db_name": {"type": "string", "description": "Database name"}, + "catalog_name": {"type": "string", "description": "Catalog name"}, + }, + "required": ["table_name"], + }, + ), + Tool( + name="table_sample_data", + description="""[Function Description]: Sample data from specified table with various sampling methods. + +[Parameter Content]: + +- table_name (string) [Required] - Name of the table to sample +- db_name (string) [Optional] - Target database name, defaults to current database +- catalog_name (string) [Optional] - Target catalog name for federation queries +- sample_method (string) [Optional] - Sampling method (RANDOM), default RANDOM +- sample_size (number) [Required] - Sample size or ratio (e.g. 100 ) +- columns (string) [Optional] - Columns to return, comma separated +- where_condition (string) [Optional] - Filter condition before sampling +- cache_ttl (integer) [Optional] - Result cache TTL in seconds, default 300 +""", + inputSchema={ + "type": "object", + "properties": { + "table_name": {"type": "string", "description": "Table name to sample"}, + "db_name": {"type": "string", "description": "Database name"}, + "catalog_name": {"type": "string", "description": "Catalog name"}, + "sample_method": {"type": "string", "enum": [ "RANDOM"], "default": "RANDOM"}, + "sample_size": {"type": "number", "description": "Sample size or ratio"}, + "columns": {"type": "string", "description": "Columns to return, comma separated"}, + "where_condition": {"type": "string", "description": "Filter condition before sampling"}, + "cache_ttl": {"type": "integer", "description": "Result cache TTL in seconds", "default": 300} + }, + "required": ["table_name", "sample_size"], + }, + ), + Tool( + name="analyze_data_lineage", + description="""[Function Description]: Analyze data lineage relationships between tables. + +[Parameter Content]: + +- table_name (string) [Optional] - Target table name (if not provided analyzes all tables) +- db_name (string) [Optional] - Target database name, defaults to current database +- catalog_name (string) [Optional] - Target catalog name for federation queries +- depth (integer) [Optional] - Analysis depth (default 1) +- direction (string) [Optional] - Analysis direction - "upstream", "downstream" or "both" (default) +""", + inputSchema={ + "type": "object", + "properties": { + "table_name": {"type": "string", "description": "Table name to analyze"}, + "db_name": {"type": "string", "description": "Database name"}, + "catalog_name": {"type": "string", "description": "Catalog name"}, + "depth": {"type": "integer", "description": "Analysis depth", "default": 1}, + "direction": {"type": "string", "enum": ["upstream", "downstream", "both"], "description": "Analysis direction", "default": "both"} + } + }, + ), ] return tools @@ -892,6 +1060,12 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: result = await self._get_realtime_memory_stats_tool(arguments) elif name == "get_historical_memory_stats": result = await self._get_historical_memory_stats_tool(arguments) + elif name == "get_table_partition_info": + result = await self._get_table_partition_info_tool(arguments) + elif name == "table_sample_data": + result = await self._table_sample_data_tool(arguments) + elif name == "analyze_data_lineage": + result = await self._analyze_data_lineage_tool(arguments) else: raise ValueError(f"Unknown tool: {name}") @@ -905,8 +1079,9 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: "timestamp": datetime.now().isoformat(), } - return json.dumps(result, ensure_ascii=False, indent=2) - + # Serialize datetime objects before JSON conversion + serialized_result = self._serialize_datetime_objects(result) + return json.dumps(serialized_result, ensure_ascii=False, indent=2) except Exception as e: logger.error(f"Tool call failed {name}: {str(e)}") error_result = { @@ -916,6 +1091,21 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str: "timestamp": datetime.now().isoformat(), } return json.dumps(error_result, ensure_ascii=False, indent=2) + + def _serialize_datetime_objects(self, data): + """Serialize datetime objects and Decimal numbers to JSON compatible format""" + if isinstance(data, list): + return [self._serialize_datetime_objects(item) for item in data] + elif isinstance(data, dict): + return {key: self._serialize_datetime_objects(value) for key, value in data.items()} + elif hasattr(data, 'isoformat'): # datetime, date, time objects + return data.isoformat() + elif hasattr(data, 'strftime'): # pandas Timestamp objects + return data.strftime('%Y-%m-%d %H:%M:%S') + elif hasattr(data, 'as_tuple'): # Decimal objects + return float(data) + else: + return data async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: @@ -1082,4 +1272,57 @@ async def _get_historical_memory_stats_tool(self, arguments: Dict[str, Any]) -> # Delegate to memory tracker for processing return await self.memory_tracker.get_historical_memory_stats( tracker_names, time_range - ) \ No newline at end of file + ) + + async def _get_table_partition_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Get table partition information tool routing""" + table_name = arguments.get("table_name") + db_name = arguments.get("db_name") + catalog_name = arguments.get("catalog_name") + + # Delegate to metadata extractor for processing + result = await self.metadata_extractor.get_table_partition_info_for_mcp( + db_name, table_name + ) + + return result + + async def _analyze_data_lineage_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Data lineage analysis tool routing""" + table_name = arguments.get("table_name") + db_name = arguments.get("db_name") + catalog_name = arguments.get("catalog_name") + depth = arguments.get("depth", 1) + direction = arguments.get("direction", "both") + + # Delegate to metadata extractor for processing + return await self.metadata_extractor.analyze_data_lineage( + table_name=table_name, + db_name=db_name, + catalog_name=catalog_name, + depth=depth, + direction=direction + ) + + async def _table_sample_data_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Table sample data tool routing""" + table_name = arguments.get("table_name") + db_name = arguments.get("db_name") + catalog_name = arguments.get("catalog_name") + sample_method = arguments.get("sample_method", "SYSTEM") + sample_size = arguments.get("sample_size") + columns = arguments.get("columns") + where_condition = arguments.get("where_condition") + cache_ttl = arguments.get("cache_ttl", 300) + + # Delegate to metadata extractor for processing + return await self.metadata_extractor.get_table_sample_data_for_mcp( + table_name=table_name, + db_name=db_name, + catalog_name=catalog_name, + sample_method=sample_method, + sample_size=sample_size, + columns=columns, + where_condition=where_condition, + cache_ttl=cache_ttl + ) diff --git a/doris_mcp_server/utils/schema_extractor.py b/doris_mcp_server/utils/schema_extractor.py index fd711c8..f17f289 100644 --- a/doris_mcp_server/utils/schema_extractor.py +++ b/doris_mcp_server/utils/schema_extractor.py @@ -1075,9 +1075,9 @@ def _extract_tables_from_sql(self, sql: str) -> List[str]: - def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, Any]: + async def get_table_partition_info_async(self, db_name: str, table_name: str) -> Dict[str, Any]: """ - Get partition information for a table + Get partition information for a table (async version) using SHOW PARTITION syntax Args: db_name: Database name @@ -1087,21 +1087,10 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A Dict: Partition information """ try: - # Get partition information - query = f""" - SELECT - PARTITION_NAME, - PARTITION_EXPRESSION, - PARTITION_DESCRIPTION, - TABLE_ROWS - FROM - information_schema.partitions - WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{table_name}' - """ + # Get partition information using SHOW PARTITION syntax + query = f"SHOW PARTITIONS FROM `{db_name}`.`{table_name}`" - partitions = self._execute_query(query) + partitions = await self._execute_query_async(query) if not partitions: return {} @@ -1113,10 +1102,10 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A for part in partitions: partition_info["partitions"].append({ - "name": part.get("PARTITION_NAME", ""), - "expression": part.get("PARTITION_EXPRESSION", ""), - "description": part.get("PARTITION_DESCRIPTION", ""), - "rows": part.get("TABLE_ROWS", 0) + "name": part.get("PartitionName", ""), + "expression": part.get("PartitionExpr", ""), + "description": part.get("PartitionDesc", ""), + "rows": part.get("PartitionRows", 0) }) return partition_info @@ -1124,6 +1113,18 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}") return {} + def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, Any]: + """ + Get partition information for a table (sync version) + """ + import asyncio + try: + return asyncio.run(self.get_table_partition_info_async(db_name, table_name)) + except RuntimeError: + # If there's already a running event loop + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.get_table_partition_info_async(db_name, table_name)) + def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None): """ Execute query with catalog-aware metadata operations using three-part naming @@ -1551,6 +1552,33 @@ async def get_table_indexes_for_mcp( logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True) return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes") + async def get_table_partition_info_for_mcp( + self, + database_name: str = None, + table_name: str = None, + db_name: str = None # For backward compatibility + ) -> Dict[str, Any]: + """Get partition information for specified table - MCP interface""" + effective_db = database_name or db_name or self.db_name + logger.info(f"Getting table partition info: Table: {table_name}, DB: {effective_db}") + + if not table_name: + return self._format_response(success=False, error="Missing table_name parameter") + + if not effective_db: + return self._format_response( + success=False, + error="Database name not specified", + message="Please specify database name or set default database" + ) + + try: + partition_info = await self.get_table_partition_info_async(db_name=effective_db, table_name=table_name) + return self._format_response(success=True, result=partition_info) + except Exception as e: + logger.error(f"Failed to get table partition info: {str(e)}", exc_info=True) + return self._format_response(success=False, error=str(e), message="Error occurred while getting table partition info") + def _serialize_datetime_objects(self, data): """Serialize datetime objects to JSON compatible format""" if isinstance(data, list): @@ -1603,6 +1631,254 @@ async def get_catalog_list_for_mcp(self) -> Dict[str, Any]: logger.error(f"Failed to get catalog list: {str(e)}", exc_info=True) return self._format_response(success=False, error=str(e), message="Error occurred while getting catalog list") + async def analyze_data_lineage( + self, + table_name: str = None, + db_name: str = None, + catalog_name: str = None, + depth: int = 1, + direction: str = "both" + ) -> Dict[str, Any]: + """ + Analyze data lineage relationships for specified table + + Args: + table_name: Target table name (optional, if not provided analyzes all tables) + db_name: Database name (optional) + catalog_name: Catalog name (optional) + depth: Analysis depth (default 1) + direction: Analysis direction - "upstream", "downstream" or "both" (default) + + Returns: + Dict containing lineage relationships + """ + try: + effective_db = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name + + # Get all tables if no specific table provided + if not table_name: + tables = await self.get_database_tables_async(db_name=effective_db, catalog_name=effective_catalog) + result = {} + for tbl in tables: + lineage = await self._analyze_table_lineage(tbl, effective_db, effective_catalog, depth, direction) + if lineage: + result[tbl] = lineage + return self._format_response(success=True, result=result) + else: + lineage = await self._analyze_table_lineage(table_name, effective_db, effective_catalog, depth, direction) + return self._format_response(success=True, result=lineage) + + except Exception as e: + logger.error(f"Failed to analyze data lineage: {str(e)}", exc_info=True) + return self._format_response(success=False, error=str(e)) + + async def _analyze_table_lineage( + self, + table_name: str, + db_name: str, + catalog_name: str, + depth: int, + direction: str + ) -> Dict[str, Any]: + """ + Analyze lineage for a single table + """ + lineage = { + "table": table_name, + "database": db_name, + "catalog": catalog_name, + "upstream": [], + "downstream": [] + } + + # Get foreign key relationships + fk_relations = await self._get_foreign_key_relations(table_name, db_name, catalog_name) + lineage["upstream"].extend(fk_relations) + + # Get SQL dependencies from audit logs + sql_deps = await self._get_sql_dependencies(table_name, db_name, catalog_name, depth) + if direction in ["both", "upstream"]: + lineage["upstream"].extend(sql_deps.get("upstream", [])) + if direction in ["both", "downstream"]: + lineage["downstream"].extend(sql_deps.get("downstream", [])) + + return lineage + + async def _get_foreign_key_relations( + self, + table_name: str, + db_name: str, + catalog_name: str + ) -> List[Dict[str, Any]]: + """ + Get foreign key relationships for a table + """ + try: + # Get table schema to find potential foreign keys + schema = await self.get_table_schema_async(table_name, db_name, catalog_name) + if not schema: + return [] + + relations = [] + + # Check for columns ending with _id (common foreign key pattern) + for col in schema: + col_name = col.get("column_name", "") + if col_name.endswith("_id"): + ref_table = col_name[:-3] # Remove _id suffix + + # Check if referenced table exists + ref_schema = await self.get_table_schema_async(ref_table, db_name, catalog_name) + if ref_schema: + relations.append({ + "type": "foreign_key", + "source_table": ref_table, + "source_column": "id", + "target_table": table_name, + "target_column": col_name, + "confidence": "medium" + }) + + return relations + + except Exception as e: + logger.error(f"Error getting foreign key relations: {str(e)}") + return [] + + async def _get_sql_dependencies( + self, + table_name: str, + db_name: str, + catalog_name: str, + depth: int + ) -> Dict[str, List[Dict[str, Any]]]: + """ + Get SQL dependencies from audit logs + """ + try: + # Get recent SQL queries involving this table + logs = await self.get_recent_audit_logs_for_mcp(days=7, limit=100) + if not logs.get("success"): + return {"upstream": [], "downstream": []} + + dependencies = {"upstream": [], "downstream": []} + + for log in logs.get("result", []): + sql = log.get("stmt", "") + if not sql or table_name.lower() not in sql.lower(): + continue + + # Parse SQL to find dependencies + tables = self._extract_tables_from_sql(sql) + if table_name in tables: + # This table is referenced in the SQL + if "insert" in sql.lower() or "update" in sql.lower(): + # Upstream dependencies (tables this table depends on) + for tbl in tables: + if tbl != table_name: + dependencies["upstream"].append({ + "type": "sql_dependency", + "source_table": tbl, + "target_table": table_name, + "sql": sql[:200] + "..." if len(sql) > 200 else sql, + "confidence": "low" + }) + elif "select" in sql.lower(): + # Downstream dependencies (tables depending on this table) + for tbl in tables: + if tbl != table_name: + dependencies["downstream"].append({ + "type": "sql_dependency", + "source_table": table_name, + "target_table": tbl, + "sql": sql[:200] + "..." if len(sql) > 200 else sql, + "confidence": "low" + }) + + return dependencies + + except Exception as e: + logger.error(f"Error getting SQL dependencies: {str(e)}") + return {"upstream": [], "downstream": []} + + async def get_table_sample_data_for_mcp( + self, + table_name: str, + db_name: str = None, + catalog_name: str = None, + sample_method: str = "RANDOM", + sample_size: float = None, + columns: str = None, + where_condition: str = None, + cache_ttl: int = 300 + ) -> Dict[str, Any]: + """Get sample data from specified table - MCP interface""" + logger.info(f"Getting table sample data: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}, " + f"Method: {sample_method}, Size: {sample_size}, Columns: {columns}, Where: {where_condition}") + + if not table_name: + return self._format_response(success=False, error="Missing table_name parameter") + + if not sample_size: + return self._format_response(success=False, error="Missing sample_size parameter") + + try: + # Build base query + effective_db = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name + + # Handle columns selection + columns_clause = columns if columns else "*" + + # Handle where condition + where_clause = f"WHERE {where_condition}" if where_condition else "" + + # Build sampling clause based on method + if sample_method.upper() == "RANDOM": + sample_clause = f"ORDER BY RAND() LIMIT {int(sample_size)}" + else: + return self._format_response( + success=False, + error=f"Invalid sample method: {sample_method}", + message="Supported methods: RANDOM" + ) + + # Build full query + if effective_catalog and effective_catalog != "internal": + query = f"SELECT {columns_clause} FROM `{effective_catalog}`.`{effective_db}`.`{table_name}` {sample_clause} {where_clause}" + else: + query = f"SELECT {columns_clause} FROM `{effective_db}`.`{table_name}` {sample_clause} {where_clause}" + + logger.info(f"Executing sample query: {query}") + + # Execute query with caching + cache_key = f"sample_data_{effective_catalog or 'default'}_{effective_db}_{table_name}_{sample_method}_{sample_size}_{columns or 'all'}_{where_condition or 'none'}" + + if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < cache_ttl: + logger.info("Returning cached sample data") + return self._format_response(success=True, result=self.metadata_cache[cache_key]) + + # Execute query + result = await self._execute_query_async(query, effective_db) + + if not result: + return self._format_response( + success=False, + error="No data returned", + message=f"No sample data found for table {effective_catalog or 'default'}.{effective_db}.{table_name}" + ) + + # Update cache + self.metadata_cache[cache_key] = result + self.metadata_cache_time[cache_key] = datetime.now() + + return self._format_response(success=True, result=result) + + except Exception as e: + logger.error(f"Failed to get table sample data: {str(e)}", exc_info=True) + return self._format_response(success=False, error=str(e), message="Error occurred while getting table sample data") + # ==================== Compatibility aliases ==================== @@ -1650,4 +1926,4 @@ async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[s async def get_catalog_list(self) -> Dict[str, Any]: """Get Doris catalog list""" - return await self.extractor.get_catalog_list_for_mcp() \ No newline at end of file + return await self.extractor.get_catalog_list_for_mcp() diff --git a/test/tools/test_tools_manager.py b/test/tools/test_tools_manager.py index dad4e2f..9d8cf5a 100644 --- a/test/tools/test_tools_manager.py +++ b/test/tools/test_tools_manager.py @@ -230,42 +230,311 @@ async def test_get_catalog_list_tool(self, tools_manager): elif "result" in result_data: assert len(result_data["result"]) >= 0 # May be empty if no catalogs + @pytest.mark.asyncio + async def test_get_table_partition_info_with_database_name(self, tools_manager): + """Test get_table_partition_info with database_name parameter""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "partitions": [{"PartitionName": "p1"}], + "partition_type": "RANGE" + } + } + + arguments = { + "table_name": "sales", + "database_name": "retail" + } + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "partitions" in result_data["result"] + assert len(result_data["result"]["partitions"]) == 1 + @pytest.mark.asyncio + async def test_get_table_partition_info_with_db_name(self, tools_manager): + """Test get_table_partition_info with db_name parameter (backward compatibility)""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "partitions": [{"PartitionName": "p1"}], + "partition_type": "RANGE" + } + } + + arguments = { + "table_name": "sales", + "db_name": "retail" + } + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "partitions" in result_data["result"] + assert len(result_data["result"]["partitions"]) == 1 @pytest.mark.asyncio - async def test_invalid_tool_name(self, tools_manager): - """Test calling invalid tool""" - result = await tools_manager.call_tool("invalid_tool", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - if "error" in result_data: - assert "Unknown tool" in result_data["error"] + async def test_get_table_partition_info_with_default_db(self, tools_manager): + """Test get_table_partition_info with default database""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "partitions": [{"PartitionName": "p1"}], + "partition_type": "RANGE" + } + } + + arguments = { + "table_name": "sales" + } + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "partitions" in result_data["result"] + assert len(result_data["result"]["partitions"]) == 1 @pytest.mark.asyncio - async def test_missing_required_arguments(self, tools_manager): - """Test calling tool with missing required arguments""" - # exec_query requires sql parameter - result = await tools_manager.call_tool("exec_query", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - # The test may pass if the tool handles missing parameters gracefully + async def test_get_table_partition_info_error(self, tools_manager): + """Test get_table_partition_info with error""" + with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute: + mock_execute.side_effect = Exception("Table not found") + + arguments = {"table_name": "nonexistent_table"} + result = await tools_manager.call_tool("get_table_partition_info", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert not result_data["success"] + assert "error" in result_data + assert "not found" in result_data["error"].lower() @pytest.mark.asyncio - async def test_tool_definitions_structure(self, tools_manager): - """Test tool definitions have correct structure""" - tools = await tools_manager.list_tools() - - for tool in tools: - # Each tool should have required fields - assert hasattr(tool, 'name') - assert hasattr(tool, 'description') - assert hasattr(tool, 'inputSchema') - - # Input schema should have properties - assert 'properties' in tool.inputSchema - - # Required fields should be defined - if 'required' in tool.inputSchema: - assert isinstance(tool.inputSchema['required'], list) \ No newline at end of file + async def test_table_sample_data_system(self, tools_manager): + """Test table_sample_data with SYSTEM sampling""" + with patch.object(tools_manager.metadata_extractor, 'get_table_sample_data_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": [ + {"id": 1, "name": "Sample 1"}, + {"id": 2, "name": "Sample 2"} + ] + } + + arguments = { + "table_name": "users", + "sample_method": "SYSTEM", + "sample_size": 10 + } + result = await tools_manager.call_tool("table_sample_data", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert len(result_data["result"]) == 2 + + @pytest.mark.asyncio + async def test_table_sample_data_bernoulli(self, tools_manager): + """Test table_sample_data with BERNOULLI sampling""" + with patch.object(tools_manager.metadata_extractor, 'get_table_sample_data_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": [ + {"id": 3, "name": "Sample 3"} + ] + } + + arguments = { + "table_name": "users", + "sample_method": "BERNOULLI", + "sample_size": 5 + } + result = await tools_manager.call_tool("table_sample_data", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert len(result_data["result"]) == 1 + + @pytest.mark.asyncio + async def test_table_sample_data_random(self, tools_manager): + """Test table_sample_data with RANDOM sampling""" + with patch.object(tools_manager.metadata_extractor, 'get_table_sample_data_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": [ + {"id": 4, "name": "Sample 4"}, + {"id": 5, "name": "Sample 5"}, + {"id": 6, "name": "Sample 6"} + ] + } + + arguments = { + "table_name": "users", + "sample_method": "RANDOM", + "sample_size": 3 + } + result = await tools_manager.call_tool("table_sample_data", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert len(result_data["result"]) == 3 + + @pytest.mark.asyncio + async def test_table_sample_data_with_columns(self, tools_manager): + """Test table_sample_data with column selection""" + with patch.object(tools_manager.metadata_extractor, 'get_table_sample_data_for_mcp') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": [ + {"id": 1}, + {"id": 2} + ] + } + + arguments = { + "table_name": "users", + "sample_method": "SYSTEM", + "sample_size": 10, + "columns": "id" + } + result = await tools_manager.call_tool("table_sample_data", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert len(result_data["result"]) == 2 + assert "name" not in result_data["result"][0] + + @pytest.mark.asyncio + async def test_analyze_data_lineage_basic(self, tools_manager): + """Test basic data lineage analysis""" + with patch.object(tools_manager.metadata_extractor, 'analyze_data_lineage') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "table": "orders", + "database": "test_db", + "upstream": [ + { + "type": "foreign_key", + "source_table": "customers", + "source_column": "id", + "target_table": "orders", + "target_column": "customer_id", + "confidence": "medium" + } + ], + "downstream": [] + } + } + + arguments = { + "table_name": "orders" + } + result = await tools_manager.call_tool("analyze_data_lineage", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert result_data["result"]["table"] == "orders" + assert len(result_data["result"]["upstream"]) == 1 + assert result_data["result"]["upstream"][0]["source_table"] == "customers" + + @pytest.mark.asyncio + async def test_analyze_data_lineage_with_params(self, tools_manager): + """Test data lineage analysis with parameters""" + with patch.object(tools_manager.metadata_extractor, 'analyze_data_lineage') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "table": "orders", + "database": "test_db", + "upstream": [], + "downstream": [ + { + "type": "sql_dependency", + "source_table": "orders", + "target_table": "order_items", + "sql": "SELECT * FROM order_items WHERE order_id IN (SELECT id FROM orders)", + "confidence": "low" + } + ] + } + } + + arguments = { + "table_name": "orders", + "depth": 2, + "direction": "downstream" + } + result = await tools_manager.call_tool("analyze_data_lineage", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert len(result_data["result"]["downstream"]) == 1 + assert result_data["result"]["downstream"][0]["target_table"] == "order_items" + + @pytest.mark.asyncio + async def test_analyze_data_lineage_error(self, tools_manager): + """Test data lineage analysis with error""" + with patch.object(tools_manager.metadata_extractor, 'analyze_data_lineage') as mock_execute: + mock_execute.side_effect = Exception("Table not found") + + arguments = { + "table_name": "nonexistent_table" + } + result = await tools_manager.call_tool("analyze_data_lineage", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert not result_data["success"] + assert "error" in result_data + assert "not found" in result_data["error"].lower() + + @pytest.mark.asyncio + async def test_analyze_data_lineage_all_tables(self, tools_manager): + """Test data lineage analysis for all tables""" + with patch.object(tools_manager.metadata_extractor, 'analyze_data_lineage') as mock_execute: + mock_execute.return_value = { + "success": True, + "result": { + "customers": { + "upstream": [], + "downstream": [ + { + "type": "foreign_key", + "source_table": "customers", + "target_table": "orders", + "source_column": "id", + "target_column": "customer_id", + "confidence": "medium" + } + ] + }, + "orders": { + "upstream": [ + { + "type": "foreign_key", + "source_table": "customers", + "target_table": "orders", + "source_column": "id", + "target_column": "customer_id", + "confidence": "medium" + } + ], + "downstream": [] + } + } + } + + arguments = { + "depth": 1, + "direction": "both" + } + result = await tools_manager.call_tool("analyze_data_lineage", arguments) + result_data = json.loads(result) if isinstance(result, str) else result + + assert result_data["success"] + assert "customers" in result_data["result"] + assert "orders" in result_data["result"] + assert len(result_data["result"]["customers"]["downstream"]) == 1 + assert len(result_data["result"]["orders"]["upstream"]) == 1