1- from typing import Any
1+ from typing import Any , Dict , Optional , Union
22from sqlalchemy import create_engine , inspect , text
33from sqlalchemy .exc import SQLAlchemyError
44from urllib .parse import quote_plus # 用于对URL进行编码
5- from typing import Any , Optional , Union
6-
7- #def get_db_schema(
8- # db_type: str,
9- # host: str,
10- # port: int,
11- # database: str,
12- # username: str,
13- # password: str,
14- # table_names: str | None = None
15- #) -> dict[str, Any] | None:
16- # """
17- # 获取数据库表结构信息
18- # :param db_type: 数据库类型 (mysql/oracle/sqlserver/postgresql)
19- # :param host: 主机地址
20- # :param port: 端口号
21- # :param database: 数据库名
22- # :param username: 用户名
23- # :param password: 密码
24- # :param table_names: 要查询的表名,以逗号分隔的字符串,如果为None则查询所有表
25- # :return: 包含所有表结构信息的字典
26- # """
27- # result: dict[str, Any] = {}
28- # db_type = 'mssql' if db_type == 'sqlserver' else db_type
29- # # 构建连接URL
30- # driver = {
31- # 'mysql': 'pymysql',
32- # 'oracle': 'cx_oracle',
33- # 'mssql': 'pyodbc',
34- # 'postgresql': 'psycopg2'
35- # }.get(db_type.lower(), '')
36- #
37- # encoded_username = quote_plus(username)
38- # encoded_password = quote_plus(password)
39- # # separator = ':' if db_type is 'mssql' else ','
40- #
41- # engine = create_engine(f'{db_type.lower()}+{driver}://{encoded_username}:{encoded_password}@{host}{separator}{port}/{database}')
42- # inspector = inspect(engine)
43- #
44- # # 获取字段注释的SQL语句
45- # column_comment_sql = {
46- # 'mysql': f"SELECT COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '{database}' AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
47- # 'oracle': "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
48- # 'mssql': "SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.columns c LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id WHERE OBJECT_NAME(c.object_id) = :table_name AND c.name = :column_name",
49- # 'postgresql': """
50- # SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
51- # FROM pg_catalog.pg_class c
52- # JOIN information_schema.columns cols
53- # ON c.relname = cols.table_name
54- # WHERE c.relname = :table_name AND cols.column_name = :column_name
55- # """
56- # }.get(db_type.lower(), "")
57- #
58- # try:
59- # # 获取所有表名
60- # all_tables = inspector.get_table_names()
61- #
62- # # 如果指定了table_names,则过滤表名
63- # target_tables = all_tables
64- #
65- # if table_names:
66- # target_tables = [table.strip() for table in table_names.split(',')]
67- # # 过滤出实际存在的表
68- # target_tables = [table for table in target_tables if table in all_tables]
69- # print(f"Retrieving table metadata for {len(target_tables)} tables...")
70- # for table_name in target_tables:
71- # # 获取表注释
72- # table_comment = ""
73- # try:
74- # table_comment = inspector.get_table_comment(table_name).get("text") or ""
75- # except SQLAlchemyError as e:
76- # raise ValueError(f"Failed to retrieve table comments: {str(e)}")
77- #
78- # table_info = {
79- # 'comment': table_comment,
80- # 'columns': []
81- # }
82- #
83- # for column in inspector.get_columns(table_name):
84- # # 获取字段注释
85- # column_comment = ""
86- # try:
87- # with engine.connect() as conn:
88- # stmt = text(column_comment_sql)
89- # column_comment = conn.execute(stmt, {
90- # 'table_name': table_name,
91- # 'column_name': column['name']
92- # }).scalar() or ""
93- # except SQLAlchemyError as e:
94- # print(f"Warning: failed to get comment for {table_name}.{column['name']} - {e}")
95- # column_comment = ""
96- #
97- # table_info['columns'].append({
98- # 'name': column['name'],
99- # 'comment': column_comment,
100- # 'type': str(column['type'])
101- # })
102- #
103- # result[table_name] = table_info
104- # return result
105- # except SQLAlchemyError as e:
106- # raise ValueError(f"Failed to retrieve database table metadata: {str(e)}")
107- # finally:
108- # engine.dispose()
5+ import atexit
6+ import logging
7+ import time
8+
9+ # 配置日志
10+ logging .basicConfig (
11+ level = logging .INFO ,
12+ format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13+ )
14+ logger = logging .getLogger ('db_connection' )
15+
16+ # 全局引擎缓存,用于存储和复用数据库连接
17+ _ENGINE_CACHE : Dict [str , Any ] = {}
18+ # 连接计数器,用于跟踪活跃连接
19+ _CONNECTION_COUNTERS : Dict [str , int ] = {}
20+ # 引擎创建时间,用于跟踪引擎的生命周期
21+ _ENGINE_CREATION_TIME : Dict [str , float ] = {}
10922
11023def format_schema_dsl (schema : dict [str , Any ], with_type : bool = True , with_comment : bool = False ) -> str :
11124 """
@@ -144,60 +57,194 @@ def format_schema_dsl(schema: dict[str, Any], with_type: bool = True, with_comme
14457
14558 return "\n " .join (lines )
14659
147- def execute_sql (
60+ def get_engine_key (
14861 db_type : str ,
14962 host : str ,
15063 port : int ,
15164 database : str ,
15265 username : str ,
153- password : str ,
154- sql : str ,
155- params : Optional [dict [str , Any ]] = None ,
15666 schema : Optional [str ] = None
157- ) -> Union [ list [ dict [ str , Any ]], dict [ str , Any ], None ] :
67+ ) -> str :
15868 """
159- 增强版 SQL 执行函数,支持 PostgreSQL schema
160-
161- 参数新增:
162- schema: 指定目标schema(主要用于PostgreSQL)
69+ 生成用于缓存引擎的唯一键
16370 """
71+ schema_part = f"/{ schema } " if schema else ""
72+ return f"{ db_type } ://{ username } @{ host } :{ port } /{ database } { schema_part } "
16473
74+ def get_or_create_engine (
75+ db_type : str ,
76+ host : str ,
77+ port : int ,
78+ database : str ,
79+ username : str ,
80+ password : str ,
81+ schema : Optional [str ] = None
82+ ) -> Any :
83+ """
84+ 获取或创建数据库引擎实例
85+ """
86+ # 生成引擎缓存键
87+ engine_key = get_engine_key (db_type , host , port , database , username , schema )
88+
89+ # 检查缓存中是否已存在引擎实例
90+ if engine_key in _ENGINE_CACHE :
91+ logger .info (f"复用已有引擎: { engine_key } (创建于 { time .time () - _ENGINE_CREATION_TIME [engine_key ]:.2f} 秒前)" )
92+ return _ENGINE_CACHE [engine_key ]
93+
16594 # 参数预处理
166- params = params or {}
16795 driver = _get_driver (db_type )
16896 encoded_username = quote_plus (username )
16997 encoded_password = quote_plus (password )
17098 connect_args = {}
99+
171100 # PostgreSQL 特殊处理
172101 if db_type .lower () == 'postgresql' and schema :
173102 connect_args ['options' ] = f"-c search_path={ schema } "
174-
175- #if db_type.lower() == 'sqlserver':
176- # import os
177- # driver_extra_info = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
178- # print(driver_extra_info)
103+
179104 # 构建连接字符串
180105 connection_uri = _build_connection_uri (
181106 db_type , driver , encoded_username , encoded_password ,
182107 host , port , database
183108 )
109+
110+ # 创建数据库引擎
111+ logger .info (f"创建新引擎: { engine_key } " )
112+ engine = create_engine (
113+ connection_uri ,
114+ connect_args = connect_args ,
115+ # 添加连接池配置,便于监控
116+ pool_pre_ping = True , # 在使用连接前检查其有效性
117+ pool_recycle = 3600 , # 一小时后回收连接
118+ echo_pool = True # 输出连接池事件日志
119+ )
120+
121+ # 将引擎实例存入缓存
122+ _ENGINE_CACHE [engine_key ] = engine
123+ _CONNECTION_COUNTERS [engine_key ] = 0
124+ _ENGINE_CREATION_TIME [engine_key ] = time .time ()
125+
126+ # 记录连接池配置信息
127+ pool_info = {
128+ "size" : engine .pool .size (),
129+ "checkedin" : engine .pool .checkedin (),
130+ "overflow" : engine .pool .overflow (),
131+ "checkedout" : engine .pool .checkedout ()
132+ }
133+ logger .info (f"引擎 { engine_key } 连接池初始状态: { pool_info } " )
134+
135+ return engine
136+
137+ # 添加连接跟踪函数
138+ def log_connection_status (engine_key : str , action : str ):
139+ """
140+ 记录连接状态变化
141+ """
142+ if engine_key not in _ENGINE_CACHE :
143+ return
144+
145+ engine = _ENGINE_CACHE [engine_key ]
146+ pool_info = {
147+ "size" : engine .pool .size (),
148+ "checkedin" : engine .pool .checkedin (),
149+ "overflow" : engine .pool .overflow (),
150+ "checkedout" : engine .pool .checkedout ()
151+ }
152+
153+ if action == "acquire" :
154+ _CONNECTION_COUNTERS [engine_key ] += 1
155+ elif action == "release" :
156+ _CONNECTION_COUNTERS [engine_key ] = max (0 , _CONNECTION_COUNTERS [engine_key ] - 1 )
157+
158+ logger .info (f"{ action } 连接 - 引擎: { engine_key } , 活跃连接: { _CONNECTION_COUNTERS [engine_key ]} , 连接池状态: { pool_info } " )
159+
160+ # 在程序退出时清理所有引擎连接
161+ @atexit .register
162+ def dispose_all_engines ():
163+ """
164+ 在程序退出时关闭所有数据库连接
165+ """
166+ logger .info (f"程序退出,开始清理 { len (_ENGINE_CACHE )} 个数据库引擎..." )
167+
168+ for key , engine in _ENGINE_CACHE .items ():
169+ # 记录引擎使用情况
170+ uptime = time .time () - _ENGINE_CREATION_TIME .get (key , time .time ())
171+ active_connections = _CONNECTION_COUNTERS .get (key , 0 )
172+
173+ pool_info = {
174+ "size" : engine .pool .size (),
175+ "checkedin" : engine .pool .checkedin (),
176+ "overflow" : engine .pool .overflow (),
177+ "checkedout" : engine .pool .checkedout ()
178+ }
179+
180+ logger .info (f"释放引擎: { key } , 运行时间: { uptime :.2f} 秒, 活跃连接: { active_connections } , 连接池状态: { pool_info } " )
181+ engine .dispose ()
182+ logger .info (f"引擎 { key } 已释放" )
183+
184+ _ENGINE_CACHE .clear ()
185+ _CONNECTION_COUNTERS .clear ()
186+ _ENGINE_CREATION_TIME .clear ()
187+ logger .info ("所有数据库引擎已清理完毕" )
188+
189+ def execute_sql (
190+ db_type : str ,
191+ host : str ,
192+ port : int ,
193+ database : str ,
194+ username : str ,
195+ password : str ,
196+ sql : str ,
197+ params : Optional [dict [str , Any ]] = None ,
198+ schema : Optional [str ] = None
199+ ) -> Union [list [dict [str , Any ]], dict [str , Any ], None ]:
200+ """
201+ 增强版 SQL 执行函数,支持 PostgreSQL schema
202+
203+ 参数新增:
204+ schema: 指定目标schema(主要用于PostgreSQL)
205+ """
206+ start_time = time .time ()
207+ # 参数预处理
208+ params = params or {}
209+
210+ # 获取引擎键,用于日志记录
211+ engine_key = get_engine_key (db_type , host , port , database , username , schema )
212+
213+ # 获取或创建数据库引擎
214+ engine = get_or_create_engine (
215+ db_type , host , port , database , username , password , schema
216+ )
184217
218+ # 记录SQL执行开始
219+ truncated_sql = sql [:100 ] + "..." if len (sql ) > 100 else sql
220+ logger .info (f"开始执行SQL - 引擎: { engine_key } , SQL: { truncated_sql } " )
221+
185222 try :
186- engine = create_engine (connection_uri , connect_args = connect_args )
223+ # 记录连接获取
224+ log_connection_status (engine_key , "acquire" )
225+
187226 with engine .begin () as conn :
188227 # 显式设置schema(部分数据库需要)
189228 if db_type .lower () == 'postgresql' and schema :
190229 conn .execute (text (f"SET search_path TO { schema } " ))
191230
192231 result_proxy = conn .execute (text (sql ), params )
232+ result = _process_result (result_proxy )
233+
234+ # 记录SQL执行结束
235+ execution_time = time .time () - start_time
236+ result_size = len (result ) if isinstance (result , list ) else 1 if result else 0
237+ logger .info (f"SQL执行完成 - 引擎: { engine_key } , 耗时: { execution_time :.3f} 秒, 结果行数: { result_size } " )
193238
194- return _process_result ( result_proxy )
239+ return result
195240
196241 except SQLAlchemyError as e :
197- raise ValueError (f"数据库操作失败:{ str (e )} " )
242+ error_msg = str (e )
243+ logger .error (f"SQL执行失败 - 引擎: { engine_key } , 错误: { error_msg } " )
244+ raise ValueError (f"数据库操作失败:{ error_msg } " )
198245 finally :
199- if 'engine' in locals ():
200- engine . dispose ( )
246+ # 记录连接释放
247+ log_connection_status ( engine_key , "release" )
201248
202249def _get_driver (db_type : str ) -> str :
203250 """获取数据库驱动"""
@@ -228,4 +275,4 @@ def _process_result(result_proxy) -> Union[list[dict], dict, None]:
228275 """处理执行结果"""
229276 if result_proxy .returns_rows :
230277 return [dict (row ._mapping ) for row in result_proxy ]
231- return {"rowcount" : result_proxy .rowcount }
278+ return {"rowcount" : result_proxy .rowcount }
0 commit comments