Skip to content

add get_table_partition_info mcp tool #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions doris_mcp_server/tools/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,31 @@ async def get_historical_memory_stats_tool(
"time_range": time_range
})

logger.info("Successfully registered 16 tools to MCP server")
# Get table partition info tool
@mcp.tool(
"get_table_partition_info",
description="""[Function Description]: Get partition information for the specified table.

[Parameter Content]:

- table_name (string) [Required] - Name of the table to query

- db_name (string) [Optional] - Target database name, defaults to the current database

- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_table_partition_info_tool(
table_name: str, db_name: str = None, catalog_name: str = None
) -> str:
"""Get table partition information"""
return await self.call_tool("get_table_partition_info", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name
})

logger.info("Successfully registered 17 tools to MCP server")

async def list_tools(self) -> List[Tool]:
"""List all available query tools (for stdio mode)"""
Expand Down Expand Up @@ -848,6 +872,28 @@ async def list_tools(self) -> List[Tool]:
},
},
),
Tool(
name="get_table_partition_info",
description="""[Function Description]: Get partition information for the specified table.

[Parameter Content]:

- table_name (string) [Required] - Name of the table to query

- db_name (string) [Optional] - Target database name, defaults to the current database

- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
"required": ["table_name"],
},
),
]

return tools
Expand Down Expand Up @@ -892,6 +938,8 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
result = await self._get_realtime_memory_stats_tool(arguments)
elif name == "get_historical_memory_stats":
result = await self._get_historical_memory_stats_tool(arguments)
elif name == "get_table_partition_info":
result = await self._get_table_partition_info_tool(arguments)
else:
raise ValueError(f"Unknown tool: {name}")

Expand Down Expand Up @@ -1082,4 +1130,17 @@ async def _get_historical_memory_stats_tool(self, arguments: Dict[str, Any]) ->
# Delegate to memory tracker for processing
return await self.memory_tracker.get_historical_memory_stats(
tracker_names, time_range
)
)

async def _get_table_partition_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get table partition information tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")

# Delegate to metadata extractor for processing
result = await self.metadata_extractor.get_table_partition_info_for_mcp(
db_name, table_name
)

return result
70 changes: 49 additions & 21 deletions doris_mcp_server/utils/schema_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,9 +1075,9 @@ def _extract_tables_from_sql(self, sql: str) -> List[str]:



def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, Any]:
async def get_table_partition_info_async(self, db_name: str, table_name: str) -> Dict[str, Any]:
"""
Get partition information for a table
Get partition information for a table (async version) using SHOW PARTITION syntax

Args:
db_name: Database name
Expand All @@ -1087,21 +1087,10 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A
Dict: Partition information
"""
try:
# Get partition information
query = f"""
SELECT
PARTITION_NAME,
PARTITION_EXPRESSION,
PARTITION_DESCRIPTION,
TABLE_ROWS
FROM
information_schema.partitions
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{table_name}'
"""
# Get partition information using SHOW PARTITION syntax
query = f"SHOW PARTITIONS FROM `{db_name}`.`{table_name}`"

partitions = self._execute_query(query)
partitions = await self._execute_query_async(query)

if not partitions:
return {}
Expand All @@ -1113,17 +1102,29 @@ def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, A

for part in partitions:
partition_info["partitions"].append({
"name": part.get("PARTITION_NAME", ""),
"expression": part.get("PARTITION_EXPRESSION", ""),
"description": part.get("PARTITION_DESCRIPTION", ""),
"rows": part.get("TABLE_ROWS", 0)
"name": part.get("PartitionName", ""),
"expression": part.get("PartitionExpr", ""),
"description": part.get("PartitionDesc", ""),
"rows": part.get("PartitionRows", 0)
})

return partition_info
except Exception as e:
logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}")
return {}

def get_table_partition_info(self, db_name: str, table_name: str) -> Dict[str, Any]:
"""
Get partition information for a table (sync version)
"""
import asyncio
try:
return asyncio.run(self.get_table_partition_info_async(db_name, table_name))
except RuntimeError:
# If there's already a running event loop
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.get_table_partition_info_async(db_name, table_name))

def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None):
"""
Execute query with catalog-aware metadata operations using three-part naming
Expand Down Expand Up @@ -1551,6 +1552,33 @@ async def get_table_indexes_for_mcp(
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes")

async def get_table_partition_info_for_mcp(
self,
database_name: str = None,
table_name: str = None,
db_name: str = None # For backward compatibility
) -> Dict[str, Any]:
"""Get partition information for specified table - MCP interface"""
effective_db = database_name or db_name or self.db_name
logger.info(f"Getting table partition info: Table: {table_name}, DB: {effective_db}")

if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")

if not effective_db:
return self._format_response(
success=False,
error="Database name not specified",
message="Please specify database name or set default database"
)

try:
partition_info = await self.get_table_partition_info_async(db_name=effective_db, table_name=table_name)
return self._format_response(success=True, result=partition_info)
except Exception as e:
logger.error(f"Failed to get table partition info: {str(e)}", exc_info=True)
return self._format_response(success=False, error=str(e), message="Error occurred while getting table partition info")

def _serialize_datetime_objects(self, data):
"""Serialize datetime objects to JSON compatible format"""
if isinstance(data, list):
Expand Down Expand Up @@ -1650,4 +1678,4 @@ async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[s

async def get_catalog_list(self) -> Dict[str, Any]:
"""Get Doris catalog list"""
return await self.extractor.get_catalog_list_for_mcp()
return await self.extractor.get_catalog_list_for_mcp()
104 changes: 73 additions & 31 deletions test/tools/test_tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,42 +230,84 @@ async def test_get_catalog_list_tool(self, tools_manager):
elif "result" in result_data:
assert len(result_data["result"]) >= 0 # May be empty if no catalogs


@pytest.mark.asyncio
async def test_get_table_partition_info_with_database_name(self, tools_manager):
"""Test get_table_partition_info with database_name parameter"""
with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"result": {
"partitions": [{"PartitionName": "p1"}],
"partition_type": "RANGE"
}
}

arguments = {
"table_name": "sales",
"database_name": "retail"
}
result = await tools_manager.call_tool("get_table_partition_info", arguments)
result_data = json.loads(result) if isinstance(result, str) else result

assert result_data["success"]
assert "partitions" in result_data["result"]
assert len(result_data["result"]["partitions"]) == 1

@pytest.mark.asyncio
async def test_invalid_tool_name(self, tools_manager):
"""Test calling invalid tool"""
result = await tools_manager.call_tool("invalid_tool", {})
result_data = json.loads(result) if isinstance(result, str) else result

assert "error" in result_data or "success" in result_data
if "error" in result_data:
assert "Unknown tool" in result_data["error"]
async def test_get_table_partition_info_with_db_name(self, tools_manager):
"""Test get_table_partition_info with db_name parameter (backward compatibility)"""
with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"result": {
"partitions": [{"PartitionName": "p1"}],
"partition_type": "RANGE"
}
}

arguments = {
"table_name": "sales",
"db_name": "retail"
}
result = await tools_manager.call_tool("get_table_partition_info", arguments)
result_data = json.loads(result) if isinstance(result, str) else result

assert result_data["success"]
assert "partitions" in result_data["result"]
assert len(result_data["result"]["partitions"]) == 1

@pytest.mark.asyncio
async def test_missing_required_arguments(self, tools_manager):
"""Test calling tool with missing required arguments"""
# exec_query requires sql parameter
result = await tools_manager.call_tool("exec_query", {})
result_data = json.loads(result) if isinstance(result, str) else result

assert "error" in result_data or "success" in result_data
# The test may pass if the tool handles missing parameters gracefully
async def test_get_table_partition_info_with_default_db(self, tools_manager):
"""Test get_table_partition_info with default database"""
with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute:
mock_execute.return_value = {
"success": True,
"result": {
"partitions": [{"PartitionName": "p1"}],
"partition_type": "RANGE"
}
}

arguments = {
"table_name": "sales"
}
result = await tools_manager.call_tool("get_table_partition_info", arguments)
result_data = json.loads(result) if isinstance(result, str) else result

assert result_data["success"]
assert "partitions" in result_data["result"]
assert len(result_data["result"]["partitions"]) == 1

@pytest.mark.asyncio
async def test_tool_definitions_structure(self, tools_manager):
"""Test tool definitions have correct structure"""
tools = await tools_manager.list_tools()

for tool in tools:
# Each tool should have required fields
assert hasattr(tool, 'name')
assert hasattr(tool, 'description')
assert hasattr(tool, 'inputSchema')
async def test_get_table_partition_info_error(self, tools_manager):
"""Test get_table_partition_info with error"""
with patch.object(tools_manager.metadata_extractor, 'get_table_partition_info_for_mcp') as mock_execute:
mock_execute.side_effect = Exception("Table not found")

# Input schema should have properties
assert 'properties' in tool.inputSchema
arguments = {"table_name": "nonexistent_table"}
result = await tools_manager.call_tool("get_table_partition_info", arguments)
result_data = json.loads(result) if isinstance(result, str) else result

# Required fields should be defined
if 'required' in tool.inputSchema:
assert isinstance(tool.inputSchema['required'], list)
assert not result_data["success"]
assert "error" in result_data
assert "not found" in result_data["error"].lower()