@@ -524,7 +524,7 @@ def select_datasource(self, _session: Session):
524524 def generate_sql (self , _session : Session ):
525525 # append current question
526526 self .sql_message .append (HumanMessage (
527- self .chat_question .sql_user_question (current_time = datetime .now ().strftime ('%Y-%m-%d %H:%M:%S' ))))
527+ self .chat_question .sql_user_question (current_time = datetime .now ().strftime ('%Y-%m-%d %H:%M:%S' ), change_title = self . change_title )))
528528
529529 self .current_logs [OperationEnum .GENERATE_SQL ] = start_log (session = _session ,
530530 ai_modal_id = self .chat_question .ai_modal_id ,
@@ -756,6 +756,26 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
756756
757757 return chart_type
758758
759+ @staticmethod
760+ def get_brief_from_sql_answer (res : str ) -> Optional [str ]:
761+ json_str = extract_nested_json (res )
762+ if json_str is None :
763+ return None
764+
765+ brief : Optional [str ]
766+ data : dict
767+ try :
768+ data = orjson .loads (json_str )
769+
770+ if data ['success' ]:
771+ brief = data ['brief' ]
772+ else :
773+ return None
774+ except Exception :
775+ return None
776+
777+ return brief
778+
759779 def check_save_sql (self , session : Session , res : str ) -> str :
760780 sql , * _ = self .check_sql (res = res )
761781 save_sql (session = session , sql = sql , record_id = self .record .id )
@@ -925,17 +945,6 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
925945 if not stream :
926946 json_result ['record_id' ] = self .get_record ().id
927947
928- # return title
929- if self .change_title :
930- if self .chat_question .question and self .chat_question .question .strip () != '' :
931- brief = rename_chat (session = _session ,
932- rename_object = RenameChat (id = self .get_record ().chat_id ,
933- brief = self .chat_question .question .strip ()[:20 ]))
934- if in_chat :
935- yield 'data:' + orjson .dumps ({'type' : 'brief' , 'brief' : brief }).decode () + '\n \n '
936- if not stream :
937- json_result ['title' ] = brief
938-
939948 # select datasource if datasource is none
940949 if not self .ds :
941950 ds_res = self .select_datasource (_session )
@@ -981,6 +990,19 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
981990
982991 chart_type = self .get_chart_type_from_sql_answer (full_sql_text )
983992
993+ # return title
994+ if self .change_title :
995+ llm_brief = self .get_brief_from_sql_answer (full_sql_text )
996+ if (llm_brief and llm_brief != '' ) or (self .chat_question .question and self .chat_question .question .strip () != '' ):
997+ save_brief = llm_brief if (llm_brief and llm_brief != '' ) else self .chat_question .question .strip ()[:20 ]
998+ brief = rename_chat (session = _session ,
999+ rename_object = RenameChat (id = self .get_record ().chat_id ,
1000+ brief = save_brief ))
1001+ if in_chat :
1002+ yield 'data:' + orjson .dumps ({'type' : 'brief' , 'brief' : brief }).decode () + '\n \n '
1003+ if not stream :
1004+ json_result ['title' ] = brief
1005+
9841006 use_dynamic_ds : bool = self .current_assistant and self .current_assistant .type in dynamic_ds_types
9851007 is_page_embedded : bool = self .current_assistant and self .current_assistant .type == 4
9861008 dynamic_sql_result = None
0 commit comments