44from urllib .parse import quote_plus # 用于对URL进行编码
55from typing import Any , Optional , Union
66
7- def get_db_schema (
8- db_type : str ,
9- host : str ,
10- port : int ,
11- database : str ,
12- username : str ,
13- password : str ,
14- table_names : str | None = None
15- ) -> dict [str , Any ] | None :
16- """
17- 获取数据库表结构信息
18- :param db_type: 数据库类型 (mysql/oracle/sqlserver/postgresql)
19- :param host: 主机地址
20- :param port: 端口号
21- :param database: 数据库名
22- :param username: 用户名
23- :param password: 密码
24- :param table_names: 要查询的表名,以逗号分隔的字符串,如果为None则查询所有表
25- :return: 包含所有表结构信息的字典
26- """
27- result : dict [str , Any ] = {}
28- db_type = 'mssql' if db_type == 'sqlserver' else db_type
29- # 构建连接URL
30- driver = {
31- 'mysql' : 'pymysql' ,
32- 'oracle' : 'cx_oracle' ,
33- 'mssql' : 'pyodbc' ,
34- 'postgresql' : 'psycopg2'
35- }.get (db_type .lower (), '' )
36-
37- encoded_username = quote_plus (username )
38- encoded_password = quote_plus (password )
39- # separator = ':' if db_type is 'mssql' else ','
40-
41- engine = create_engine (f'{ db_type .lower ()} +{ driver } ://{ encoded_username } :{ encoded_password } @{ host } { separator } { port } /{ database } ' )
42- inspector = inspect (engine )
43-
44- # 获取字段注释的SQL语句
45- column_comment_sql = {
46- 'mysql' : f"SELECT COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '{ database } ' AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name" ,
47- 'oracle' : "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE TABLE_NAME = :table_name AND COLUMN_NAME = :column_name" ,
48- 'mssql' : "SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.columns c LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id WHERE OBJECT_NAME(c.object_id) = :table_name AND c.name = :column_name" ,
49- 'postgresql' : """
50- SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
51- FROM pg_catalog.pg_class c
52- JOIN information_schema.columns cols
53- ON c.relname = cols.table_name
54- WHERE c.relname = :table_name AND cols.column_name = :column_name
55- """
56- }.get (db_type .lower (), "" )
57-
58- try :
59- # 获取所有表名
60- all_tables = inspector .get_table_names ()
61-
62- # 如果指定了table_names,则过滤表名
63- target_tables = all_tables
64-
65- if table_names :
66- target_tables = [table .strip () for table in table_names .split (',' )]
67- # 过滤出实际存在的表
68- target_tables = [table for table in target_tables if table in all_tables ]
69- print (f"Retrieving table metadata for { len (target_tables )} tables..." )
70- for table_name in target_tables :
71- # 获取表注释
72- table_comment = ""
73- try :
74- table_comment = inspector .get_table_comment (table_name ).get ("text" ) or ""
75- except SQLAlchemyError as e :
76- raise ValueError (f"Failed to retrieve table comments: { str (e )} " )
77-
78- table_info = {
79- 'comment' : table_comment ,
80- 'columns' : []
81- }
82-
83- for column in inspector .get_columns (table_name ):
84- # 获取字段注释
85- column_comment = ""
86- try :
87- with engine .connect () as conn :
88- stmt = text (column_comment_sql )
89- column_comment = conn .execute (stmt , {
90- 'table_name' : table_name ,
91- 'column_name' : column ['name' ]
92- }).scalar () or ""
93- except SQLAlchemyError as e :
94- print (f"Warning: failed to get comment for { table_name } .{ column ['name' ]} - { e } " )
95- column_comment = ""
96-
97- table_info ['columns' ].append ({
98- 'name' : column ['name' ],
99- 'comment' : column_comment ,
100- 'type' : str (column ['type' ])
101- })
102-
103- result [table_name ] = table_info
104- return result
105- except SQLAlchemyError as e :
106- raise ValueError (f"Failed to retrieve database table metadata: { str (e )} " )
107- finally :
108- engine .dispose ()
7+ # def get_db_schema(
8+ # db_type: str,
9+ # host: str,
10+ # port: int,
11+ # database: str,
12+ # username: str,
13+ # password: str,
14+ # table_names: str | None = None
15+ # ) -> dict[str, Any] | None:
16+ # """
17+ # 获取数据库表结构信息
18+ # :param db_type: 数据库类型 (mysql/oracle/sqlserver/postgresql)
19+ # :param host: 主机地址
20+ # :param port: 端口号
21+ # :param database: 数据库名
22+ # :param username: 用户名
23+ # :param password: 密码
24+ # :param table_names: 要查询的表名,以逗号分隔的字符串,如果为None则查询所有表
25+ # :return: 包含所有表结构信息的字典
26+ # """
27+ # result: dict[str, Any] = {}
28+ # db_type = 'mssql' if db_type == 'sqlserver' else db_type
29+ # # 构建连接URL
30+ # driver = {
31+ # 'mysql': 'pymysql',
32+ # 'oracle': 'cx_oracle',
33+ # 'mssql': 'pyodbc',
34+ # 'postgresql': 'psycopg2'
35+ # }.get(db_type.lower(), '')
36+ #
37+ # encoded_username = quote_plus(username)
38+ # encoded_password = quote_plus(password)
39+ # # separator = ':' if db_type is 'mssql' else ','
40+ #
41+ # engine = create_engine(f'{db_type.lower()}+{driver}://{encoded_username}:{encoded_password}@{host}{separator}{port}/{database}')
42+ # inspector = inspect(engine)
43+ #
44+ # # 获取字段注释的SQL语句
45+ # column_comment_sql = {
46+ # 'mysql': f"SELECT COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '{database}' AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
47+ # 'oracle': "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
48+ # 'mssql': "SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.columns c LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id WHERE OBJECT_NAME(c.object_id) = :table_name AND c.name = :column_name",
49+ # 'postgresql': """
50+ # SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
51+ # FROM pg_catalog.pg_class c
52+ # JOIN information_schema.columns cols
53+ # ON c.relname = cols.table_name
54+ # WHERE c.relname = :table_name AND cols.column_name = :column_name
55+ # """
56+ # }.get(db_type.lower(), "")
57+ #
58+ # try:
59+ # # 获取所有表名
60+ # all_tables = inspector.get_table_names()
61+ #
62+ # # 如果指定了table_names,则过滤表名
63+ # target_tables = all_tables
64+ #
65+ # if table_names:
66+ # target_tables = [table.strip() for table in table_names.split(',')]
67+ # # 过滤出实际存在的表
68+ # target_tables = [table for table in target_tables if table in all_tables]
69+ # print(f"Retrieving table metadata for {len(target_tables)} tables...")
70+ # for table_name in target_tables:
71+ # # 获取表注释
72+ # table_comment = ""
73+ # try:
74+ # table_comment = inspector.get_table_comment(table_name).get("text") or ""
75+ # except SQLAlchemyError as e:
76+ # raise ValueError(f"Failed to retrieve table comments: {str(e)}")
77+ #
78+ # table_info = {
79+ # 'comment': table_comment,
80+ # 'columns': []
81+ # }
82+ #
83+ # for column in inspector.get_columns(table_name):
84+ # # 获取字段注释
85+ # column_comment = ""
86+ # try:
87+ # with engine.connect() as conn:
88+ # stmt = text(column_comment_sql)
89+ # column_comment = conn.execute(stmt, {
90+ # 'table_name': table_name,
91+ # 'column_name': column['name']
92+ # }).scalar() or ""
93+ # except SQLAlchemyError as e:
94+ # print(f"Warning: failed to get comment for {table_name}.{column['name']} - {e}")
95+ # column_comment = ""
96+ #
97+ # table_info['columns'].append({
98+ # 'name': column['name'],
99+ # 'comment': column_comment,
100+ # 'type': str(column['type'])
101+ # })
102+ #
103+ # result[table_name] = table_info
104+ # return result
105+ # except SQLAlchemyError as e:
106+ # raise ValueError(f"Failed to retrieve database table metadata: {str(e)}")
107+ # finally:
108+ # engine.dispose()
109109
110110def format_schema_dsl (schema : dict [str , Any ], with_type : bool = True , with_comment : bool = False ) -> str :
111111 """
@@ -168,18 +168,18 @@ def execute_sql(
168168 encoded_username = quote_plus (username )
169169 encoded_password = quote_plus (password )
170170 connect_args = {}
171- driver_extra_info = None
172171 # PostgreSQL 特殊处理
173172 if db_type .lower () == 'postgresql' and schema :
174173 connect_args ['options' ] = f"-c search_path={ schema } "
175174
176- if db_type .lower () == 'sqlserver' :
177- import os
178- driver_extra_info = 'ODBC+Driver+17+for+SQL+Server' if os .name == 'posix' else 'SQL Server'
175+ #if db_type.lower() == 'sqlserver':
176+ # import os
177+ # driver_extra_info = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
178+ # print(driver_extra_info)
179179 # 构建连接字符串
180180 connection_uri = _build_connection_uri (
181181 db_type , driver , encoded_username , encoded_password ,
182- host , port , database , driver_extra_info
182+ host , port , database
183183 )
184184
185185 try :
@@ -204,7 +204,7 @@ def _get_driver(db_type: str) -> str:
204204 drivers = {
205205 'mysql' : 'pymysql' ,
206206 'oracle' : 'cx_oracle' ,
207- 'sqlserver' : 'pyodbc ' ,
207+ 'sqlserver' : 'pymssql ' ,
208208 'postgresql' : 'psycopg2'
209209 }
210210 return drivers .get (db_type .lower (), '' )
@@ -217,17 +217,21 @@ def _build_connection_uri(
217217 host : str ,
218218 port : int ,
219219 database : str ,
220- driver_info : str
221220) -> str :
222221 """构建数据库连接字符串"""
222+ < << << << HEAD
223223 separator = ':' if db_type != 'sqlserver' else ','
224224 db_type = db_type if db_type != 'sqlserver' else 'mssql'
225225 extra_info = '' if driver_info == None else f'?driver={ driver_info } '
226226
227227 return f"{ db_type } +{ driver } ://{ username } :{ password } @{ host } { separator } { port } /{ database } { extra_info } "
228+ == == == =
229+ db_type = db_type if db_type != 'sqlserver' else 'mssql'
230+ return f"{ db_type } +{ driver } ://{ username } :{ password } @{ host } :{ port } /{ database } "
231+ >> >> >> > debug
228232
229233def _process_result (result_proxy ) -> Union [list [dict ], dict , None ]:
230234 """处理执行结果"""
231235 if result_proxy .returns_rows :
232236 return [dict (row ._mapping ) for row in result_proxy ]
233- return {"rowcount" : result_proxy .rowcount }
237+ return {"rowcount" : result_proxy .rowcount }
0 commit comments