diff --git a/tools/sql_execute.py b/tools/sql_execute.py index 76ba61d..2d4a2c1 100644 --- a/tools/sql_execute.py +++ b/tools/sql_execute.py @@ -2,13 +2,12 @@ from typing import Any import re import json - +import pandas as pd import records from sqlalchemy import text from dify_plugin import Tool from dify_plugin.entities.tool import ToolInvokeMessage - class SQLExecuteTool(Tool): def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: db_uri = tool_parameters.get("db_uri") or self.runtime.credentials.get("db_uri") @@ -25,38 +24,60 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag try: if re.match(r'^\s*(SELECT|WITH)\s+', query, re.IGNORECASE): + # 查询字段类型 + table_name = re.search(r'FROM\s+([^\s;]+)', query, re.IGNORECASE) + column_types = {} + if table_name: + type_query = f""" + SELECT COLUMN_NAME, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = '{table_name.group(1)}' + """ + types_rows = db.query(type_query) + column_types = {row['COLUMN_NAME']: { + 'DATA_TYPE': row['DATA_TYPE'], + 'CHARACTER_MAXIMUM_LENGTH': row['CHARACTER_MAXIMUM_LENGTH'] + } for row in types_rows.as_dict()} + + # 执行主查询 rows = db.query(query) + # 转换为 pandas DataFrame 并处理 NaN + df = pd.DataFrame(rows.as_dict()) + df = df.fillna('') # 替换 NaN 为空字符串 + if format == "json": - result = rows.as_dict() - yield self.create_json_message({"result": result}) + result = df.to_dict(orient='records') + yield self.create_json_message({"result": result, "column_types": column_types}) elif format == "md": result = str(rows.dataset) yield self.create_text_message(result) elif format == "csv": - result = rows.export("csv").encode() + result = df.to_csv(index=False).encode() yield self.create_blob_message( result, meta={"mime_type": "text/csv", "filename": "result.csv"} ) elif format == "yaml": - result = rows.export("yaml").encode() + result = df.to_dict(orient='records') + result = json.dumps(result, ensure_ascii=False).encode() yield self.create_blob_message( - result, - meta={"mime_type": "text/yaml", "filename": "result.yaml"}, + result, meta={"mime_type": "text/yaml", "filename": "result.yaml"} ) elif format == "xlsx": - result = rows.export("xlsx") + output_file = '/tmp/output.xlsx' + df.to_excel(output_file, index=False) + with open(output_file, 'rb') as f: + result = f.read() yield self.create_blob_message( result, meta={ "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "filename": "result.xlsx", - }, + "filename": "result.xlsx" + } ) elif format == "html": - result = rows.export("html").encode() + result = df.to_html(index=False).encode() yield self.create_blob_message( - result, - meta={"mime_type": "text/html", "filename": "result.html"}, + result, meta={"mime_type": "text/html", "filename": "result.html"} ) else: raise ValueError(f"Unsupported format: {format}")