Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions tools/sql_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from typing import Any
import re
import json

import pandas as pd
import records
from sqlalchemy import text
from dify_plugin import Tool
from dify_plugin.entities.tool import ToolInvokeMessage


class SQLExecuteTool(Tool):
def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
db_uri = tool_parameters.get("db_uri") or self.runtime.credentials.get("db_uri")
Expand All @@ -25,38 +24,60 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag

try:
if re.match(r'^\s*(SELECT|WITH)\s+', query, re.IGNORECASE):
# 查询字段类型
table_name = re.search(r'FROM\s+([^\s;]+)', query, re.IGNORECASE)
column_types = {}
if table_name:
type_query = f"""
SELECT COLUMN_NAME, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = '{table_name.group(1)}'
"""
types_rows = db.query(type_query)
column_types = {row['COLUMN_NAME']: {
'DATA_TYPE': row['DATA_TYPE'],
'CHARACTER_MAXIMUM_LENGTH': row['CHARACTER_MAXIMUM_LENGTH']
} for row in types_rows.as_dict()}

# 执行主查询
rows = db.query(query)
# 转换为 pandas DataFrame 并处理 NaN
df = pd.DataFrame(rows.as_dict())
df = df.fillna('') # 替换 NaN 为空字符串

if format == "json":
result = rows.as_dict()
yield self.create_json_message({"result": result})
result = df.to_dict(orient='records')
yield self.create_json_message({"result": result, "column_types": column_types})
elif format == "md":
result = str(rows.dataset)
yield self.create_text_message(result)
elif format == "csv":
result = rows.export("csv").encode()
result = df.to_csv(index=False).encode()
yield self.create_blob_message(
result, meta={"mime_type": "text/csv", "filename": "result.csv"}
)
elif format == "yaml":
result = rows.export("yaml").encode()
result = df.to_dict(orient='records')
result = json.dumps(result, ensure_ascii=False).encode()
yield self.create_blob_message(
result,
meta={"mime_type": "text/yaml", "filename": "result.yaml"},
result, meta={"mime_type": "text/yaml", "filename": "result.yaml"}
)
elif format == "xlsx":
result = rows.export("xlsx")
output_file = '/tmp/output.xlsx'
df.to_excel(output_file, index=False)
with open(output_file, 'rb') as f:
result = f.read()
yield self.create_blob_message(
result,
meta={
"mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"filename": "result.xlsx",
},
"filename": "result.xlsx"
}
)
elif format == "html":
result = rows.export("html").encode()
result = df.to_html(index=False).encode()
yield self.create_blob_message(
result,
meta={"mime_type": "text/html", "filename": "result.html"},
result, meta={"mime_type": "text/html", "filename": "result.html"}
)
else:
raise ValueError(f"Unsupported format: {format}")
Expand Down