Skip to content

Commit e86c0d6

Browse files
authored
fix: Fixed the issue where the conversation title was not adjusted when the first sentence failed (#543)
1 parent 2ff3e10 commit e86c0d6

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

backend/apps/chat/curd/chat.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def get_chat_record_by_id(session: SessionDep, record_id: int):
2828
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by)
2929
return record
3030

31+
def get_chat(session: SessionDep, chat_id: int) -> Chat:
32+
statement = select(Chat).where(Chat.id == chat_id)
33+
chat = session.exec(statement).scalars().first()
34+
return chat
3135

3236
def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]:
3337
oid = current_user.oid if current_user.oid is not None else 1
@@ -57,6 +61,7 @@ def rename_chat(session: SessionDep, rename_object: RenameChat) -> str:
5761
raise Exception(f"Chat with id {rename_object.id} not found")
5862

5963
chat.brief = rename_object.brief.strip()[:20]
64+
chat.brief_generate = rename_object.brief_generate
6065
session.add(chat)
6166
session.flush()
6267
session.refresh(chat)
@@ -340,6 +345,13 @@ def format_record(record: ChatRecordResult):
340345

341346
return _dict
342347

348+
def get_chat_brief_generate(session: SessionDep, chat_id: int):
349+
chat = get_chat(session=session,chat_id=chat_id)
350+
if chat is not None and chat.brief_generate is not None:
351+
return chat.brief_generate
352+
else:
353+
return False
354+
343355

344356
def list_generate_sql_logs(session: SessionDep, chart_id: int) -> List[ChatLog]:
345357
stmt = select(ChatLog).where(

backend/apps/chat/models/chat_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class Chat(SQLModel, table=True):
7878
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
7979
engine_type: str = Field(max_length=64)
8080
origin: Optional[int] = Field(
81-
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
81+
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
82+
brief_generate: bool = Field(default=False)
8283

8384

8485
class ChatRecord(SQLModel, table=True):
@@ -149,6 +150,7 @@ class CreateChat(BaseModel):
149150
class RenameChat(BaseModel):
150151
id: int = None
151152
brief: str = ''
153+
brief_generate: bool = True
152154

153155

154156
class ChatInfo(BaseModel):

backend/apps/chat/task/llm.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
save_select_datasource_answer, save_recommend_question_answer, \
3030
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
3131
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
32-
get_last_execute_sql_error, format_json_data, format_chart_fields
32+
get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate
3333
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
3434
ChatFinishStep, AxisObj
3535
from apps.data_training.curd.data_training import get_training_template
@@ -117,7 +117,7 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
117117
self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
118118
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)
119119

120-
self.change_title = len(self.generate_sql_logs) == 0
120+
self.change_title = not get_chat_brief_generate(session=session, chat_id=chat_id)
121121

122122
chat_question.lang = get_lang_name(current_user.language)
123123

@@ -528,7 +528,8 @@ def select_datasource(self, _session: Session):
528528
def generate_sql(self, _session: Session):
529529
# append current question
530530
self.sql_message.append(HumanMessage(
531-
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title)))
531+
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
532+
change_title=self.change_title)))
532533

533534
self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session,
534535
ai_modal_id=self.chat_question.ai_modal_id,
@@ -997,11 +998,13 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
997998
# return title
998999
if self.change_title:
9991000
llm_brief = self.get_brief_from_sql_answer(full_sql_text)
1000-
if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''):
1001-
save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20]
1001+
llm_brief_generated = bool(llm_brief)
1002+
if llm_brief_generated or (self.chat_question.question and self.chat_question.question.strip() != ''):
1003+
save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[
1004+
:20]
10021005
brief = rename_chat(session=_session,
10031006
rename_object=RenameChat(id=self.get_record().chat_id,
1004-
brief=save_brief))
1007+
brief=save_brief, brief_generate=llm_brief_generated))
10051008
if in_chat:
10061009
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
10071010
if not stream:
@@ -1084,7 +1087,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10841087
for field in result.get('fields'):
10851088
_column_list.append(AxisObj(name=field, value=field))
10861089

1087-
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data'))
1090+
md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list,
1091+
result.get('data'))
10881092

10891093
# data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data'))
10901094

@@ -1203,8 +1207,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
12031207
self.finish(_session)
12041208
session_maker.remove()
12051209

1206-
1207-
12081210
def run_recommend_questions_task_async(self):
12091211
self.future = executor.submit(self.run_recommend_questions_task_cache)
12101212

0 commit comments

Comments
 (0)