|
29 | 29 | save_select_datasource_answer, save_recommend_question_answer, \ |
30 | 30 | get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ |
31 | 31 | get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ |
32 | | - get_last_execute_sql_error, format_json_data, format_chart_fields |
| 32 | + get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate |
33 | 33 | from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ |
34 | 34 | ChatFinishStep, AxisObj |
35 | 35 | from apps.data_training.curd.data_training import get_training_template |
@@ -117,7 +117,7 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C |
117 | 117 | self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id) |
118 | 118 | self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id) |
119 | 119 |
|
120 | | - self.change_title = len(self.generate_sql_logs) == 0 |
| 120 | + self.change_title = not get_chat_brief_generate(session=session, chat_id=chat_id) |
121 | 121 |
|
122 | 122 | chat_question.lang = get_lang_name(current_user.language) |
123 | 123 |
|
@@ -528,7 +528,8 @@ def select_datasource(self, _session: Session): |
528 | 528 | def generate_sql(self, _session: Session): |
529 | 529 | # append current question |
530 | 530 | self.sql_message.append(HumanMessage( |
531 | | - self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),change_title = self.change_title))) |
| 531 | + self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
| 532 | + change_title=self.change_title))) |
532 | 533 |
|
533 | 534 | self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=_session, |
534 | 535 | ai_modal_id=self.chat_question.ai_modal_id, |
@@ -997,11 +998,13 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
997 | 998 | # return title |
998 | 999 | if self.change_title: |
999 | 1000 | llm_brief = self.get_brief_from_sql_answer(full_sql_text) |
1000 | | - if (llm_brief and llm_brief != '') or (self.chat_question.question and self.chat_question.question.strip() != ''): |
1001 | | - save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[:20] |
| 1001 | + llm_brief_generated = bool(llm_brief) |
| 1002 | + if llm_brief_generated or (self.chat_question.question and self.chat_question.question.strip() != ''): |
| 1003 | + save_brief = llm_brief if (llm_brief and llm_brief != '') else self.chat_question.question.strip()[ |
| 1004 | + :20] |
1002 | 1005 | brief = rename_chat(session=_session, |
1003 | 1006 | rename_object=RenameChat(id=self.get_record().chat_id, |
1004 | | - brief=save_brief)) |
| 1007 | + brief=save_brief, brief_generate=llm_brief_generated)) |
1005 | 1008 | if in_chat: |
1006 | 1009 | yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' |
1007 | 1010 | if not stream: |
@@ -1084,7 +1087,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
1084 | 1087 | for field in result.get('fields'): |
1085 | 1088 | _column_list.append(AxisObj(name=field, value=field)) |
1086 | 1089 |
|
1087 | | - md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, result.get('data')) |
| 1090 | + md_data, _fields_list = DataFormat.convert_object_array_for_pandas(_column_list, |
| 1091 | + result.get('data')) |
1088 | 1092 |
|
1089 | 1093 | # data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data')) |
1090 | 1094 |
|
@@ -1203,8 +1207,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
1203 | 1207 | self.finish(_session) |
1204 | 1208 | session_maker.remove() |
1205 | 1209 |
|
1206 | | - |
1207 | | - |
1208 | 1210 | def run_recommend_questions_task_async(self): |
1209 | 1211 | self.future = executor.submit(self.run_recommend_questions_task_cache) |
1210 | 1212 |
|
|
0 commit comments