Skip to content

Commit 75bd13a

Browse files
committed
feat(X-Pack): add custom prompt
#213
1 parent 1b25652 commit 75bd13a

File tree

10 files changed

+123
-22
lines changed

10 files changed

+123
-22
lines changed

backend/alembic/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# from apps.settings.models.setting_models import SQLModel
2727
# from apps.chat.models.chat_model import SQLModel
2828
from apps.terminology.models.terminology_model import SQLModel
29+
#from apps.custom_prompt.models.custom_prompt_model import SQLModel
2930
# from apps.data_training.models.data_training_model import SQLModel
3031
# from apps.dashboard.models.dashboard_model import SQLModel
3132
from common.core.config import settings # noqa
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""046_add_custom_prompt
2+
3+
Revision ID: 8855aea2dd61
4+
Revises: 45e7e52bf2b8
5+
Create Date: 2025-09-28 13:57:01.509249
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '8855aea2dd61'
15+
down_revision = '45e7e52bf2b8'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table('custom_prompt',
23+
sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False),
24+
sa.Column('oid', sa.BigInteger(), nullable=True),
25+
sa.Column('type', sa.Enum('GENERATE_SQL', 'ANALYSIS', 'PREDICT_DATA', name='customprompttypeenum', native_enum=False, length=20), nullable=True),
26+
sa.Column('create_time', sa.DateTime(), nullable=True),
27+
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True),
28+
sa.Column('prompt', sa.Text(), nullable=True),
29+
sa.Column('specific_ds', sa.Boolean(), nullable=True),
30+
sa.Column('datasource_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
31+
sa.PrimaryKeyConstraint('id')
32+
)
33+
# ### end Alembic commands ###
34+
35+
36+
def downgrade():
37+
# ### commands auto generated by Alembic - please adjust! ###
38+
op.drop_table('custom_prompt')
39+
# ### end Alembic commands ###

backend/apps/chat/models/chat_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ class OperationEnum(Enum):
4040
CHOOSE_DATASOURCE = '6'
4141
GENERATE_DYNAMIC_SQL = '7'
4242

43+
4344
class ChatFinishStep(Enum):
4445
GENERATE_SQL = 1
4546
QUERY_DATA = 2
4647
GENERATE_CHART = 3
4748

49+
4850
# TODO choose table / check connection / generate description
4951

5052
class ChatLog(SQLModel, table=True):
@@ -177,12 +179,13 @@ class AiModelQuestion(BaseModel):
177179
sub_query: Optional[list[dict]] = None
178180
terminologies: str = ""
179181
data_training: str = ""
182+
custom_prompt: str = ""
180183
error_msg: str = ""
181184

182185
def sql_sys_question(self):
183186
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
184187
lang=self.lang, terminologies=self.terminologies,
185-
data_training=self.data_training)
188+
data_training=self.data_training, custom_prompt=self.custom_prompt)
186189

187190
def sql_user_question(self, current_time: str):
188191
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
@@ -196,13 +199,14 @@ def chart_user_question(self, chart_type: Optional[str] = None):
196199
chart_type=chart_type)
197200

198201
def analysis_sys_question(self):
199-
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies)
202+
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies,
203+
custom_prompt=self.custom_prompt)
200204

201205
def analysis_user_question(self):
202206
return get_analysis_template()['user'].format(fields=self.fields, data=self.data)
203207

204208
def predict_sys_question(self):
205-
return get_predict_template()['system'].format(lang=self.lang)
209+
return get_predict_template()['system'].format(lang=self.lang, custom_prompt=self.custom_prompt)
206210

207211
def predict_user_question(self):
208212
return get_predict_template()['user'].format(fields=self.fields, data=self.data)

backend/apps/chat/task/llm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
get_last_execute_sql_error
3131
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
3232
ChatFinishStep
33+
from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil
34+
from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts
35+
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum
3336
from apps.data_training.curd.data_training import get_training_template
3437
from apps.datasource.crud.datasource import get_table_schema
3538
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
@@ -244,6 +247,9 @@ def generate_analysis(self):
244247
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
245248
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
246249
self.current_user.oid, ds_id)
250+
if SQLBotLicenseUtil.valid():
251+
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS,
252+
self.current_user.oid, ds_id)
247253

248254
analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
249255
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
@@ -288,6 +294,12 @@ def generate_predict(self):
288294
self.chat_question.fields = orjson.dumps(fields).decode()
289295
data = get_chat_chart_data(self.session, self.record.id)
290296
self.chat_question.data = orjson.dumps(data.get('data')).decode()
297+
298+
if SQLBotLicenseUtil.valid():
299+
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
300+
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA,
301+
self.current_user.oid, ds_id)
302+
291303
predict_msg: List[Union[BaseMessage, dict[str, Any]]] = []
292304
predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question()))
293305
predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question()))
@@ -509,6 +521,9 @@ def select_datasource(self):
509521
ds_id)
510522
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id,
511523
oid)
524+
if SQLBotLicenseUtil.valid():
525+
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL,
526+
oid, ds_id)
512527

513528
self.init_messages()
514529

@@ -902,6 +917,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
902917
oid, ds_id)
903918
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
904919
ds_id, oid)
920+
if SQLBotLicenseUtil.valid():
921+
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL,
922+
oid, ds_id)
905923

906924
self.init_messages()
907925

backend/apps/terminology/curd/terminology.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
from xml.dom.minidom import parseString
66

77
import dicttoxml
8-
from sqlalchemy import BigInteger
9-
from sqlalchemy import and_, or_, select, func, delete, update, union
10-
from sqlalchemy import text
8+
from sqlalchemy import and_, or_, select, func, delete, update, union, text, BigInteger
119
from sqlalchemy.orm import aliased
1210
from sqlalchemy.orm.session import Session
1311

backend/locales/en.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
"datasource_cannot_be_none": "Datasource cannot be none",
4848
"data_training_not_exists": "Example does not exists",
4949
"exists_in_db": "Question exists"
50+
},
51+
"i18n_custom_prompt": {
52+
"exists_in_db": "Prompt name exists",
53+
"not_exists": "Prompt does not exists"
5054
},
5155
"i18n_excel_export": {
5256
"data_is_empty": "The form data is empty, cannot export data"

backend/locales/zh-CN.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
"data_training_not_exists": "该示例不存在",
4949
"exists_in_db": "该问题已存在"
5050
},
51+
"i18n_custom_prompt": {
52+
"exists_in_db": "模版名称已存在",
53+
"not_exists": "该模版不存在"
54+
},
5155
"i18n_excel_export": {
5256
"data_is_empty": "表单数据为空,无法导出数据"
5357
}

backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939
"pyyaml (>=6.0.2,<7.0.0)",
4040
"fastapi-mcp (>=0.3.4,<0.4.0)",
4141
"tabulate>=0.9.0",
42-
"sqlbot-xpack>=0.0.3.31,<1.0.0",
42+
"sqlbot-xpack>=0.0.3.36,<1.0.0",
4343
"fastapi-cache2>=0.2.2",
4444
"sqlparse>=0.5.3",
4545
"redis>=6.2.0",

backend/template.yaml

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ template:
1414
<Info>内有<db-engine><m-schema><terminologies>等信息;
1515
其中,<db-engine>:提供数据库引擎及版本信息;
1616
<m-schema>:以 M-Schema 格式提供数据库表结构信息;
17-
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件
18-
<sql-examples>:提供一组SQL示例,你可以参考这些示例来生成你的回答,其中<question>内是提问,<suggestion-answer>内是对于该<question>提问的解释或者对应应该回答的SQL示例
17+
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件;
18+
<sql-examples>:提供一组SQL示例,你可以参考这些示例来生成你的回答,其中<question>内是提问,<suggestion-answer>内是对于该<question>提问的解释或者对应应该回答的SQL示例。
19+
若有<Other-Infos>块,它会提供一组<content>,可能会是额外添加的背景信息,或者是额外的生成SQL的要求,请结合额外信息或要求后生成你的回答。
1920
用户的提问在<user-question>内,<error-msg>内则会提供上次执行你提供的SQL时会出现的错误信息,<background-infos>内的<current-time>会告诉你用户当前提问的时间
2021
</Instruction>
2122
@@ -219,7 +220,6 @@ template:
219220
</chat-examples>
220221
</example>
221222
222-
### 下面是提供的信息
223223
<Info>
224224
<db-engine> {engine} </db-engine>
225225
<m-schema>
@@ -229,6 +229,7 @@ template:
229229
{terminologies}
230230
{data_training}
231231
</Info>
232+
{custom_prompt}
232233
233234
### 响应, 请根据上述要求直接返回JSON结果:
234235
```json
@@ -394,7 +395,11 @@ template:
394395
你当前的任务是根据给定的数据分析数据,并给出你的分析结果。
395396
我们会在<Info>块内提供给你信息,帮助你进行分析:
396397
<Info>内有<terminologies>等信息;
397-
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件
398+
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件。
399+
若有<Other-Infos>块,它会提供一组<content>,可能会是额外添加的背景信息,或者是额外的分析要求,请结合额外信息或要求后生成你的回答。
400+
用户会在提问中提供给你信息:
401+
<data>块内是提供给你的数,以JSON格式给出;
402+
<fields>块内提供给你对应的字段或字段别名。
398403
</Instruction>
399404
400405
你必须遵守以下规则:
@@ -404,32 +409,60 @@ template:
404409
</rule>
405410
</Rules>
406411
407-
### 下面是提供的信息
408412
<Info>
409413
{terminologies}
410414
</Info>
415+
{custom_prompt}
411416
user: |
412-
### 字段(字段别名):
417+
<fields>
413418
{fields}
419+
</fields>
414420
415-
### 数据:
421+
<data>
416422
{data}
423+
</data>
417424
predict:
418425
system: |
419-
### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
426+
<Instruction>
427+
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
428+
你当前的任务是根据给定的数据进行数据预测,并给出你的预测结果。
429+
若有<Other-Infos>块,它会提供一组<content>,可能会是额外添加的背景信息,或者是额外的分析要求,请结合额外信息或要求后生成你的回答。
430+
用户会在提问中提供给你信息:
431+
<data>块内是提供给你的数据,以JSON格式给出;
432+
<fields>块内提供给你对应的字段或字段别名。
433+
</Instruction>
420434
421-
### 说明:
422-
你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以JSON格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json数组的格式返回,返回的格式需要与传入的数据格式保持一致。
435+
你必须遵守以下规则:
436+
<Rules>
437+
<rule>
438+
请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
439+
</rule>
440+
<rule>
441+
预测的数据是一段可以展示趋势的数据,至少2个周期
442+
</rule>
443+
<rule>
444+
返回的预测数据必须与用户提供的数据同样的格式,使用JSON数组的形式返回
445+
</rule>
446+
<rule>
447+
无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式):"抱歉,该数据无法进行预测。"(若有原因,则额外返回无法预测的原因)
448+
</rule>
449+
<rule>
450+
预测的数据不需要返回用户提供的原有数据,请直接返回你预测的部份
451+
</rule>
452+
</Rules>
453+
{custom_prompt}
454+
455+
### 响应, 请根据上述要求直接返回JSON结果:
423456
```json
424457
425-
无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式,需要翻译为 {lang} 输出):"抱歉,该数据无法进行预测。(有原因则返回无法预测的原因)"
426-
如果可以预测,则不需要返回原有数据,直接返回预测的部份
427458
user: |
428-
### 字段(字段别名):
459+
<fields>
429460
{fields}
461+
</fields>
430462
431-
### 数据:
463+
<data>
432464
{data}
465+
</data>
433466
datasource:
434467
system: |
435468
### 请使用语言:{lang} 回答

frontend/src/views/system/prompt/index.vue

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ const search = () => {
202202
oldKeywords.value = keywords.value
203203
promptApi
204204
.getList(pageInfo.currentPage, pageInfo.pageSize, currentType.value, {
205-
question: keywords.value,
205+
name: keywords.value,
206206
})
207207
.then((res: any) => {
208208
toggleRowLoading.value = true

0 commit comments

Comments
 (0)