Skip to content

Commit 20e91cd

Browse files
jaguarliujaguarliu
authored andcommitted
feat: 修改engine执行sql部分逻辑,释放数据库链接
1 parent 3e5a4fb commit 20e91cd

File tree

4 files changed

+175
-128
lines changed

4 files changed

+175
-128
lines changed

.ide/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ FROM python:3.13.0
33
# RUN apt-get update && apt-get install -y git
44

55
# 下载插件并设置权限,重命名后复制到系统路径
6-
RUN wget https://github.com/langgenius/dify-plugin-daemon/releases/download/0.0.7/dify-plugin-linux-amd64 -O /tmp/dify-plugin-linux-amd64 && \
6+
RUN wget https://github.com/langgenius/dify-plugin-daemon/releases/download/0.2.0/dify-plugin-linux-amd64 -O /tmp/dify-plugin-linux-amd64 && \
77
chmod +x /tmp/dify-plugin-linux-amd64 && \
88
mv /tmp/dify-plugin-linux-amd64 /usr/local/bin/dify
99

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# 核心依赖
2-
dify_plugin~=0.0.1b72
2+
dify_plugin~=0.2.0
33
sqlalchemy>=2.0.0
44

55
# 数据库驱动

utils/alchemy_db_client.py

Lines changed: 173 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,24 @@
1-
from typing import Any
1+
from typing import Any, Dict, Optional, Union
22
from sqlalchemy import create_engine, inspect, text
33
from sqlalchemy.exc import SQLAlchemyError
44
from urllib.parse import quote_plus # 用于对URL进行编码
5-
from typing import Any, Optional, Union
6-
7-
#def get_db_schema(
8-
# db_type: str,
9-
# host: str,
10-
# port: int,
11-
# database: str,
12-
# username: str,
13-
# password: str,
14-
# table_names: str | None = None
15-
#) -> dict[str, Any] | None:
16-
# """
17-
# 获取数据库表结构信息
18-
# :param db_type: 数据库类型 (mysql/oracle/sqlserver/postgresql)
19-
# :param host: 主机地址
20-
# :param port: 端口号
21-
# :param database: 数据库名
22-
# :param username: 用户名
23-
# :param password: 密码
24-
# :param table_names: 要查询的表名,以逗号分隔的字符串,如果为None则查询所有表
25-
# :return: 包含所有表结构信息的字典
26-
# """
27-
# result: dict[str, Any] = {}
28-
# db_type = 'mssql' if db_type == 'sqlserver' else db_type
29-
# # 构建连接URL
30-
# driver = {
31-
# 'mysql': 'pymysql',
32-
# 'oracle': 'cx_oracle',
33-
# 'mssql': 'pyodbc',
34-
# 'postgresql': 'psycopg2'
35-
# }.get(db_type.lower(), '')
36-
#
37-
# encoded_username = quote_plus(username)
38-
# encoded_password = quote_plus(password)
39-
# # separator = ':' if db_type is 'mssql' else ','
40-
#
41-
# engine = create_engine(f'{db_type.lower()}+{driver}://{encoded_username}:{encoded_password}@{host}{separator}{port}/{database}')
42-
# inspector = inspect(engine)
43-
#
44-
# # 获取字段注释的SQL语句
45-
# column_comment_sql = {
46-
# 'mysql': f"SELECT COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '{database}' AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
47-
# 'oracle': "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
48-
# 'mssql': "SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.columns c LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id WHERE OBJECT_NAME(c.object_id) = :table_name AND c.name = :column_name",
49-
# 'postgresql': """
50-
# SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
51-
# FROM pg_catalog.pg_class c
52-
# JOIN information_schema.columns cols
53-
# ON c.relname = cols.table_name
54-
# WHERE c.relname = :table_name AND cols.column_name = :column_name
55-
# """
56-
# }.get(db_type.lower(), "")
57-
#
58-
# try:
59-
# # 获取所有表名
60-
# all_tables = inspector.get_table_names()
61-
#
62-
# # 如果指定了table_names,则过滤表名
63-
# target_tables = all_tables
64-
#
65-
# if table_names:
66-
# target_tables = [table.strip() for table in table_names.split(',')]
67-
# # 过滤出实际存在的表
68-
# target_tables = [table for table in target_tables if table in all_tables]
69-
# print(f"Retrieving table metadata for {len(target_tables)} tables...")
70-
# for table_name in target_tables:
71-
# # 获取表注释
72-
# table_comment = ""
73-
# try:
74-
# table_comment = inspector.get_table_comment(table_name).get("text") or ""
75-
# except SQLAlchemyError as e:
76-
# raise ValueError(f"Failed to retrieve table comments: {str(e)}")
77-
#
78-
# table_info = {
79-
# 'comment': table_comment,
80-
# 'columns': []
81-
# }
82-
#
83-
# for column in inspector.get_columns(table_name):
84-
# # 获取字段注释
85-
# column_comment = ""
86-
# try:
87-
# with engine.connect() as conn:
88-
# stmt = text(column_comment_sql)
89-
# column_comment = conn.execute(stmt, {
90-
# 'table_name': table_name,
91-
# 'column_name': column['name']
92-
# }).scalar() or ""
93-
# except SQLAlchemyError as e:
94-
# print(f"Warning: failed to get comment for {table_name}.{column['name']} - {e}")
95-
# column_comment = ""
96-
#
97-
# table_info['columns'].append({
98-
# 'name': column['name'],
99-
# 'comment': column_comment,
100-
# 'type': str(column['type'])
101-
# })
102-
#
103-
# result[table_name] = table_info
104-
# return result
105-
# except SQLAlchemyError as e:
106-
# raise ValueError(f"Failed to retrieve database table metadata: {str(e)}")
107-
# finally:
108-
# engine.dispose()
5+
import atexit
6+
import logging
7+
import time
8+
9+
# 配置日志
10+
logging.basicConfig(
11+
level=logging.INFO,
12+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13+
)
14+
logger = logging.getLogger('db_connection')
15+
16+
# 全局引擎缓存,用于存储和复用数据库连接
17+
_ENGINE_CACHE: Dict[str, Any] = {}
18+
# 连接计数器,用于跟踪活跃连接
19+
_CONNECTION_COUNTERS: Dict[str, int] = {}
20+
# 引擎创建时间,用于跟踪引擎的生命周期
21+
_ENGINE_CREATION_TIME: Dict[str, float] = {}
10922

11023
def format_schema_dsl(schema: dict[str, Any], with_type: bool = True, with_comment: bool = False) -> str:
11124
"""
@@ -144,60 +57,194 @@ def format_schema_dsl(schema: dict[str, Any], with_type: bool = True, with_comme
14457

14558
return "\n".join(lines)
14659

147-
def execute_sql(
60+
def get_engine_key(
14861
db_type: str,
14962
host: str,
15063
port: int,
15164
database: str,
15265
username: str,
153-
password: str,
154-
sql: str,
155-
params: Optional[dict[str, Any]] = None,
15666
schema: Optional[str] = None
157-
) -> Union[list[dict[str, Any]], dict[str, Any], None]:
67+
) -> str:
15868
"""
159-
增强版 SQL 执行函数,支持 PostgreSQL schema
160-
161-
参数新增:
162-
schema: 指定目标schema(主要用于PostgreSQL)
69+
生成用于缓存引擎的唯一键
16370
"""
71+
schema_part = f"/{schema}" if schema else ""
72+
return f"{db_type}://{username}@{host}:{port}/{database}{schema_part}"
16473

74+
def get_or_create_engine(
75+
db_type: str,
76+
host: str,
77+
port: int,
78+
database: str,
79+
username: str,
80+
password: str,
81+
schema: Optional[str] = None
82+
) -> Any:
83+
"""
84+
获取或创建数据库引擎实例
85+
"""
86+
# 生成引擎缓存键
87+
engine_key = get_engine_key(db_type, host, port, database, username, schema)
88+
89+
# 检查缓存中是否已存在引擎实例
90+
if engine_key in _ENGINE_CACHE:
91+
logger.info(f"复用已有引擎: {engine_key} (创建于 {time.time() - _ENGINE_CREATION_TIME[engine_key]:.2f} 秒前)")
92+
return _ENGINE_CACHE[engine_key]
93+
16594
# 参数预处理
166-
params = params or {}
16795
driver = _get_driver(db_type)
16896
encoded_username = quote_plus(username)
16997
encoded_password = quote_plus(password)
17098
connect_args = {}
99+
171100
# PostgreSQL 特殊处理
172101
if db_type.lower() == 'postgresql' and schema:
173102
connect_args['options'] = f"-c search_path={schema}"
174-
175-
#if db_type.lower() == 'sqlserver':
176-
# import os
177-
# driver_extra_info = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
178-
# print(driver_extra_info)
103+
179104
# 构建连接字符串
180105
connection_uri = _build_connection_uri(
181106
db_type, driver, encoded_username, encoded_password,
182107
host, port, database
183108
)
109+
110+
# 创建数据库引擎
111+
logger.info(f"创建新引擎: {engine_key}")
112+
engine = create_engine(
113+
connection_uri,
114+
connect_args=connect_args,
115+
# 添加连接池配置,便于监控
116+
pool_pre_ping=True, # 在使用连接前检查其有效性
117+
pool_recycle=3600, # 一小时后回收连接
118+
echo_pool=True # 输出连接池事件日志
119+
)
120+
121+
# 将引擎实例存入缓存
122+
_ENGINE_CACHE[engine_key] = engine
123+
_CONNECTION_COUNTERS[engine_key] = 0
124+
_ENGINE_CREATION_TIME[engine_key] = time.time()
125+
126+
# 记录连接池配置信息
127+
pool_info = {
128+
"size": engine.pool.size(),
129+
"checkedin": engine.pool.checkedin(),
130+
"overflow": engine.pool.overflow(),
131+
"checkedout": engine.pool.checkedout()
132+
}
133+
logger.info(f"引擎 {engine_key} 连接池初始状态: {pool_info}")
134+
135+
return engine
136+
137+
# 添加连接跟踪函数
138+
def log_connection_status(engine_key: str, action: str):
139+
"""
140+
记录连接状态变化
141+
"""
142+
if engine_key not in _ENGINE_CACHE:
143+
return
144+
145+
engine = _ENGINE_CACHE[engine_key]
146+
pool_info = {
147+
"size": engine.pool.size(),
148+
"checkedin": engine.pool.checkedin(),
149+
"overflow": engine.pool.overflow(),
150+
"checkedout": engine.pool.checkedout()
151+
}
152+
153+
if action == "acquire":
154+
_CONNECTION_COUNTERS[engine_key] += 1
155+
elif action == "release":
156+
_CONNECTION_COUNTERS[engine_key] = max(0, _CONNECTION_COUNTERS[engine_key] - 1)
157+
158+
logger.info(f"{action} 连接 - 引擎: {engine_key}, 活跃连接: {_CONNECTION_COUNTERS[engine_key]}, 连接池状态: {pool_info}")
159+
160+
# 在程序退出时清理所有引擎连接
161+
@atexit.register
162+
def dispose_all_engines():
163+
"""
164+
在程序退出时关闭所有数据库连接
165+
"""
166+
logger.info(f"程序退出,开始清理 {len(_ENGINE_CACHE)} 个数据库引擎...")
167+
168+
for key, engine in _ENGINE_CACHE.items():
169+
# 记录引擎使用情况
170+
uptime = time.time() - _ENGINE_CREATION_TIME.get(key, time.time())
171+
active_connections = _CONNECTION_COUNTERS.get(key, 0)
172+
173+
pool_info = {
174+
"size": engine.pool.size(),
175+
"checkedin": engine.pool.checkedin(),
176+
"overflow": engine.pool.overflow(),
177+
"checkedout": engine.pool.checkedout()
178+
}
179+
180+
logger.info(f"释放引擎: {key}, 运行时间: {uptime:.2f}秒, 活跃连接: {active_connections}, 连接池状态: {pool_info}")
181+
engine.dispose()
182+
logger.info(f"引擎 {key} 已释放")
183+
184+
_ENGINE_CACHE.clear()
185+
_CONNECTION_COUNTERS.clear()
186+
_ENGINE_CREATION_TIME.clear()
187+
logger.info("所有数据库引擎已清理完毕")
188+
189+
def execute_sql(
190+
db_type: str,
191+
host: str,
192+
port: int,
193+
database: str,
194+
username: str,
195+
password: str,
196+
sql: str,
197+
params: Optional[dict[str, Any]] = None,
198+
schema: Optional[str] = None
199+
) -> Union[list[dict[str, Any]], dict[str, Any], None]:
200+
"""
201+
增强版 SQL 执行函数,支持 PostgreSQL schema
202+
203+
参数新增:
204+
schema: 指定目标schema(主要用于PostgreSQL)
205+
"""
206+
start_time = time.time()
207+
# 参数预处理
208+
params = params or {}
209+
210+
# 获取引擎键,用于日志记录
211+
engine_key = get_engine_key(db_type, host, port, database, username, schema)
212+
213+
# 获取或创建数据库引擎
214+
engine = get_or_create_engine(
215+
db_type, host, port, database, username, password, schema
216+
)
184217

218+
# 记录SQL执行开始
219+
truncated_sql = sql[:100] + "..." if len(sql) > 100 else sql
220+
logger.info(f"开始执行SQL - 引擎: {engine_key}, SQL: {truncated_sql}")
221+
185222
try:
186-
engine = create_engine(connection_uri, connect_args=connect_args)
223+
# 记录连接获取
224+
log_connection_status(engine_key, "acquire")
225+
187226
with engine.begin() as conn:
188227
# 显式设置schema(部分数据库需要)
189228
if db_type.lower() == 'postgresql' and schema:
190229
conn.execute(text(f"SET search_path TO {schema}"))
191230

192231
result_proxy = conn.execute(text(sql), params)
232+
result = _process_result(result_proxy)
233+
234+
# 记录SQL执行结束
235+
execution_time = time.time() - start_time
236+
result_size = len(result) if isinstance(result, list) else 1 if result else 0
237+
logger.info(f"SQL执行完成 - 引擎: {engine_key}, 耗时: {execution_time:.3f}秒, 结果行数: {result_size}")
193238

194-
return _process_result(result_proxy)
239+
return result
195240

196241
except SQLAlchemyError as e:
197-
raise ValueError(f"数据库操作失败:{str(e)}")
242+
error_msg = str(e)
243+
logger.error(f"SQL执行失败 - 引擎: {engine_key}, 错误: {error_msg}")
244+
raise ValueError(f"数据库操作失败:{error_msg}")
198245
finally:
199-
if 'engine' in locals():
200-
engine.dispose()
246+
# 记录连接释放
247+
log_connection_status(engine_key, "release")
201248

202249
def _get_driver(db_type: str) -> str:
203250
"""获取数据库驱动"""
@@ -228,4 +275,4 @@ def _process_result(result_proxy) -> Union[list[dict], dict, None]:
228275
"""处理执行结果"""
229276
if result_proxy.returns_rows:
230277
return [dict(row._mapping) for row in result_proxy]
231-
return {"rowcount": result_proxy.rowcount}
278+
return {"rowcount": result_proxy.rowcount}

workspace.difypkg

-598 KB
Binary file not shown.

0 commit comments

Comments
 (0)