1717from apps .chat .curd .chat import save_question , save_full_sql_message , save_full_sql_message_and_answer , save_sql , \
1818 save_error_message , save_sql_exec_data , save_full_chart_message , save_full_chart_message_and_answer , save_chart , \
1919 finish_record , save_full_analysis_message_and_answer , save_full_predict_message_and_answer , save_predict_data , \
20- save_full_select_datasource_message_and_answer , list_records
20+ save_full_select_datasource_message_and_answer , list_records , save_full_recommend_question_message_and_answer , \
21+ get_old_questions
2122from apps .chat .models .chat_model import ChatQuestion , ChatRecord , Chat
2223from apps .datasource .crud .datasource import get_table_schema
2324from apps .datasource .models .datasource import CoreDatasource
2425from apps .db .db import exec_sql
2526from common .core .config import settings
2627from common .core .deps import SessionDep , CurrentUser
28+ from common .utils .utils import extract_nested_json
2729
2830warnings .filterwarnings ("ignore" )
2931
@@ -59,7 +61,6 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
5961
6062 chat_question .engine = ds .type_name if ds .type != 'excel' else 'PostgreSQL'
6163
62-
6364 history_records : List [ChatRecord ] = list (
6465 map (lambda x : ChatRecord (** x .model_dump ()), filter (lambda r : True if r .first_chat != True else False ,
6566 list_records (session = self .session ,
@@ -75,7 +76,6 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
7576 self .chat_question = chat_question
7677 self .config = get_default_config ()
7778 self .chat_question .ai_modal_id = self .config .model_id
78-
7979
8080 # Create LLM instance through factory
8181 llm_instance = LLMFactory .create_llm (self .config )
@@ -176,7 +176,7 @@ def get_fields_from_chart(self):
176176 fields .append (column_str )
177177 return fields
178178
179- def generate_analysis (self , session : SessionDep ):
179+ def generate_analysis (self ):
180180 fields = self .get_fields_from_chart ()
181181
182182 self .chat_question .fields = orjson .dumps (fields ).decode ()
@@ -189,7 +189,7 @@ def generate_analysis(self, session: SessionDep):
189189 if self .record .full_analysis_message and self .record .full_analysis_message .strip () != '' :
190190 history_msg = orjson .loads (self .record .full_analysis_message )
191191
192- self .record = save_full_analysis_message_and_answer (session = session , record_id = self .record .id , answer = '' ,
192+ self .record = save_full_analysis_message_and_answer (session = self . session , record_id = self .record .id , answer = '' ,
193193 full_message = orjson .dumps (history_msg +
194194 [{'type' : msg .type ,
195195 'content' : msg .content } for msg
@@ -210,7 +210,7 @@ def generate_analysis(self, session: SessionDep):
210210 continue
211211
212212 analysis_msg .append (AIMessage (full_analysis_text ))
213- self .record = save_full_analysis_message_and_answer (session = session , record_id = self .record .id ,
213+ self .record = save_full_analysis_message_and_answer (session = self . session , record_id = self .record .id ,
214214 answer = full_analysis_text ,
215215 full_message = orjson .dumps (history_msg +
216216 [{'type' : msg .type ,
@@ -261,6 +261,47 @@ def generate_predict(self):
261261 in
262262 predict_msg ]).decode ())
263263
264+ def generate_recommend_questions_task (self ):
265+
266+ # get schema
267+ if self .ds and not self .chat_question .db_schema :
268+ self .chat_question .db_schema = get_table_schema (session = self .session , ds = self .ds )
269+
270+ guess_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
271+ guess_msg .append (SystemMessage (content = self .chat_question .guess_sys_question ()))
272+ # todo old questions
273+ old_questions = list (map (lambda q : q [0 ].strip (), get_old_questions (self .session , self .record .datasource )))
274+ guess_msg .append (HumanMessage (content = self .chat_question .guess_user_question (orjson .dumps (old_questions ).decode ())))
275+
276+ self .record = save_full_recommend_question_message_and_answer (session = self .session , record_id = self .record .id ,
277+ answer = '' ,
278+ full_message = orjson .dumps ([{'type' : msg .type ,
279+ 'content' : msg .content }
280+ for msg
281+ in
282+ guess_msg ]).decode ())
283+
284+ full_guess_text = ''
285+ res = self .llm .stream (guess_msg )
286+ for chunk in res :
287+ print (chunk )
288+ if isinstance (chunk , dict ):
289+ full_guess_text += chunk ['content' ]
290+ continue
291+ if isinstance (chunk , AIMessageChunk ):
292+ full_guess_text += chunk .content
293+ continue
294+
295+ guess_msg .append (AIMessage (full_guess_text ))
296+ self .record = save_full_recommend_question_message_and_answer (session = self .session , record_id = self .record .id ,
297+ answer = full_guess_text ,
298+ full_message = orjson .dumps ([{'type' : msg .type ,
299+ 'content' : msg .content }
300+ for msg
301+ in
302+ guess_msg ]).decode ())
303+ return self .record .recommended_question
304+
264305 def select_datasource (self ):
265306 datasource_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
266307 datasource_msg .append (SystemMessage (self .chat_question .datasource_sys_question ()))
@@ -486,33 +527,6 @@ def execute_sql(self, sql: str):
486527 return exec_sql (self .ds , sql )
487528
488529
489- def extract_nested_json (text ):
490- stack = []
491- start_index = - 1
492- results = []
493-
494- for i , char in enumerate (text ):
495- if char in '{[' :
496- if not stack : # 记录起始位置
497- start_index = i
498- stack .append (char )
499- elif char in '}]' :
500- if stack and ((char == '}' and stack [- 1 ] == '{' ) or (char == ']' and stack [- 1 ] == '[' )):
501- stack .pop ()
502- if not stack : # 栈空时截取完整JSON
503- json_str = text [start_index :i + 1 ]
504- try :
505- orjson .loads (json_str ) # 验证有效性
506- results .append (json_str )
507- except :
508- pass
509- else :
510- stack = [] # 括号不匹配则重置
511- if len (results ) > 0 and results [0 ]:
512- return results [0 ]
513- return None
514-
515-
516530def execute_sql_with_db (db : SQLDatabase , sql : str ) -> str :
517531 """Execute SQL query using SQLDatabase
518532
@@ -647,6 +661,42 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True)
647661 yield f'> ❌ **ERROR**\n \n > \n \n > { str (e )} 。'
648662
649663
664+ def run_analysis_or_predict_task (llm_service : LLMService , action_type : str ):
665+ try :
666+ if action_type == 'analysis' :
667+ # generate analysis
668+ analysis_res = llm_service .generate_analysis ()
669+ for chunk in analysis_res :
670+ yield orjson .dumps ({'content' : chunk , 'type' : 'analysis-result' }).decode () + '\n \n '
671+ yield orjson .dumps ({'type' : 'info' , 'msg' : 'analysis generated' }).decode () + '\n \n '
672+
673+ yield orjson .dumps ({'type' : 'analysis_finish' }).decode () + '\n \n '
674+
675+ elif action_type == 'predict' :
676+ # generate predict
677+ analysis_res = llm_service .generate_predict ()
678+ full_text = ''
679+ for chunk in analysis_res :
680+ yield orjson .dumps ({'content' : chunk , 'type' : 'predict-result' }).decode () + '\n \n '
681+ full_text += chunk
682+ yield orjson .dumps ({'type' : 'info' , 'msg' : 'predict generated' }).decode () + '\n \n '
683+
684+ _data = llm_service .check_save_predict_data (res = full_text )
685+ yield orjson .dumps ({'type' : 'predict' , 'content' : _data }).decode () + '\n \n '
686+
687+ yield orjson .dumps ({'type' : 'predict_finish' }).decode () + '\n \n '
688+
689+
690+ except Exception as e :
691+ traceback .print_exc ()
692+ # llm_service.save_error(session=session, message=str(e))
693+ yield orjson .dumps ({'content' : str (e ), 'type' : 'error' }).decode () + '\n \n '
694+
695+
696+ def run_recommend_questions_task (llm_service : LLMService ):
697+ return llm_service .generate_recommend_questions_task ()
698+
699+
650700def request_picture (chat_id : int , record_id : int , chart : dict , data : dict ):
651701 file_name = f'c_{ chat_id } _r_{ record_id } '
652702
0 commit comments