Skip to content

Commit 55fb063

Browse files
committed
feat: improve generate sql template
1.Add parameter: GENERATE_SQL_QUERY_LIMIT_ENABLED (default value is True to control data return Limit 1000 rows) 2.Use different SQL example templates based on different data sources.
1 parent 67f446e commit 55fb063

File tree

18 files changed

+1086
-93
lines changed

18 files changed

+1086
-93
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime
22
from enum import Enum
3-
from typing import List, Optional
3+
from typing import List, Optional, Union
44

55
from fastapi import Body
66
from pydantic import BaseModel
@@ -9,13 +9,14 @@
99
from sqlalchemy.dialects.postgresql import JSONB
1010
from sqlmodel import SQLModel, Field
1111

12+
from apps.db.constant import DB
1213
from apps.template.filter.generator import get_permissions_template
1314
from apps.template.generate_analysis.generator import get_analysis_template
1415
from apps.template.generate_chart.generator import get_chart_template
1516
from apps.template.generate_dynamic.generator import get_dynamic_template
1617
from apps.template.generate_guess_question.generator import get_guess_question_template
1718
from apps.template.generate_predict.generator import get_predict_template
18-
from apps.template.generate_sql.generator import get_sql_template
19+
from apps.template.generate_sql.generator import get_sql_template, get_sql_example_template
1920
from apps.template.select_datasource.generator import get_datasource_template
2021

2122

@@ -182,10 +183,27 @@ class AiModelQuestion(BaseModel):
182183
custom_prompt: str = ""
183184
error_msg: str = ""
184185

185-
def sql_sys_question(self):
186+
def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
187+
_sql_template = get_sql_example_template(db_type)
188+
_base_sql_rules = _sql_template['quot_rule'] + _sql_template['limit_rule'] + _sql_template['other_rule']
189+
_query_limit = get_sql_template()['query_limit'] if enable_query_limit else ''
190+
_sql_examples = _sql_template['basic_example']
191+
_example_engine = _sql_template['example_engine']
192+
_example_answer_1 = _sql_template['example_answer_1_with_limit'] if enable_query_limit else _sql_template[
193+
'example_answer_1']
194+
_example_answer_2 = _sql_template['example_answer_2_with_limit'] if enable_query_limit else _sql_template[
195+
'example_answer_2']
196+
_example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[
197+
'example_answer_3']
186198
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
187199
lang=self.lang, terminologies=self.terminologies,
188-
data_training=self.data_training, custom_prompt=self.custom_prompt)
200+
data_training=self.data_training, custom_prompt=self.custom_prompt,
201+
base_sql_rules=_base_sql_rules, query_limit=_query_limit,
202+
basic_sql_examples=_sql_examples,
203+
example_engine=_example_engine,
204+
example_answer_1=_example_answer_1,
205+
example_answer_2=_example_answer_2,
206+
example_answer_3=_example_answer_3)
189207

190208
def sql_user_question(self, current_time: str):
191209
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,

backend/apps/chat/task/llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def init_messages(self):
171171

172172
self.sql_message = []
173173
# add sys prompt
174-
self.sql_message.append(SystemMessage(content=self.chat_question.sql_sys_question()))
174+
self.sql_message.append(SystemMessage(
175+
content=self.chat_question.sql_sys_question(self.ds.type, settings.GENERATE_SQL_QUERY_LIMIT_ENABLED)))
175176
if last_sql_messages is not None and len(last_sql_messages) > 0:
176177
# limit count
177178
for last_sql_message in last_sql_messages[count_limit:]:

backend/apps/db/constant.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,33 @@ def __init__(self, type_name):
1313

1414

1515
class DB(Enum):
16-
mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy)
17-
sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy)
18-
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy)
19-
excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy)
20-
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy)
21-
ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy)
22-
dm = ('dm', '达梦', '"', '"', ConnectType.py_driver)
23-
doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver)
24-
redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver)
25-
es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver)
26-
kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver)
27-
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver)
28-
29-
def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType):
16+
excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL')
17+
redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver, 'AWS_Redshift')
18+
ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy, 'ClickHouse')
19+
dm = ('dm', '达梦', '"', '"', ConnectType.py_driver, 'DM')
20+
doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver, 'Doris')
21+
es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver, 'Elasticsearch')
22+
kingbase = ('kingbase', 'Kingbase', '"', '"', ConnectType.py_driver, 'Kingbase')
23+
sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy, 'Microsoft_SQL_Server')
24+
mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy, 'MySQL')
25+
oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy, 'Oracle')
26+
pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy, 'PostgreSQL')
27+
starrocks = ('starrocks', 'StarRocks', '`', '`', ConnectType.py_driver, 'StarRocks')
28+
29+
def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType, template_name: str):
3030
self.type = type
3131
self.db_name = db_name
3232
self.prefix = prefix
3333
self.suffix = suffix
3434
self.connect_type = connect_type
35+
self.template_name = template_name
3536

3637
@classmethod
37-
def get_db(cls, type):
38+
def get_db(cls, type, default_if_none=False):
3839
for db in cls:
3940
if db.type == type:
4041
return db
41-
raise ValueError(f"Invalid db type: {type}")
42+
if default_if_none:
43+
return DB.pg
44+
else:
45+
raise ValueError(f"Invalid db type: {type}")
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
from apps.template.template import get_base_template
1+
from typing import Union
2+
3+
from apps.db.constant import DB
4+
from apps.template.template import get_base_template, get_sql_template as get_base_sql_template
25

36

47
def get_sql_template():
58
template = get_base_template()
69
return template['template']['sql']
10+
11+
12+
def get_sql_example_template(db_type: Union[str, DB]):
13+
template = get_base_sql_template(db_type)
14+
return template['template']

backend/apps/template/template.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,64 @@
11
import yaml
2+
from pathlib import Path
3+
from functools import cache
4+
from typing import Union
25

3-
base_template = None
6+
from apps.db.constant import DB
47

8+
# 基础路径配置
9+
PROJECT_ROOT = Path(__file__).parent.parent.parent
10+
TEMPLATES_DIR = PROJECT_ROOT / 'templates'
11+
BASE_TEMPLATE_PATH = TEMPLATES_DIR / 'template.yaml'
12+
SQL_TEMPLATES_DIR = TEMPLATES_DIR / 'sql_examples'
513

6-
def load():
7-
with open('./template.yaml', 'r', encoding='utf-8') as f:
8-
global base_template
9-
base_template = yaml.load(f, Loader=yaml.SafeLoader)
14+
15+
@cache
16+
def _load_template_file(file_path: Path):
17+
"""内部函数:加载并解析YAML文件"""
18+
try:
19+
with open(file_path, 'r', encoding='utf-8') as f:
20+
return yaml.safe_load(f)
21+
except FileNotFoundError:
22+
raise FileNotFoundError(f"Template file not found at {file_path}")
23+
except yaml.YAMLError as e:
24+
raise ValueError(f"Error parsing YAML file {file_path}: {e}")
1025

1126

1227
def get_base_template():
13-
if not base_template:
14-
load()
15-
return base_template
28+
"""获取基础模板(自动缓存)"""
29+
return _load_template_file(BASE_TEMPLATE_PATH)
30+
31+
32+
def get_sql_template(db_type: Union[str, DB]):
33+
# 处理输入参数
34+
if isinstance(db_type, str):
35+
# 如果是字符串,查找对应的枚举值,找不到则使用默认的 DB.pg
36+
db_enum = DB.get_db(db_type, default_if_none=True)
37+
elif isinstance(db_type, DB):
38+
db_enum = db_type
39+
else:
40+
db_enum = DB.pg
41+
42+
# 使用 template_name 作为文件名
43+
template_path = SQL_TEMPLATES_DIR / f"{db_enum.template_name}.yaml"
44+
45+
return _load_template_file(template_path)
46+
47+
48+
def get_all_sql_templates():
49+
"""获取所有支持的数据库模板"""
50+
templates = {}
51+
for db in DB:
52+
try:
53+
templates[db.type] = get_sql_template(db)
54+
except FileNotFoundError:
55+
# 如果某个数据库的模板文件不存在,跳过
56+
continue
57+
return templates
58+
59+
60+
def reload_all_templates():
61+
"""清空所有模板缓存"""
62+
_load_template_file.cache_clear()
63+
64+

backend/common/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
9696
EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
9797
EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
9898

99+
GENERATE_SQL_QUERY_LIMIT_ENABLED: bool = True
100+
99101
PARSE_REASONING_BLOCK_ENABLED: bool = True
100102
DEFAULT_REASONING_CONTENT_START: str = '<think>'
101103
DEFAULT_REASONING_CONTENT_END: str = '</think>'
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
template:
2+
quot_rule: |
3+
<rule>
4+
必须对数据库名、表名、字段名、别名外层加双引号(")。
5+
<note>
6+
1. 点号(.)不能包含在引号内,必须写成 "schema"."table"
7+
2. 即使标识符不含特殊字符或非关键字,也需强制加双引号
8+
3. Redshift 默认将未加引号的标识符转为小写
9+
</note>
10+
</rule>
11+
12+
limit_rule: |
13+
<rule>
14+
使用 LIMIT 或 FETCH FIRST 限制行数(Redshift 兼容 PostgreSQL)
15+
<note>
16+
1. 标准写法:LIMIT 100
17+
2. 可选写法:FETCH FIRST 100 ROWS ONLY
18+
</note>
19+
</rule>
20+
21+
other_rule: |
22+
<rule>必须为每个表生成别名(不加AS)</rule>
23+
<rule>禁止使用星号(*),必须明确字段名</rule>
24+
<rule>中文/特殊字符字段需保留原名并添加英文别名</rule>
25+
<rule>函数字段必须加别名</rule>
26+
<rule>百分比字段保留两位小数并以%结尾(使用ROUND+CONCAT)</rule>
27+
<rule>避免与Redshift关键字冲突(如USER/GROUP/ORDER等)</rule>
28+
29+
basic_example: |
30+
<basic-examples>
31+
<intro>
32+
📌 以下示例严格遵循<Rules>中的 AWS Redshift 规范,展示符合要求的 SQL 写法与典型错误案例。
33+
⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。
34+
🔍 重点观察:
35+
1. 双引号包裹所有数据库对象的规范用法
36+
2. 中英别名/百分比/函数等特殊字段的处理
37+
3. 关键字冲突的规避方式
38+
</intro>
39+
<example>
40+
<input>查询 TEST.SALES 表的前100条订单(含百分比计算)</input>
41+
<output-bad>
42+
SELECT * FROM TEST.SALES LIMIT 100 -- 错误:未加引号、使用星号
43+
SELECT "订单ID", "金额" FROM "TEST"."SALES" "t1" FETCH FIRST 100 ROWS ONLY -- 错误:缺少英文别名
44+
SELECT COUNT("订单ID") FROM "TEST"."SALES" "t1" -- 错误:函数未加别名
45+
</output-bad>
46+
<output-good>
47+
SELECT
48+
"t1"."订单ID" AS "order_id",
49+
"t1"."金额" AS "amount",
50+
COUNT("t1"."订单ID") AS "total_orders",
51+
CONCAT(ROUND("t1"."折扣率" * 100, 2), '%') AS "discount_percent"
52+
FROM "TEST"."SALES" "t1"
53+
LIMIT 100
54+
</output-good>
55+
</example>
56+
57+
<example>
58+
<input>统计用户表 PUBLIC.USERS(含关键字字段user)的活跃占比</input>
59+
<output-bad>
60+
SELECT user, status FROM PUBLIC.USERS -- 错误:未处理关键字和引号
61+
SELECT "user", ROUND(active_ratio) FROM "PUBLIC"."USERS" -- 错误:百分比格式错误
62+
</output-bad>
63+
<output-good>
64+
SELECT
65+
"u"."user" AS "user_account",
66+
CONCAT(ROUND("u"."active_ratio" * 100, 2), '%') AS "active_percent"
67+
FROM "PUBLIC"."USERS" "u"
68+
WHERE "u"."status" = 1
69+
FETCH FIRST 1000 ROWS ONLY
70+
</output-good>
71+
</example>
72+
</basic-examples>
73+
74+
example_engine: AWS Redshift 1.0
75+
example_answer_1: |
76+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"}
77+
example_answer_1_with_limit: |
78+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000 OFFSET 0","tables":["sample_country_gdp"],"chart-type":"line"}
79+
example_answer_2: |
80+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC","tables":["sample_country_gdp"],"chart-type":"pie"}
81+
example_answer_2_with_limit: |
82+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000 OFFSET 0","tables":["sample_country_gdp"],"chart-type":"pie"}
83+
example_answer_3: |
84+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国'","tables":["sample_country_gdp"],"chart-type":"table"}
85+
example_answer_3_with_limit: |
86+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000 OFFSET 0","tables":["sample_country_gdp"],"chart-type":"table"}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
template:
2+
quot_rule: |
3+
<rule>
4+
必须对数据库名、表名、字段名、别名外层加双引号(")。
5+
<note>
6+
1. 点号(.)不能包含在引号内,必须写成 "database"."table"
7+
2. ClickHouse 严格区分大小写,必须通过引号保留原始大小写
8+
3. 嵌套字段使用点号连接:`"json_column.field"`
9+
</note>
10+
</rule>
11+
12+
limit_rule: |
13+
<rule>
14+
行数限制使用标准SQL语法:
15+
<note>
16+
1. 标准写法:LIMIT [count]
17+
2. 分页写法:LIMIT [count] OFFSET [start]
18+
3. 禁止使用原生 `topk()` 等函数替代
19+
</note>
20+
</rule>
21+
22+
other_rule: |
23+
<rule>必须为每个表生成简短别名(如t1/t2)</rule>
24+
<rule>禁止使用星号(*),必须明确字段名</rule>
25+
<rule>JSON字段需用点号语法访问:`"column.field"`</rule>
26+
<rule>函数字段必须加别名</rule>
27+
<rule>百分比显示为:`ROUND(x*100,2) || '%'`</rule>
28+
<rule>避免与ClickHouse关键字冲突(如`timestamp`/`default`)</rule>
29+
30+
basic_example: |
31+
<basic-examples>
32+
<intro>
33+
📌 以下示例严格遵循<Rules>中的 ClickHouse 规范,展示符合要求的 SQL 写法与典型错误案例。
34+
⚠️ 注意:示例中的表名、字段名均为演示虚构,实际使用时需替换为用户提供的真实标识符。
35+
🔍 重点观察:
36+
1. 双引号包裹所有数据库对象的规范用法
37+
2. 中英别名/百分比/函数等特殊字段的处理
38+
3. 关键字冲突的规避方式
39+
</intro>
40+
<example>
41+
<input>查询 events 表的前100条错误日志(含JSON字段)</input>
42+
<output-bad>
43+
SELECT * FROM default.events LIMIT 100 -- 错误1:使用星号
44+
SELECT message FROM "default"."events" WHERE level = 'error' -- 错误2:未处理JSON字段
45+
SELECT "message", "extra.error_code" FROM events LIMIT 100 -- 错误3:表名未加引号
46+
</output-bad>
47+
<output-good>
48+
SELECT
49+
"e"."message" AS "log_content",
50+
"e"."extra"."error_code" AS "error_id",
51+
toDateTime("e"."timestamp") AS "log_time"
52+
FROM "default"."events" "e"
53+
WHERE "e"."level" = 'error'
54+
LIMIT 100
55+
</output-good>
56+
</example>
57+
58+
<example>
59+
<input>统计各地区的错误率Top 5(含百分比)</input>
60+
<output-bad>
61+
SELECT region, COUNT(*) FROM events GROUP BY region -- 错误1:使用COUNT(*)
62+
SELECT "region", MAX("count") FROM "events" GROUP BY 1 -- 错误2:使用序号分组
63+
</output-bad>
64+
<output-good>
65+
SELECT
66+
"e"."region" AS "area",
67+
COUNT(*) AS "total",
68+
COUNTIf("e"."level" = 'error') AS "error_count",
69+
ROUND(error_count * 100.0 / total, 2) || '%' AS "error_rate"
70+
FROM "default"."events" "e"
71+
GROUP BY "e"."region"
72+
ORDER BY "error_rate" DESC
73+
LIMIT 5
74+
</output-good>
75+
</example>
76+
</basic-examples>
77+
78+
example_engine: ClickHouse 23.3
79+
example_answer_1: |
80+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\"","tables":["sample_country_gdp"],"chart-type":"line"}
81+
example_answer_1_with_limit: |
82+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}
83+
example_answer_2: |
84+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC","tables":["sample_country_gdp"],"chart-type":"pie"}
85+
example_answer_2_with_limit: |
86+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}
87+
example_answer_3: |
88+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国'","tables":["sample_country_gdp"],"chart-type":"table"}
89+
example_answer_3_with_limit: |
90+
{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp_usd\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}

0 commit comments

Comments
 (0)