Skip to content

add table_sample_data tool #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ venv.bak/



.coverage
coverage.xml
251 changes: 247 additions & 4 deletions doris_mcp_server/tools/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,100 @@ async def get_historical_memory_stats_tool(
"time_range": time_range
})

logger.info("Successfully registered 16 tools to MCP server")
# Get table partition info tool
@mcp.tool(
"get_table_partition_info",
description="""[Function Description]: Get partition information for the specified table.

[Parameter Content]:

- table_name (string) [Required] - Name of the table to query

- db_name (string) [Optional] - Target database name, defaults to the current database

- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
)
async def get_table_partition_info_tool(
table_name: str, db_name: str = None, catalog_name: str = None
) -> str:
"""Get table partition information"""
return await self.call_tool("get_table_partition_info", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name
})

# Table sample data tool
@mcp.tool(
"table_sample_data",
description="""[Function Description]: Sample data from specified table with various sampling methods.

[Parameter Content]:

- table_name (string) [Required] - Name of the table to sample
- db_name (string) [Optional] - Target database name, defaults to current database
- catalog_name (string) [Optional] - Target catalog name for federation queries
- sample_method (string) [Optional] - Sampling method (RANDOM), default RANDOM
- sample_size (number) [Required] - Sample size or ratio (e.g. 100)
- columns (string) [Optional] - Columns to return, comma separated
- where_condition (string) [Optional] - Filter condition before sampling
- cache_ttl (integer) [Optional] - Result cache TTL in seconds, default 300
""",
)
async def table_sample_data_tool(
table_name: str,
sample_size: float,
db_name: str = None,
catalog_name: str = None,
sample_method: str = "RANDOM",
columns: str = None,
where_condition: str = None,
cache_ttl: int = 300
) -> str:
"""Table data sampling tool"""
return await self.call_tool("table_sample_data", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name,
"sample_method": sample_method,
"sample_size": sample_size,
"columns": columns,
"where_condition": where_condition,
"cache_ttl": cache_ttl
})

# Data lineage analysis tool
@mcp.tool(
"analyze_data_lineage",
description="""[Function Description]: Analyze data lineage relationships between tables.

[Parameter Content]:

- table_name (string) [Optional] - Target table name (if not provided analyzes all tables)
- db_name (string) [Optional] - Target database name, defaults to current database
- catalog_name (string) [Optional] - Target catalog name for federation queries
- depth (integer) [Optional] - Analysis depth (default 1)
- direction (string) [Optional] - Analysis direction - "upstream", "downstream" or "both" (default)
""",
)
async def analyze_data_lineage_tool(
table_name: str = None,
db_name: str = None,
catalog_name: str = None,
depth: int = 1,
direction: str = "both"
) -> str:
"""Data lineage analysis tool"""
return await self.call_tool("analyze_data_lineage", {
"table_name": table_name,
"db_name": db_name,
"catalog_name": catalog_name,
"depth": depth,
"direction": direction
})

logger.info("Successfully registered 19 tools to MCP server")

async def list_tools(self) -> List[Tool]:
"""List all available query tools (for stdio mode)"""
Expand Down Expand Up @@ -848,6 +941,81 @@ async def list_tools(self) -> List[Tool]:
},
},
),
Tool(
name="get_table_partition_info",
description="""[Function Description]: Get partition information for the specified table.

[Parameter Content]:

- table_name (string) [Required] - Name of the table to query

- db_name (string) [Optional] - Target database name, defaults to the current database

- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
},
"required": ["table_name"],
},
),
Tool(
name="table_sample_data",
description="""[Function Description]: Sample data from specified table with various sampling methods.

[Parameter Content]:

- table_name (string) [Required] - Name of the table to sample
- db_name (string) [Optional] - Target database name, defaults to current database
- catalog_name (string) [Optional] - Target catalog name for federation queries
- sample_method (string) [Optional] - Sampling method (RANDOM), default RANDOM
- sample_size (number) [Required] - Sample size or ratio (e.g. 100 )
- columns (string) [Optional] - Columns to return, comma separated
- where_condition (string) [Optional] - Filter condition before sampling
- cache_ttl (integer) [Optional] - Result cache TTL in seconds, default 300
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name to sample"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
"sample_method": {"type": "string", "enum": [ "RANDOM"], "default": "RANDOM"},
"sample_size": {"type": "number", "description": "Sample size or ratio"},
"columns": {"type": "string", "description": "Columns to return, comma separated"},
"where_condition": {"type": "string", "description": "Filter condition before sampling"},
"cache_ttl": {"type": "integer", "description": "Result cache TTL in seconds", "default": 300}
},
"required": ["table_name", "sample_size"],
},
),
Tool(
name="analyze_data_lineage",
description="""[Function Description]: Analyze data lineage relationships between tables.

[Parameter Content]:

- table_name (string) [Optional] - Target table name (if not provided analyzes all tables)
- db_name (string) [Optional] - Target database name, defaults to current database
- catalog_name (string) [Optional] - Target catalog name for federation queries
- depth (integer) [Optional] - Analysis depth (default 1)
- direction (string) [Optional] - Analysis direction - "upstream", "downstream" or "both" (default)
""",
inputSchema={
"type": "object",
"properties": {
"table_name": {"type": "string", "description": "Table name to analyze"},
"db_name": {"type": "string", "description": "Database name"},
"catalog_name": {"type": "string", "description": "Catalog name"},
"depth": {"type": "integer", "description": "Analysis depth", "default": 1},
"direction": {"type": "string", "enum": ["upstream", "downstream", "both"], "description": "Analysis direction", "default": "both"}
}
},
),
]

return tools
Expand Down Expand Up @@ -892,6 +1060,12 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
result = await self._get_realtime_memory_stats_tool(arguments)
elif name == "get_historical_memory_stats":
result = await self._get_historical_memory_stats_tool(arguments)
elif name == "get_table_partition_info":
result = await self._get_table_partition_info_tool(arguments)
elif name == "table_sample_data":
result = await self._table_sample_data_tool(arguments)
elif name == "analyze_data_lineage":
result = await self._analyze_data_lineage_tool(arguments)
else:
raise ValueError(f"Unknown tool: {name}")

Expand All @@ -905,8 +1079,9 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
"timestamp": datetime.now().isoformat(),
}

return json.dumps(result, ensure_ascii=False, indent=2)

# Serialize datetime objects before JSON conversion
serialized_result = self._serialize_datetime_objects(result)
return json.dumps(serialized_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"Tool call failed {name}: {str(e)}")
error_result = {
Expand All @@ -916,6 +1091,21 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
"timestamp": datetime.now().isoformat(),
}
return json.dumps(error_result, ensure_ascii=False, indent=2)

def _serialize_datetime_objects(self, data):
"""Serialize datetime objects and Decimal numbers to JSON compatible format"""
if isinstance(data, list):
return [self._serialize_datetime_objects(item) for item in data]
elif isinstance(data, dict):
return {key: self._serialize_datetime_objects(value) for key, value in data.items()}
elif hasattr(data, 'isoformat'): # datetime, date, time objects
return data.isoformat()
elif hasattr(data, 'strftime'): # pandas Timestamp objects
return data.strftime('%Y-%m-%d %H:%M:%S')
elif hasattr(data, 'as_tuple'): # Decimal objects
return float(data)
else:
return data


async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -1082,4 +1272,57 @@ async def _get_historical_memory_stats_tool(self, arguments: Dict[str, Any]) ->
# Delegate to memory tracker for processing
return await self.memory_tracker.get_historical_memory_stats(
tracker_names, time_range
)
)

async def _get_table_partition_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get table partition information tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")

# Delegate to metadata extractor for processing
result = await self.metadata_extractor.get_table_partition_info_for_mcp(
db_name, table_name
)

return result

async def _analyze_data_lineage_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Data lineage analysis tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
depth = arguments.get("depth", 1)
direction = arguments.get("direction", "both")

# Delegate to metadata extractor for processing
return await self.metadata_extractor.analyze_data_lineage(
table_name=table_name,
db_name=db_name,
catalog_name=catalog_name,
depth=depth,
direction=direction
)

async def _table_sample_data_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Table sample data tool routing"""
table_name = arguments.get("table_name")
db_name = arguments.get("db_name")
catalog_name = arguments.get("catalog_name")
sample_method = arguments.get("sample_method", "SYSTEM")
sample_size = arguments.get("sample_size")
columns = arguments.get("columns")
where_condition = arguments.get("where_condition")
cache_ttl = arguments.get("cache_ttl", 300)

# Delegate to metadata extractor for processing
return await self.metadata_extractor.get_table_sample_data_for_mcp(
table_name=table_name,
db_name=db_name,
catalog_name=catalog_name,
sample_method=sample_method,
sample_size=sample_size,
columns=columns,
where_condition=where_condition,
cache_ttl=cache_ttl
)
Loading