Skip to content

Commit ac226e0

Browse files
committed
feat:测试支持达梦
1 parent 05a509c commit ac226e0

File tree

10 files changed

+249
-8
lines changed

10 files changed

+249
-8
lines changed

database_schema/factory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
SQLServerInspector,
55
PostgreSQLInspector,
66
OracleInspector,
7-
GaussDBInspector
7+
GaussDBInspector,
8+
DMInspector
89
)
910

1011
class InspectorFactory:
@@ -17,7 +18,8 @@ def create_inspector(db_type: str, **kwargs) -> object:
1718
'sqlserver': SQLServerInspector,
1819
'postgresql': PostgreSQLInspector,
1920
'oracle': OracleInspector,
20-
'gaussdb': GaussDBInspector
21+
'gaussdb': GaussDBInspector,
22+
'dm': DMInspector
2123
}
2224

2325
if db_type not in mapping:

database_schema/inspectors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from .postgresql import PostgreSQLInspector
55
from .oracle import OracleInspector
66
from .gaussdb import GaussDBInspector
7+
from .dm import DMInspector
78

89
__all__ = [
910
'MySQLInspector',
1011
'SQLServerInspector',
1112
'PostgreSQLInspector',
1213
'OracleInspector',
13-
'GaussDBInspector'
14+
'GaussDBInspector',
15+
'DMInspector'
1416
]

database_schema/inspectors/dm.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# database_schema/inspectors/dm.py
2+
from sqlalchemy.sql import text
3+
from sqlalchemy.engine import reflection
4+
from .base import BaseInspector
5+
from urllib.parse import quote_plus
6+
7+
class DMInspector(BaseInspector):
8+
"""达梦数据库(DM Database)元数据获取实现
9+
10+
达梦数据库特点:
11+
1. 国产数据库,部分兼容 Oracle 语法
12+
2. 使用 dmPython 驱动
13+
3. 默认端口:5236
14+
4. 支持 Schema 概念,类似 Oracle
15+
5. 表名和列名默认大写
16+
"""
17+
18+
def __init__(self, host: str, port: int, database: str,
19+
username: str, password: str, schema_name: str = None, **kwargs):
20+
super().__init__(host, port, database, username, password, schema_name)
21+
# 达梦 schema 通常与用户名一致,且默认大写
22+
# 如果提供了 schema_name,使用它;否则使用用户名
23+
self.schema_name = (schema_name or username).upper()
24+
25+
def build_conn_str(self, host: str, port: int, database: str,
26+
username: str, password: str) -> str:
27+
"""构建达梦数据库连接字符串
28+
29+
达梦连接格式:dm+dmPython://username:password@host:port/?schema=SCHEMANAME
30+
"""
31+
encoded_username = quote_plus(username)
32+
encoded_password = quote_plus(password)
33+
34+
# 达梦数据库连接字符串
35+
# 注意:达梦的连接方式类似 Oracle
36+
return f"dm+dmPython://{encoded_username}:{encoded_password}@{host}:{port}/"
37+
38+
def get_table_names(self, inspector: reflection.Inspector) -> list[str]:
39+
"""获取指定 schema 下的所有表名"""
40+
return inspector.get_table_names(schema=self.schema_name)
41+
42+
def get_table_comment(self, inspector: reflection.Inspector,
43+
table_name: str) -> str:
44+
"""获取表注释
45+
46+
达梦使用类似 Oracle 的系统表结构
47+
"""
48+
with self.engine.connect() as conn:
49+
sql = text("""
50+
SELECT COMMENTS
51+
FROM ALL_TAB_COMMENTS
52+
WHERE OWNER = :owner
53+
AND TABLE_NAME = :table_name
54+
""")
55+
try:
56+
result = conn.execute(sql, {
57+
'owner': self.schema_name,
58+
'table_name': table_name.upper()
59+
}).scalar()
60+
return result or ""
61+
except Exception as e:
62+
print(f"获取表注释失败 {table_name}: {str(e)}")
63+
return ""
64+
65+
def get_column_comment(self, inspector: reflection.Inspector,
66+
table_name: str, column_name: str) -> str:
67+
"""获取列注释"""
68+
with self.engine.connect() as conn:
69+
sql = text("""
70+
SELECT COMMENTS
71+
FROM ALL_COL_COMMENTS
72+
WHERE OWNER = :owner
73+
AND TABLE_NAME = :table_name
74+
AND COLUMN_NAME = :column_name
75+
""")
76+
try:
77+
result = conn.execute(sql, {
78+
'owner': self.schema_name,
79+
'table_name': table_name.upper(),
80+
'column_name': column_name.upper()
81+
}).scalar()
82+
return result or ""
83+
except Exception as e:
84+
print(f"获取列注释失败 {table_name}.{column_name}: {str(e)}")
85+
return ""
86+
87+
def normalize_type(self, raw_type: str) -> str:
88+
"""标准化达梦数据类型
89+
90+
达梦支持多种数据类型,部分兼容 Oracle
91+
"""
92+
# 类型映射表
93+
type_map = {
94+
# 数值类型
95+
'NUMBER': 'NUMERIC',
96+
'NUMERIC': 'NUMERIC',
97+
'DECIMAL': 'DECIMAL',
98+
'INTEGER': 'INTEGER',
99+
'INT': 'INTEGER',
100+
'BIGINT': 'BIGINT',
101+
'SMALLINT': 'SMALLINT',
102+
'TINYINT': 'TINYINT',
103+
'FLOAT': 'FLOAT',
104+
'DOUBLE': 'DOUBLE',
105+
'REAL': 'FLOAT',
106+
107+
# 字符串类型
108+
'VARCHAR': 'VARCHAR',
109+
'VARCHAR2': 'VARCHAR',
110+
'CHAR': 'CHAR',
111+
'CHARACTER': 'CHAR',
112+
'TEXT': 'TEXT',
113+
'CLOB': 'TEXT',
114+
'NCHAR': 'NCHAR',
115+
'NVARCHAR': 'NVARCHAR',
116+
'NVARCHAR2': 'NVARCHAR',
117+
118+
# 日期时间类型
119+
'DATE': 'DATE',
120+
'TIME': 'TIME',
121+
'TIMESTAMP': 'TIMESTAMP',
122+
'DATETIME': 'DATETIME',
123+
124+
# 二进制类型
125+
'BLOB': 'BLOB',
126+
'BINARY': 'BINARY',
127+
'VARBINARY': 'VARBINARY',
128+
'IMAGE': 'BLOB',
129+
130+
# 其他类型
131+
'BIT': 'BOOLEAN',
132+
'BOOLEAN': 'BOOLEAN'
133+
}
134+
135+
# 提取基础类型(去除长度、精度等)
136+
base_type = raw_type.split('(')[0].strip().upper()
137+
138+
# 返回标准化类型
139+
return type_map.get(base_type, base_type)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{# prompt_templates/sql_generation/dm_prompt.jinja #}
2+
{% extends "base_prompt.jinja" %}
3+
4+
{% block optimization_rules %}
5+
## 达梦数据库(DM)优化原则:
6+
1. **索引策略**
7+
- 对高频查询字段创建 B-Tree 索引
8+
- 复合索引遵循最左前缀原则
9+
- 使用 `CREATE INDEX` 语句创建索引:
10+
```sql
11+
CREATE INDEX idx_table_column ON table_name (column_name);
12+
```
13+
14+
2. **查询优化**
15+
- 使用 `ROWNUM` 进行分页(类似 Oracle)
16+
- 优先使用 EXISTS 替代 IN 子查询
17+
- 避免在 WHERE 子句中对字段进行函数操作
18+
- 使用 CTE (WITH 子句) 优化复杂查询
19+
20+
3. **性能验证**
21+
- 使用 `EXPLAIN` 查看执行计划
22+
- 确保索引被正确使用
23+
- 关注全表扫描(TABLE ACCESS FULL)
24+
25+
4. **数据类型规范**
26+
- 字符串类型优先使用 `VARCHAR2` 或 `VARCHAR`
27+
- 数值类型根据精度需求选择 `NUMBER`, `INTEGER`, `BIGINT`
28+
- 日期时间类型使用 `DATE`, `TIMESTAMP`, `DATETIME`
29+
- 大对象使用 `CLOB`(文本)或 `BLOB`(二进制)
30+
31+
5. **达梦特有特性**
32+
- 部分兼容 Oracle 语法和函数
33+
- 支持 Schema 概念,表名和列名默认大写
34+
- 支持序列(SEQUENCE)和触发器
35+
- 默认端口:5236
36+
{% endblock %}
37+
38+
{% block validation_rules %}
39+
## 验证机制:
40+
1. **元数据验证**
41+
```sql
42+
-- 表存在性检查
43+
SELECT COUNT(*) FROM ALL_TABLES
44+
WHERE OWNER = 'YOUR_SCHEMA' AND TABLE_NAME = 'YOUR_TABLE';
45+
46+
-- 字段存在性检查
47+
SELECT COUNT(*) FROM ALL_TAB_COLUMNS
48+
WHERE OWNER = 'YOUR_SCHEMA'
49+
AND TABLE_NAME = 'YOUR_TABLE'
50+
AND COLUMN_NAME = 'YOUR_COLUMN';
51+
```
52+
53+
2. **执行计划验证**
54+
```sql
55+
EXPLAIN
56+
SELECT ... -- 生成的查询语句
57+
```
58+
59+
3. **安全规范**
60+
- 禁止使用 SELECT *,必须显式指定字段
61+
- 所有字符串比较使用参数化值
62+
- 结果集必须使用 ROWNUM 限制行数(例如:WHERE ROWNUM <= {{ limit }})
63+
- 严格验证所有表名和字段名均在提供的元数据中
64+
{% endblock %}
65+
66+
{% block example_section %}
67+
输出示例:
68+
SELECT
69+
o.ORDER_ID AS "订单编号",
70+
c.NAME AS "客户名称",
71+
o.AMOUNT AS "订单金额",
72+
o.ORDER_DATE AS "订单日期"
73+
FROM
74+
ORDERS o
75+
INNER JOIN
76+
CUSTOMERS c ON o.CUSTOMER_ID = c.ID
77+
WHERE
78+
c.REGION = 'Asia'
79+
AND c.ACTIVE = 1
80+
AND ROWNUM <= {{ limit }}
81+
ORDER BY
82+
o.ORDER_DATE DESC;
83+
{% endblock %}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pyodbc>=4.0.39 # SQL Server新驱动
77
pymysql>=1.1.1 # MySQL驱动
88
psycopg2-binary>=2.9.10 # PostgreSQL驱动 (也用于GaussDB)
99
oracledb>=2.0.0 # Oracle驱动 (推荐使用 python-oracledb,替代 cx_Oracle)
10+
dmPython>=2.3.0 # 达梦数据库驱动
1011

1112
# 安全相关
1213
cryptography>=41.0.0,<43.0.0 # 使用范围版本以提高平台兼容性

rookie_text2data.difypkg

3.62 KB
Binary file not shown.

tools/rookie_excute_sql.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ parameters:
4242
en_US: GaussDB
4343
zh_Hans: 华为高斯数据库
4444
value: gaussdb
45+
- label:
46+
en_US: DM Database
47+
zh_Hans: 达梦数据库
48+
value: dm
4549
- name: host
4650
type: string
4751
required: true
@@ -115,7 +119,7 @@ parameters:
115119
form: llm
116120
label:
117121
en_US: Schema name
118-
zh_Hans: 数据库Schema PGSQL用户选填,默认为public
122+
zh_Hans: 数据库Schema(PGSQL默认为public,Oracle/DM默认为用户名大写)
119123
pt_BR: Schema name
120124
human_description:
121125
en_US: Schema name

tools/rookie_text2data.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ parameters:
4646
en_US: GaussDB
4747
zh_Hans: 华为高斯数据库
4848
value: gaussdb
49+
- label:
50+
en_US: DM Database
51+
zh_Hans: 达梦数据库
52+
value: dm
4953
- name: limit
5054
type: number
5155
required: false
@@ -144,7 +148,7 @@ parameters:
144148
form: llm
145149
label:
146150
en_US: Schema name
147-
zh_Hans: 数据库Schema PGSQL用户选填,默认为public
151+
zh_Hans: 数据库Schema(PGSQL默认为public,Oracle/DM默认为用户名大写)
148152
pt_BR: Schema name
149153
human_description:
150154
en_US: Schema name

utils/alchemy_db_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ def execute_sql(
234234
# 显式设置schema(部分数据库需要)
235235
if db_type.lower() in ('postgresql', 'gaussdb') and schema:
236236
conn.execute(text(f"SET search_path TO {schema}"))
237+
elif db_type.lower() in ('oracle', 'dm') and schema:
238+
# Oracle 和达梦数据库使用 ALTER SESSION 设置当前 schema
239+
conn.execute(text(f"ALTER SESSION SET CURRENT_SCHEMA = {schema}"))
237240

238241
result_proxy = conn.execute(text(sql), params)
239242
result = _process_result(result_proxy)
@@ -260,7 +263,8 @@ def _get_driver(db_type: str) -> str:
260263
'oracle': 'oracledb', # 使用新的 python-oracledb 驱动
261264
'sqlserver': 'pymssql',
262265
'postgresql': 'psycopg2',
263-
'gaussdb': 'psycopg2' # GaussDB 使用 psycopg2 驱动
266+
'gaussdb': 'psycopg2', # GaussDB 使用 psycopg2 驱动
267+
'dm': 'dmPython' # 达梦数据库使用 dmPython 驱动
264268
}
265269
return drivers.get(db_type.lower(), '')
266270

utils/prompt_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def _get_limit_clause(self, db_type: str) -> str:
3737
'oracle': "ROWNUM <= n",
3838
'sqlserver': "TOP n",
3939
'postgresql': "FETCH FIRST n ROWS ONLY",
40-
'gaussdb': "FETCH FIRST n ROWS ONLY"
40+
'gaussdb': "FETCH FIRST n ROWS ONLY",
41+
'dm': "ROWNUM <= n" # 达梦使用类似 Oracle 的 ROWNUM
4142
}
4243
return clauses.get(db_type.lower(), "LIMIT 100")
4344

@@ -47,7 +48,8 @@ def _get_optimization_rules(self, db_type: str) -> str:
4748
'oracle': "- 使用索引组织表(IOT)\n- 检查执行计划的COST值",
4849
'sqlserver': "- 使用INCLUDE索引策略\n- 查看实际执行计划",
4950
'postgresql': "- 使用INCLUDE索引列\n- 分析EXPLAIN ANALYZE结果",
50-
'gaussdb': "- 使用INCLUDE索引列\n- 分析EXPLAIN PERFORMANCE结果\n- 利用列存储特性优化宽表查询"
51+
'gaussdb': "- 使用INCLUDE索引列\n- 分析EXPLAIN PERFORMANCE结果\n- 利用列存储特性优化宽表查询",
52+
'dm': "- 使用EXPLAIN查看执行计划\n- 避免全表扫描\n- 利用ROWNUM进行分页"
5153
}
5254
return rules.get(db_type.lower(), "")
5355

0 commit comments

Comments
 (0)