1010from apps .ai_model .model_factory import LLMConfig , LLMFactory , get_llm_config
1111from apps .chat .curd .chat import save_question , save_full_sql_message , save_full_sql_message_and_answer , save_sql , \
1212 save_error_message , save_sql_exec_data , save_full_chart_message , save_full_chart_message_and_answer , save_chart , \
13- finish_record , save_full_analysis_message_and_answer
13+ finish_record , save_full_analysis_message_and_answer , save_full_predict_message_and_answer , save_predict_data
1414from apps .chat .models .chat_model import ChatQuestion , ChatRecord
1515from apps .datasource .models .datasource import CoreDatasource
1616from apps .db .db import exec_sql
@@ -112,7 +112,7 @@ def get_record(self):
112112 def set_record (self , record : ChatRecord ):
113113 self .record = record
114114
115- def generate_analysis (self , session : SessionDep ):
115+ def get_fields_from_chart (self ):
116116 chart_info = orjson .loads (self .record .chart )
117117 fields = []
118118 if chart_info .get ('columns' ) and len (chart_info .get ('columns' )) > 0 :
@@ -129,6 +129,10 @@ def generate_analysis(self, session: SessionDep):
129129 if column .get ('value' ) != column .get ('name' ):
130130 column_str = column_str + '(' + column .get ('name' ) + ')'
131131 fields .append (column_str )
132+ return fields
133+
134+ def generate_analysis (self , session : SessionDep ):
135+ fields = self .get_fields_from_chart ()
132136
133137 self .chat_question .fields = orjson .dumps (fields ).decode ()
134138 self .chat_question .data = orjson .dumps (orjson .loads (self .record .data ).get ('data' )).decode ()
@@ -169,6 +173,49 @@ def generate_analysis(self, session: SessionDep):
169173 in
170174 analysis_msg ]).decode ())
171175
176+ def generate_predict (self , session : SessionDep ):
177+ fields = self .get_fields_from_chart ()
178+
179+ self .chat_question .fields = orjson .dumps (fields ).decode ()
180+ self .chat_question .data = orjson .dumps (orjson .loads (self .record .data ).get ('data' )).decode ()
181+ predict_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
182+ predict_msg .append (SystemMessage (content = self .chat_question .predict_sys_question ()))
183+ predict_msg .append (HumanMessage (content = self .chat_question .predict_user_question ()))
184+
185+ history_msg = []
186+ if self .record .full_predict_message and self .record .full_predict_message .strip () != '' :
187+ history_msg = orjson .loads (self .record .full_predict_message )
188+
189+ self .record = save_full_predict_message_and_answer (session = session , record_id = self .record .id , answer = '' ,
190+ data = '' ,
191+ full_message = orjson .dumps (history_msg +
192+ [{'type' : msg .type ,
193+ 'content' : msg .content } for msg
194+ in
195+ predict_msg ]).decode ())
196+
197+ full_predict_text = ''
198+ res = self .llm .stream (predict_msg )
199+ for chunk in res :
200+ print (chunk )
201+ if isinstance (chunk , dict ):
202+ full_predict_text += chunk ['content' ]
203+ yield chunk ['content' ]
204+ continue
205+ if isinstance (chunk , AIMessageChunk ):
206+ full_predict_text += chunk .content
207+ yield chunk .content
208+ continue
209+
210+ predict_msg .append (AIMessage (full_predict_text ))
211+ self .record = save_full_predict_message_and_answer (session = session , record_id = self .record .id ,
212+ answer = full_predict_text , data = '' ,
213+ full_message = orjson .dumps (history_msg +
214+ [{'type' : msg .type ,
215+ 'content' : msg .content } for msg
216+ in
217+ predict_msg ]).decode ())
218+
172219 def generate_sql (self , session : SessionDep ):
173220 # append current question
174221 self .sql_message .append (HumanMessage (self .chat_question .sql_user_question ()))
@@ -274,6 +321,17 @@ def check_save_chart(self, session: SessionDep, res: str) -> Dict[str, Any]:
274321
275322 return chart
276323
324+ def check_save_predict_data (self , session : SessionDep , res : str ) -> Dict [str , Any ]:
325+
326+ json_str = extract_nested_json (res )
327+
328+ if not json_str :
329+ json_str = ''
330+
331+ save_predict_data (session = session , record_id = self .record .id , data = json_str )
332+
333+ return json_str
334+
277335 def save_error (self , session : SessionDep , message : str ):
278336 return save_error_message (session = session , record_id = self .record .id , message = message )
279337
0 commit comments