1717from apps .chat .curd .chat import save_question , save_full_sql_message , save_full_sql_message_and_answer , save_sql , \
1818 save_error_message , save_sql_exec_data , save_full_chart_message , save_full_chart_message_and_answer , save_chart , \
1919 finish_record , save_full_analysis_message_and_answer , save_full_predict_message_and_answer , save_predict_data , \
20- save_full_select_datasource_message_and_answer , list_records , save_full_recommend_question_message_and_answer , \
21- get_old_questions
20+ save_full_select_datasource_message_and_answer , save_full_recommend_question_message_and_answer , \
21+ get_old_questions , save_analysis_predict_record , list_base_records
2222from apps .chat .models .chat_model import ChatQuestion , ChatRecord , Chat
2323from apps .datasource .crud .datasource import get_table_schema
2424from apps .datasource .models .datasource import CoreDatasource
@@ -63,9 +63,9 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
6363
6464 history_records : List [ChatRecord ] = list (
6565 map (lambda x : ChatRecord (** x .model_dump ()), filter (lambda r : True if r .first_chat != True else False ,
66- list_records (session = self .session ,
67- current_user = current_user ,
68- chart_id = chat_question .chat_id ))))
66+ list_base_records (session = self .session ,
67+ current_user = current_user ,
68+ chart_id = chat_question .chat_id ))))
6969 # get schema
7070 if ds :
7171 chat_question .db_schema = get_table_schema (session = self .session , ds = ds )
@@ -606,7 +606,7 @@ def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
606606 raise RuntimeError (error_msg )
607607
608608
609- def run_task (llm_service : LLMService , session : SessionDep , in_chat : bool = True ):
609+ def run_task (llm_service : LLMService , in_chat : bool = True ):
610610 try :
611611 # return id
612612 if in_chat :
@@ -626,7 +626,7 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True)
626626 yield orjson .dumps ({'id' : llm_service .ds .id , 'datasource_name' : llm_service .ds .name ,
627627 'engine_type' : llm_service .ds .type_name , 'type' : 'datasource' }).decode () + '\n \n '
628628
629- llm_service .chat_question .db_schema = get_table_schema (session = session , ds = llm_service .ds )
629+ llm_service .chat_question .db_schema = get_table_schema (session = llm_service . session , ds = llm_service .ds )
630630
631631 # generate sql
632632 sql_res = llm_service .generate_sql ()
@@ -720,8 +720,10 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True)
720720 yield f'> ❌ **ERROR**\n \n > \n \n > { str (e )} 。'
721721
722722
723- def run_analysis_or_predict_task (llm_service : LLMService , action_type : str ):
723+ def run_analysis_or_predict_task (llm_service : LLMService , action_type : str , base_record : ChatRecord ):
724724 try :
725+ llm_service .set_record (save_analysis_predict_record (llm_service .session , base_record , action_type ))
726+
725727 if action_type == 'analysis' :
726728 # generate analysis
727729 analysis_res = llm_service .generate_analysis ()
@@ -752,10 +754,10 @@ def run_analysis_or_predict_task(llm_service: LLMService, action_type: str):
752754
753755 yield orjson .dumps ({'type' : 'predict_finish' }).decode () + '\n \n '
754756
755-
757+ llm_service . finish ()
756758 except Exception as e :
757759 traceback .print_exc ()
758- # llm_service.save_error(session=session, message=str(e))
760+ llm_service .save_error (message = str (e ))
759761 yield orjson .dumps ({'content' : str (e ), 'type' : 'error' }).decode () + '\n \n '
760762 finally :
761763 # end
0 commit comments