Skip to content

Commit a3d2d66

Browse files
authored
Merge pull request #54 from Qiao-gengle/main
fix bug
2 parents bc7e08c + fb3ba4e commit a3d2d66

File tree

6 files changed

+129
-128
lines changed

6 files changed

+129
-128
lines changed

database_schema/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ def get_db_schema(
7979
return None
8080
finally:
8181
if engine:
82-
engine.dispose()
82+
engine.dispose()

database_schema/inspectors/sqlserver.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,21 @@
44
from urllib.parse import quote_plus
55

66
class SQLServerInspector(BaseInspector):
7-
"""SQLServer元数据获取实现"""
8-
9-
def __init__(self, host: str, port: int, database: str,
10-
username: str, password: str, schema_name: str = None, **kwargs):
11-
# 在SQL Server中,schema和database是不同的概念
12-
# 如果未指定schema,默认使用"dbo"
13-
schema_name = schema_name or "dbo"
14-
super().__init__(host, port, database, username, password, schema_name)
7+
"""SQL Server元数据获取实现"""
8+
9+
def __init__(self, host, port, database, username, password, schema_name = None, **kwargs):
10+
super().__init__(host, port, database, username, password, schema_name, **kwargs)
11+
self.schema_name = schema_name if schema_name != None else 'dbo'
1512

1613
def build_conn_str(self, host: str, port: int, database: str,
1714
username: str, password: str) -> str:
18-
import os
19-
driver = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
15+
16+
# import os
17+
# driver = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
2018
# driver = 'ODBC+Driver+17+for+SQL+Server'
2119
return (
22-
f"mssql+pyodbc://{quote_plus(username)}:{quote_plus(password)}"
23-
f"@{host},{port}/{database}?"
24-
f"driver={driver}"
20+
f"mssql+pymssql://{quote_plus(username)}:{quote_plus(password)}"
21+
f"@{host}:{port}/{database}"
2522
)
2623

2724
def get_table_names(self, inspector: reflection.Inspector) -> list[str]:

provider/rookie_text2data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
class RookieText2dataProvider(ToolProvider):
55
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
6-
pass
6+
pass

tools/rookie_excute_sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _validate_and_prepare_params(self, params: dict) -> dict:
4646

4747
if self._contains_risk_commands(params['sql']):
4848
raise ValueError("SQL语句包含危险操作")
49-
49+
params['schema'] = params.get('schema')if params.get('schema') != None else 'dbo' if params['db_type'] == 'sqlserver' else 'public'
5050
# 数据库执行参数
5151
execute_params = {
5252
'db_type': params['db_type'],
@@ -197,4 +197,4 @@ def _safe_serialize(self, data: Any) -> Any:
197197
"""安全的数据序列化"""
198198
return json.loads(
199199
json.dumps(data, default=self._custom_serializer, ensure_ascii=False)
200-
)
200+
)

tools/rookie_text2data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag
2323
username=tool_parameters['username'],
2424
password=tool_parameters['password'],
2525
table_names=tool_parameters['table_names'],
26-
schema_name=tool_parameters.get('schema_name', 'public')
26+
schema_name=tool_parameters.get('schema_name')
2727
)
2828
with_comment = tool_parameters.get('with_comment', False)
2929
dsl_text = format_schema_dsl(meta_data, with_type=True, with_comment=with_comment)

utils/alchemy_db_client.py

Lines changed: 114 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -4,108 +4,108 @@
44
from urllib.parse import quote_plus # 用于对URL进行编码
55
from typing import Any, Optional, Union
66

7-
def get_db_schema(
8-
db_type: str,
9-
host: str,
10-
port: int,
11-
database: str,
12-
username: str,
13-
password: str,
14-
table_names: str | None = None
15-
) -> dict[str, Any] | None:
16-
"""
17-
获取数据库表结构信息
18-
:param db_type: 数据库类型 (mysql/oracle/sqlserver/postgresql)
19-
:param host: 主机地址
20-
:param port: 端口号
21-
:param database: 数据库名
22-
:param username: 用户名
23-
:param password: 密码
24-
:param table_names: 要查询的表名,以逗号分隔的字符串,如果为None则查询所有表
25-
:return: 包含所有表结构信息的字典
26-
"""
27-
result: dict[str, Any] = {}
28-
db_type = 'mssql' if db_type == 'sqlserver' else db_type
29-
# 构建连接URL
30-
driver = {
31-
'mysql': 'pymysql',
32-
'oracle': 'cx_oracle',
33-
'mssql': 'pyodbc',
34-
'postgresql': 'psycopg2'
35-
}.get(db_type.lower(), '')
36-
37-
encoded_username = quote_plus(username)
38-
encoded_password = quote_plus(password)
39-
# separator = ':' if db_type is 'mssql' else ','
40-
41-
engine = create_engine(f'{db_type.lower()}+{driver}://{encoded_username}:{encoded_password}@{host}{separator}{port}/{database}')
42-
inspector = inspect(engine)
43-
44-
# 获取字段注释的SQL语句
45-
column_comment_sql = {
46-
'mysql': f"SELECT COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '{database}' AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
47-
'oracle': "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
48-
'mssql': "SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.columns c LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id WHERE OBJECT_NAME(c.object_id) = :table_name AND c.name = :column_name",
49-
'postgresql': """
50-
SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
51-
FROM pg_catalog.pg_class c
52-
JOIN information_schema.columns cols
53-
ON c.relname = cols.table_name
54-
WHERE c.relname = :table_name AND cols.column_name = :column_name
55-
"""
56-
}.get(db_type.lower(), "")
57-
58-
try:
59-
# 获取所有表名
60-
all_tables = inspector.get_table_names()
61-
62-
# 如果指定了table_names,则过滤表名
63-
target_tables = all_tables
64-
65-
if table_names:
66-
target_tables = [table.strip() for table in table_names.split(',')]
67-
# 过滤出实际存在的表
68-
target_tables = [table for table in target_tables if table in all_tables]
69-
print(f"Retrieving table metadata for {len(target_tables)} tables...")
70-
for table_name in target_tables:
71-
# 获取表注释
72-
table_comment = ""
73-
try:
74-
table_comment = inspector.get_table_comment(table_name).get("text") or ""
75-
except SQLAlchemyError as e:
76-
raise ValueError(f"Failed to retrieve table comments: {str(e)}")
77-
78-
table_info = {
79-
'comment': table_comment,
80-
'columns': []
81-
}
82-
83-
for column in inspector.get_columns(table_name):
84-
# 获取字段注释
85-
column_comment = ""
86-
try:
87-
with engine.connect() as conn:
88-
stmt = text(column_comment_sql)
89-
column_comment = conn.execute(stmt, {
90-
'table_name': table_name,
91-
'column_name': column['name']
92-
}).scalar() or ""
93-
except SQLAlchemyError as e:
94-
print(f"Warning: failed to get comment for {table_name}.{column['name']} - {e}")
95-
column_comment = ""
96-
97-
table_info['columns'].append({
98-
'name': column['name'],
99-
'comment': column_comment,
100-
'type': str(column['type'])
101-
})
102-
103-
result[table_name] = table_info
104-
return result
105-
except SQLAlchemyError as e:
106-
raise ValueError(f"Failed to retrieve database table metadata: {str(e)}")
107-
finally:
108-
engine.dispose()
7+
#def get_db_schema(
8+
# db_type: str,
9+
# host: str,
10+
# port: int,
11+
# database: str,
12+
# username: str,
13+
# password: str,
14+
# table_names: str | None = None
15+
#) -> dict[str, Any] | None:
16+
# """
17+
# 获取数据库表结构信息
18+
# :param db_type: 数据库类型 (mysql/oracle/sqlserver/postgresql)
19+
# :param host: 主机地址
20+
# :param port: 端口号
21+
# :param database: 数据库名
22+
# :param username: 用户名
23+
# :param password: 密码
24+
# :param table_names: 要查询的表名,以逗号分隔的字符串,如果为None则查询所有表
25+
# :return: 包含所有表结构信息的字典
26+
# """
27+
# result: dict[str, Any] = {}
28+
# db_type = 'mssql' if db_type == 'sqlserver' else db_type
29+
# # 构建连接URL
30+
# driver = {
31+
# 'mysql': 'pymysql',
32+
# 'oracle': 'cx_oracle',
33+
# 'mssql': 'pyodbc',
34+
# 'postgresql': 'psycopg2'
35+
# }.get(db_type.lower(), '')
36+
#
37+
# encoded_username = quote_plus(username)
38+
# encoded_password = quote_plus(password)
39+
# # separator = ':' if db_type is 'mssql' else ','
40+
#
41+
# engine = create_engine(f'{db_type.lower()}+{driver}://{encoded_username}:{encoded_password}@{host}{separator}{port}/{database}')
42+
# inspector = inspect(engine)
43+
#
44+
# # 获取字段注释的SQL语句
45+
# column_comment_sql = {
46+
# 'mysql': f"SELECT COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '{database}' AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
47+
# 'oracle': "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE TABLE_NAME = :table_name AND COLUMN_NAME = :column_name",
48+
# 'mssql': "SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.columns c LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id WHERE OBJECT_NAME(c.object_id) = :table_name AND c.name = :column_name",
49+
# 'postgresql': """
50+
# SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
51+
# FROM pg_catalog.pg_class c
52+
# JOIN information_schema.columns cols
53+
# ON c.relname = cols.table_name
54+
# WHERE c.relname = :table_name AND cols.column_name = :column_name
55+
# """
56+
# }.get(db_type.lower(), "")
57+
#
58+
# try:
59+
# # 获取所有表名
60+
# all_tables = inspector.get_table_names()
61+
#
62+
# # 如果指定了table_names,则过滤表名
63+
# target_tables = all_tables
64+
#
65+
# if table_names:
66+
# target_tables = [table.strip() for table in table_names.split(',')]
67+
# # 过滤出实际存在的表
68+
# target_tables = [table for table in target_tables if table in all_tables]
69+
# print(f"Retrieving table metadata for {len(target_tables)} tables...")
70+
# for table_name in target_tables:
71+
# # 获取表注释
72+
# table_comment = ""
73+
# try:
74+
# table_comment = inspector.get_table_comment(table_name).get("text") or ""
75+
# except SQLAlchemyError as e:
76+
# raise ValueError(f"Failed to retrieve table comments: {str(e)}")
77+
#
78+
# table_info = {
79+
# 'comment': table_comment,
80+
# 'columns': []
81+
# }
82+
#
83+
# for column in inspector.get_columns(table_name):
84+
# # 获取字段注释
85+
# column_comment = ""
86+
# try:
87+
# with engine.connect() as conn:
88+
# stmt = text(column_comment_sql)
89+
# column_comment = conn.execute(stmt, {
90+
# 'table_name': table_name,
91+
# 'column_name': column['name']
92+
# }).scalar() or ""
93+
# except SQLAlchemyError as e:
94+
# print(f"Warning: failed to get comment for {table_name}.{column['name']} - {e}")
95+
# column_comment = ""
96+
#
97+
# table_info['columns'].append({
98+
# 'name': column['name'],
99+
# 'comment': column_comment,
100+
# 'type': str(column['type'])
101+
# })
102+
#
103+
# result[table_name] = table_info
104+
# return result
105+
# except SQLAlchemyError as e:
106+
# raise ValueError(f"Failed to retrieve database table metadata: {str(e)}")
107+
# finally:
108+
# engine.dispose()
109109

110110
def format_schema_dsl(schema: dict[str, Any], with_type: bool = True, with_comment: bool = False) -> str:
111111
"""
@@ -168,18 +168,18 @@ def execute_sql(
168168
encoded_username = quote_plus(username)
169169
encoded_password = quote_plus(password)
170170
connect_args = {}
171-
driver_extra_info = None
172171
# PostgreSQL 特殊处理
173172
if db_type.lower() == 'postgresql' and schema:
174173
connect_args['options'] = f"-c search_path={schema}"
175174

176-
if db_type.lower() == 'sqlserver':
177-
import os
178-
driver_extra_info = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
175+
#if db_type.lower() == 'sqlserver':
176+
# import os
177+
# driver_extra_info = 'ODBC+Driver+17+for+SQL+Server' if os.name == 'posix' else 'SQL Server'
178+
# print(driver_extra_info)
179179
# 构建连接字符串
180180
connection_uri = _build_connection_uri(
181181
db_type, driver, encoded_username, encoded_password,
182-
host, port, database, driver_extra_info
182+
host, port, database
183183
)
184184

185185
try:
@@ -204,7 +204,7 @@ def _get_driver(db_type: str) -> str:
204204
drivers = {
205205
'mysql': 'pymysql',
206206
'oracle': 'cx_oracle',
207-
'sqlserver': 'pyodbc',
207+
'sqlserver': 'pymssql',
208208
'postgresql': 'psycopg2'
209209
}
210210
return drivers.get(db_type.lower(), '')
@@ -217,17 +217,21 @@ def _build_connection_uri(
217217
host: str,
218218
port: int,
219219
database: str,
220-
driver_info: str
221220
) -> str:
222221
"""构建数据库连接字符串"""
222+
<<<<<<< HEAD
223223
separator = ':' if db_type != 'sqlserver' else ','
224224
db_type = db_type if db_type != 'sqlserver' else 'mssql'
225225
extra_info = '' if driver_info == None else f'?driver={driver_info}'
226226

227227
return f"{db_type}+{driver}://{username}:{password}@{host}{separator}{port}/{database}{extra_info}"
228+
=======
229+
db_type = db_type if db_type != 'sqlserver' else 'mssql'
230+
return f"{db_type}+{driver}://{username}:{password}@{host}:{port}/{database}"
231+
>>>>>>> debug
228232

229233
def _process_result(result_proxy) -> Union[list[dict], dict, None]:
230234
"""处理执行结果"""
231235
if result_proxy.returns_rows:
232236
return [dict(row._mapping) for row in result_proxy]
233-
return {"rowcount": result_proxy.rowcount}
237+
return {"rowcount": result_proxy.rowcount}

0 commit comments

Comments
 (0)