Skip to content

Commit e881b3e

Browse files
authored
feat:The generation of dialogue titles is intelligently generated by LLM (#494)
1 parent c7693fc commit e881b3e

File tree

3 files changed

+46
-17
lines changed

3 files changed

+46
-17
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
210210
example_answer_2=_example_answer_2,
211211
example_answer_3=_example_answer_3)
212212

213-
def sql_user_question(self, current_time: str):
213+
def sql_user_question(self, current_time: str, change_title: bool):
214214
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
215-
rule=self.rule, current_time=current_time, error_msg=self.error_msg)
215+
rule=self.rule, current_time=current_time, error_msg=self.error_msg,change_title = change_title)
216216

217217
def chart_sys_question(self):
218218
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)

backend/apps/chat/task/llm.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def select_datasource(self, _session: Session):
524524
def generate_sql(self, _session: Session):
525525
# append current question
526526
self.sql_message.append(HumanMessage(
527-
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))))
527+
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title)))
528528

529529
self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session,
530530
ai_modal_id=self.chat_question.ai_modal_id,
@@ -756,6 +756,26 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
756756

757757
return chart_type
758758

759+
@staticmethod
760+
def get_brief_from_sql_answer(res: str) -> Optional[str]:
761+
json_str = extract_nested_json(res)
762+
if json_str is None:
763+
return None
764+
765+
brief: Optional[str]
766+
data: dict
767+
try:
768+
data = orjson.loads(json_str)
769+
770+
if data['success']:
771+
brief = data['brief']
772+
else:
773+
return None
774+
except Exception:
775+
return None
776+
777+
return brief
778+
759779
def check_save_sql(self, session: Session, res: str) -> str:
760780
sql, *_ = self.check_sql(res=res)
761781
save_sql(session=session, sql=sql, record_id=self.record.id)
@@ -925,17 +945,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
925945
if not stream:
926946
json_result['record_id'] = self.get_record().id
927947

928-
# return title
929-
if self.change_title:
930-
if self.chat_question.question and self.chat_question.question.strip() != '':
931-
brief = rename_chat(session=_session,
932-
rename_object=RenameChat(id=self.get_record().chat_id,
933-
brief=self.chat_question.question.strip()[:20]))
934-
if in_chat:
935-
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
936-
if not stream:
937-
json_result['title'] = brief
938-
939948
# select datasource if datasource is none
940949
if not self.ds:
941950
ds_res = self.select_datasource(_session)
@@ -981,6 +990,19 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
981990

982991
chart_type = self.get_chart_type_from_sql_answer(full_sql_text)
983992

993+
# return title
994+
if self.change_title:
995+
llm_brief = self.get_brief_from_sql_answer(full_sql_text)
996+
if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''):
997+
save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20]
998+
brief = rename_chat(session=_session,
999+
rename_object=RenameChat(id=self.get_record().chat_id,
1000+
brief=save_brief))
1001+
if in_chat:
1002+
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
1003+
if not stream:
1004+
json_result['title'] = brief
1005+
9841006
use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types
9851007
is_page_embedded: bool = self.current_assistant and self.current_assistant.type == 4
9861008
dynamic_sql_result = None

backend/templates/template.yaml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ template:
1414
<step>4. 应用其他规则(引号、别名等)</step>
1515
<step>5. <strong>强制检查:检查语法是否正确?</strong></step>
1616
<step>6. 确定图表类型</step>
17-
<step>7. 返回JSON结果</step>
17+
<step>7. 确定对话标题</step>
18+
<step>8. 返回JSON结果</step>
1819
</SQL-Generation-Process>
1920
query_limit: |
2021
<rule priority="critical" id="data-limit-policy">
@@ -41,7 +42,7 @@ template:
4142
system: |
4243
<Instruction>
4344
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
44-
你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。
45+
你当前的任务是根据给定的表结构和用户问题生成SQL语句、对话标题、可能适合展示的图表类型以及该SQL中所用到的表名。
4546
我们会在<Info>块内提供给你信息,帮助你生成SQL:
4647
<Info>内有<db-engine><m-schema><terminologies>等信息;
4748
其中,<db-engine>:提供数据库引擎及版本信息;
@@ -72,7 +73,7 @@ template:
7273
</rule>
7374
<rule>
7475
请使用JSON格式返回你的回答:
75-
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
76+
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}}
7677
若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}}
7778
</rule>
7879
<rule>
@@ -112,6 +113,9 @@ template:
112113
<rule>
113114
我们目前的情况适用于单指标、多分类的场景(展示table除外)
114115
</rule>
116+
<rule>
117+
是否生成对话标题在<change-title>内,如果为True需要生成,否则不需要生成,生成的对话标题要求在20字以内
118+
</rule>
115119
</Rules>
116120
117121
{process_check}
@@ -251,6 +255,9 @@ template:
251255
<user-question>
252256
{question}
253257
</user-question>
258+
<change-title>
259+
{change_title}
260+
</change-title>
254261
255262
chart:
256263
system: |

0 commit comments

Comments
 (0)