Skip to content

Commit 8a03ea8

Browse files
committed
feat: improve chat template
1 parent 9c75512 commit 8a03ea8

File tree

3 files changed

+310
-104
lines changed

3 files changed

+310
-104
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,16 @@ def sql_sys_question(self):
176176
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
177177
lang=self.lang, terminologies=self.terminologies)
178178

179-
def sql_user_question(self):
179+
def sql_user_question(self, current_time: str):
180180
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
181-
rule=self.rule)
181+
rule=self.rule, current_time=current_time)
182182

183183
def chart_sys_question(self):
184184
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
185185

186-
def chart_user_question(self):
187-
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule)
186+
def chart_user_question(self, chart_type: Optional[str] = None):
187+
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
188+
chart_type=chart_type)
188189

189190
def analysis_sys_question(self):
190191
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies)

backend/apps/chat/task/llm.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import traceback
44
import warnings
55
from concurrent.futures import ThreadPoolExecutor, Future
6+
from datetime import datetime
67
from typing import Any, List, Optional, Union, Dict
78

89
import numpy as np
@@ -70,7 +71,8 @@ class LLMService:
7071
future: Future
7172

7273
def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
73-
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False, config: LLMConfig = None):
74+
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False,
75+
config: LLMConfig = None):
7476
self.chunk_list = []
7577
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
7678
session_maker = sessionmaker(bind=engine)
@@ -126,7 +128,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
126128
self.llm = llm_instance.llm
127129

128130
self.init_messages()
129-
131+
130132
@classmethod
131133
async def create(cls, *args, **kwargs):
132134
config: LLMConfig = await get_default_config()
@@ -503,7 +505,8 @@ def select_datasource(self):
503505

504506
def generate_sql(self):
505507
# append current question
506-
self.sql_message.append(HumanMessage(self.chat_question.sql_user_question()))
508+
self.sql_message.append(HumanMessage(
509+
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))))
507510

508511
self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=self.session,
509512
ai_modal_id=self.chat_question.ai_modal_id,
@@ -670,9 +673,9 @@ def generate_assistant_filter(self, sql, tables: List):
670673
return None
671674
return self.build_table_filter(sql=sql, filters=filters)
672675

673-
def generate_chart(self):
676+
def generate_chart(self, chart_type: Optional[str] = ''):
674677
# append current question
675-
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question()))
678+
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type)))
676679

677680
self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=self.session,
678681
ai_modal_id=self.chat_question.ai_modal_id,
@@ -714,7 +717,8 @@ def generate_chart(self):
714717
reasoning_content=full_thinking_text,
715718
token_usage=token_usage)
716719

717-
def check_sql(self, res: str) -> tuple[any]:
720+
@staticmethod
721+
def check_sql(res: str) -> tuple[str, Optional[list]]:
718722
json_str = extract_nested_json(res)
719723
if json_str is None:
720724
raise SingleMessageError(orjson.dumps({'message': 'Cannot parse sql from answer',
@@ -739,6 +743,26 @@ def check_sql(self, res: str) -> tuple[any]:
739743
raise SingleMessageError("SQL query is empty")
740744
return sql, data.get('tables')
741745

746+
@staticmethod
747+
def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
748+
json_str = extract_nested_json(res)
749+
if json_str is None:
750+
return None
751+
752+
chart_type: Optional[str]
753+
data: dict
754+
try:
755+
data = orjson.loads(json_str)
756+
757+
if data['success']:
758+
chart_type = data['chart-type']
759+
else:
760+
return None
761+
except Exception:
762+
return None
763+
764+
return chart_type
765+
742766
def check_save_sql(self, res: str) -> str:
743767
sql, *_ = self.check_sql(res=res)
744768
save_sql(session=self.session, sql=sql, record_id=self.record.id)
@@ -921,6 +945,9 @@ def run_task(self, in_chat: bool = True):
921945

922946
# filter sql
923947
SQLBotLogUtil.info(full_sql_text)
948+
949+
chart_type = self.get_chart_type_from_sql_answer(full_sql_text)
950+
924951
use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types
925952

926953
# todo row permission
@@ -962,7 +989,7 @@ def run_task(self, in_chat: bool = True):
962989
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'
963990

964991
# generate chart
965-
chart_res = self.generate_chart()
992+
chart_res = self.generate_chart(chart_type)
966993
full_chart_text = ''
967994
for chunk in chart_res:
968995
full_chart_text += chunk.get('content')

0 commit comments

Comments
 (0)