Skip to content

Commit c2ac82b

Browse files
authored
Merge pull request #91 from jaguarliuu/dev/jagaurliu
Dev/jagaurliu
2 parents 20e91cd + ac226e0 commit c2ac82b

File tree

14 files changed

+704
-19
lines changed

14 files changed

+704
-19
lines changed

CLAUDE.md

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
**rookie_text2data** is a Dify plugin that converts natural language queries into secure, optimized SQL statements. It supports MySQL, PostgreSQL, Oracle, SQL Server, and GaussDB databases with built-in security mechanisms to prevent data leaks and SQL injection.
8+
9+
## Development Commands
10+
11+
### Running the Plugin
12+
```bash
13+
# Start the plugin with extended timeout (120s)
14+
python main.py
15+
```
16+
17+
### Testing
18+
```bash
19+
# Run test file
20+
python _test/test.py
21+
22+
# Test prompt loading
23+
python utils/prompt_loader.py
24+
```
25+
26+
### Dependencies
27+
```bash
28+
# Install dependencies
29+
pip install -r requirements.txt
30+
```
31+
32+
## Architecture Overview
33+
34+
### Core Components
35+
36+
1. **Two-Tool Architecture**
37+
- `rookie_text2data`: Generates SQL from natural language using LLM
38+
- `rookie_excute_sql`: Executes generated SQL with security validation
39+
40+
2. **Database Schema Inspection System** (`database_schema/`)
41+
- **Factory Pattern**: `InspectorFactory` creates database-specific inspectors
42+
- **Base Inspector**: `BaseInspector` (abstract) defines interface for metadata extraction
43+
- **Database-Specific Inspectors**: MySQL, PostgreSQL, Oracle, SQL Server, GaussDB implementations
44+
- **Connector**: `get_db_schema()` orchestrates schema extraction
45+
- **Formatter**: `format_schema_dsl()` compresses schema into LLM-friendly DSL format
46+
47+
3. **Prompt Template System** (`prompt_templates/sql_generation/`)
48+
- Jinja2-based templates per database type
49+
- `PromptLoader` injects database-specific syntax rules
50+
- Templates enforce security rules (SELECT-only, result limits, field validation)
51+
52+
4. **Database Client** (`utils/alchemy_db_client.py`)
53+
- SQLAlchemy-based execution with engine caching
54+
- Connection pooling with lifecycle tracking
55+
- Automatic cleanup via `atexit` handler
56+
57+
### Security Architecture
58+
59+
The plugin implements defense-in-depth:
60+
61+
- **SQL Generation Phase** (tools/rookie_text2data.py:13-66):
62+
- Schema whitelist validation via metadata inspection
63+
- LLM prompt enforces SELECT-only statements
64+
- Automatic LIMIT clause injection (default: 100 rows)
65+
- Database-specific syntax validation
66+
67+
- **Execution Phase** (tools/rookie_excute_sql.py:13-201):
68+
- SQL injection detection via `_contains_risk_commands()` (line 165)
69+
- Blocks DML operations: DROP, DELETE, TRUNCATE, ALTER, UPDATE, INSERT
70+
- Parameterized query execution
71+
- Empty result handling
72+
73+
### Schema DSL Format
74+
75+
The system uses compressed DSL to reduce LLM token usage:
76+
77+
```
78+
T:table_name(field1:i, field2:s, field3:dt)
79+
```
80+
81+
Type abbreviations:
82+
- `i` = INTEGER, INT, BIGINT, SMALLINT, TINYINT
83+
- `s` = VARCHAR, TEXT, CHAR, NVARCHAR, NCHAR
84+
- `dt` = DATETIME, DATE, TIMESTAMP, TIME
85+
- `f` = DECIMAL, NUMERIC, FLOAT, DOUBLE
86+
- `b` = BOOLEAN, BOOL
87+
- `j` = JSON, JSONB
88+
89+
### Database-Specific Handling
90+
91+
Each inspector handles database quirks:
92+
93+
- **MySQL**: Schema equals database name; uses `information_schema.COLUMNS` for comments
94+
- **PostgreSQL**: Supports custom schemas (default: `public`); uses `pg_catalog` for metadata
95+
- **Oracle**: Uses `ROWNUM` for limits; type normalization for Oracle-specific types
96+
- **SQL Server**: Schema defaults to `dbo`; uses `TOP n` syntax
97+
- **GaussDB**: Compatible with PostgreSQL protocol; uses special connection args (`gssencmode=disable`) to bypass SASL authentication; supports both Oracle and PostgreSQL compatibility modes
98+
99+
### Connection Management
100+
101+
The `alchemy_db_client.py` implements:
102+
- **Engine Caching**: Reuses connections via `_ENGINE_CACHE` keyed by `db_type://user@host:port/database/schema`
103+
- **Connection Pooling**: SQLAlchemy pool with pre-ping and 1-hour recycle
104+
- **Lifecycle Tracking**: Logs connection acquire/release, pool status, and uptime
105+
- **Automatic Cleanup**: `dispose_all_engines()` runs at program exit
106+
107+
Key functions:
108+
- `get_or_create_engine()`: Retrieves cached engine or creates new one
109+
- `execute_sql()`: Executes SQL with schema support, logging, and error handling
110+
- `log_connection_status()`: Tracks active connection count and pool state
111+
112+
## Key Implementation Details
113+
114+
### Natural Language → SQL Flow
115+
116+
1. User provides natural language query + database credentials
117+
2. `RookieText2dataTool._invoke()` extracts schema metadata via `get_db_schema()`
118+
3. Schema formatted to DSL via `format_schema_dsl()`
119+
4. `PromptLoader` builds database-specific system prompt
120+
5. LLM generates SQL (model configurable, Qwen-max recommended)
121+
6. SQL returned in JSON or text format
122+
123+
### SQL Execution Flow
124+
125+
1. `RookieExecuteSqlTool._invoke()` validates parameters
126+
2. `_contains_risk_commands()` blocks dangerous operations
127+
3. `execute_sql()` acquires connection from cache/pool
128+
4. For PostgreSQL: sets `search_path` to target schema
129+
5. Executes via SQLAlchemy's `text()` with parameterization
130+
6. Results formatted as JSON/CSV/HTML/text
131+
7. Connection returned to pool
132+
133+
### Prompt Template Customization
134+
135+
Templates located in `prompt_templates/sql_generation/`:
136+
- `base_prompt.jinja`: Common rules for all databases
137+
- `{db_type}_prompt.jinja`: Database-specific overrides
138+
139+
Context variables injected by `PromptLoader.get_prompt()`:
140+
- `db_type`: Database type (MySQL, PostgreSQL, etc.)
141+
- `meta_data`: Compressed DSL schema
142+
- `limit_clause`: Database-specific syntax (LIMIT n, TOP n, ROWNUM <= n, FETCH FIRST n ROWS ONLY)
143+
- `optimization_rules`: Database-specific performance tips
144+
- `user_custom_prompt`: User-provided custom instructions
145+
- `limit`: Maximum result rows (default: 100)
146+
147+
## Common Development Patterns
148+
149+
### Adding a New Database Type
150+
151+
1. Create inspector in `database_schema/inspectors/{dbtype}.py`:
152+
- Extend `BaseInspector`
153+
- Implement `build_conn_str()`, `get_table_names()`, `get_table_comment()`, `get_column_comment()`, `normalize_type()`
154+
2. Register in `database_schema/factory.py` mapping
155+
3. Add driver to `requirements.txt` and `_get_driver()` in `alchemy_db_client.py`
156+
4. Create prompt template: `prompt_templates/sql_generation/{dbtype}_prompt.jinja`
157+
5. Update `PromptLoader._get_limit_clause()` and `_get_optimization_rules()`
158+
6. Add to manifest.yaml if needed
159+
160+
### Modifying Security Rules
161+
162+
- SQL injection patterns: Edit `RookieExecuteSqlTool.RISK_KEYWORDS` (line 14)
163+
- Schema validation: Modify `get_db_schema()` in `database_schema/connector.py`
164+
- LLM constraints: Update prompt templates in `prompt_templates/sql_generation/`
165+
166+
### Extending Result Formats
167+
168+
Add format handler in `tools/rookie_excute_sql.py`:
169+
1. Add format to `SUPPORTED_FORMATS` (line 15)
170+
2. Implement `_handle_{format}()` method
171+
3. Update `_handle_result_format()` dispatch logic
172+
173+
## Database Connection Details
174+
175+
### Connection String Formats
176+
177+
Built by `BaseInspector.build_conn_str()` implementations:
178+
- **MySQL**: `mysql+pymysql://{user}:{pass}@{host}:{port}/{db}?charset=utf8mb4`
179+
- **PostgreSQL**: `postgresql+psycopg2://{user}:{pass}@{host}:{port}/{db}`
180+
- **SQL Server**: `mssql+pymssql://{user}:{pass}@{host}:{port}/{db}`
181+
- **Oracle**: `oracle+oracledb://{user}:{pass}@{host}:{port}/{db}` (uses python-oracledb, the modern replacement for cx_Oracle)
182+
- **GaussDB**: `postgresql+psycopg2://{user}:{pass}@{host}:{port}/{db}?sslmode=disable&gssencmode=disable` (uses PostgreSQL protocol with special auth params)
183+
184+
Passwords are URL-encoded via `urllib.parse.quote_plus()`.
185+
186+
### Schema Handling
187+
188+
- **MySQL**: `schema_name` forced to equal `database` (line 13 in mysql.py)
189+
- **PostgreSQL**: Defaults to `public`, customizable via constructor (line 13 in postgresql.py)
190+
- **SQL Server**: Defaults to `dbo` (line 49 in rookie_excute_sql.py)
191+
- **Oracle**: Schema-aware via `schema_name` parameter
192+
- **GaussDB**: Defaults to `public`, same as PostgreSQL (line 13 in gaussdb.py)
193+
194+
## Plugin Metadata
195+
196+
Defined in `manifest.yaml`:
197+
- Entry point: `main.py`
198+
- Python version: 3.12
199+
- Memory limit: 256MB (268435456 bytes)
200+
- Storage: 1MB
201+
- Permissions: tools, LLM models, app integration
202+
- Architectures: amd64, arm64
203+
204+
## Testing Notes
205+
206+
- The `_test/test.py` file contains integration tests
207+
- Prompt templates can be tested via `utils/prompt_loader.py`
208+
- Test different database types by varying `db_type` parameter
209+
- Deep-thinking LLM models are NOT supported (use Qwen-max, DeepSeek V3, ChatGLM-6B, etc.)

database_schema/factory.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
MySQLInspector,
44
SQLServerInspector,
55
PostgreSQLInspector,
6-
OracleInspector
6+
OracleInspector,
7+
GaussDBInspector,
8+
DMInspector
79
)
810

911
class InspectorFactory:
@@ -15,10 +17,12 @@ def create_inspector(db_type: str, **kwargs) -> object:
1517
'mysql': MySQLInspector,
1618
'sqlserver': SQLServerInspector,
1719
'postgresql': PostgreSQLInspector,
18-
'oracle': OracleInspector
20+
'oracle': OracleInspector,
21+
'gaussdb': GaussDBInspector,
22+
'dm': DMInspector
1923
}
20-
24+
2125
if db_type not in mapping:
2226
raise ValueError(f"Unsupported database type: {db_type}")
23-
27+
2428
return mapping[db_type](**kwargs)

database_schema/inspectors/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
from .sqlserver import SQLServerInspector
44
from .postgresql import PostgreSQLInspector
55
from .oracle import OracleInspector
6+
from .gaussdb import GaussDBInspector
7+
from .dm import DMInspector
68

79
__all__ = [
810
'MySQLInspector',
911
'SQLServerInspector',
1012
'PostgreSQLInspector',
11-
'OracleInspector'
13+
'OracleInspector',
14+
'GaussDBInspector',
15+
'DMInspector'
1216
]

database_schema/inspectors/dm.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# database_schema/inspectors/dm.py
2+
from sqlalchemy.sql import text
3+
from sqlalchemy.engine import reflection
4+
from .base import BaseInspector
5+
from urllib.parse import quote_plus
6+
7+
class DMInspector(BaseInspector):
8+
"""达梦数据库(DM Database)元数据获取实现
9+
10+
达梦数据库特点:
11+
1. 国产数据库,部分兼容 Oracle 语法
12+
2. 使用 dmPython 驱动
13+
3. 默认端口:5236
14+
4. 支持 Schema 概念,类似 Oracle
15+
5. 表名和列名默认大写
16+
"""
17+
18+
def __init__(self, host: str, port: int, database: str,
19+
username: str, password: str, schema_name: str = None, **kwargs):
20+
super().__init__(host, port, database, username, password, schema_name)
21+
# 达梦 schema 通常与用户名一致,且默认大写
22+
# 如果提供了 schema_name,使用它;否则使用用户名
23+
self.schema_name = (schema_name or username).upper()
24+
25+
def build_conn_str(self, host: str, port: int, database: str,
26+
username: str, password: str) -> str:
27+
"""构建达梦数据库连接字符串
28+
29+
达梦连接格式:dm+dmPython://username:password@host:port/?schema=SCHEMANAME
30+
"""
31+
encoded_username = quote_plus(username)
32+
encoded_password = quote_plus(password)
33+
34+
# 达梦数据库连接字符串
35+
# 注意:达梦的连接方式类似 Oracle
36+
return f"dm+dmPython://{encoded_username}:{encoded_password}@{host}:{port}/"
37+
38+
def get_table_names(self, inspector: reflection.Inspector) -> list[str]:
39+
"""获取指定 schema 下的所有表名"""
40+
return inspector.get_table_names(schema=self.schema_name)
41+
42+
def get_table_comment(self, inspector: reflection.Inspector,
43+
table_name: str) -> str:
44+
"""获取表注释
45+
46+
达梦使用类似 Oracle 的系统表结构
47+
"""
48+
with self.engine.connect() as conn:
49+
sql = text("""
50+
SELECT COMMENTS
51+
FROM ALL_TAB_COMMENTS
52+
WHERE OWNER = :owner
53+
AND TABLE_NAME = :table_name
54+
""")
55+
try:
56+
result = conn.execute(sql, {
57+
'owner': self.schema_name,
58+
'table_name': table_name.upper()
59+
}).scalar()
60+
return result or ""
61+
except Exception as e:
62+
print(f"获取表注释失败 {table_name}: {str(e)}")
63+
return ""
64+
65+
def get_column_comment(self, inspector: reflection.Inspector,
66+
table_name: str, column_name: str) -> str:
67+
"""获取列注释"""
68+
with self.engine.connect() as conn:
69+
sql = text("""
70+
SELECT COMMENTS
71+
FROM ALL_COL_COMMENTS
72+
WHERE OWNER = :owner
73+
AND TABLE_NAME = :table_name
74+
AND COLUMN_NAME = :column_name
75+
""")
76+
try:
77+
result = conn.execute(sql, {
78+
'owner': self.schema_name,
79+
'table_name': table_name.upper(),
80+
'column_name': column_name.upper()
81+
}).scalar()
82+
return result or ""
83+
except Exception as e:
84+
print(f"获取列注释失败 {table_name}.{column_name}: {str(e)}")
85+
return ""
86+
87+
def normalize_type(self, raw_type: str) -> str:
88+
"""标准化达梦数据类型
89+
90+
达梦支持多种数据类型,部分兼容 Oracle
91+
"""
92+
# 类型映射表
93+
type_map = {
94+
# 数值类型
95+
'NUMBER': 'NUMERIC',
96+
'NUMERIC': 'NUMERIC',
97+
'DECIMAL': 'DECIMAL',
98+
'INTEGER': 'INTEGER',
99+
'INT': 'INTEGER',
100+
'BIGINT': 'BIGINT',
101+
'SMALLINT': 'SMALLINT',
102+
'TINYINT': 'TINYINT',
103+
'FLOAT': 'FLOAT',
104+
'DOUBLE': 'DOUBLE',
105+
'REAL': 'FLOAT',
106+
107+
# 字符串类型
108+
'VARCHAR': 'VARCHAR',
109+
'VARCHAR2': 'VARCHAR',
110+
'CHAR': 'CHAR',
111+
'CHARACTER': 'CHAR',
112+
'TEXT': 'TEXT',
113+
'CLOB': 'TEXT',
114+
'NCHAR': 'NCHAR',
115+
'NVARCHAR': 'NVARCHAR',
116+
'NVARCHAR2': 'NVARCHAR',
117+
118+
# 日期时间类型
119+
'DATE': 'DATE',
120+
'TIME': 'TIME',
121+
'TIMESTAMP': 'TIMESTAMP',
122+
'DATETIME': 'DATETIME',
123+
124+
# 二进制类型
125+
'BLOB': 'BLOB',
126+
'BINARY': 'BINARY',
127+
'VARBINARY': 'VARBINARY',
128+
'IMAGE': 'BLOB',
129+
130+
# 其他类型
131+
'BIT': 'BOOLEAN',
132+
'BOOLEAN': 'BOOLEAN'
133+
}
134+
135+
# 提取基础类型(去除长度、精度等)
136+
base_type = raw_type.split('(')[0].strip().upper()
137+
138+
# 返回标准化类型
139+
return type_map.get(base_type, base_type)

0 commit comments

Comments
 (0)