@@ -774,7 +774,7 @@ def check_save_sql(self, session: Session, res: str) -> str:
774774
775775 return sql
776776
777- def check_save_chart (self , session : Session , res : str ) -> Dict [str , Any ]:
777+ def check_save_chart (self , session : Session , res : str , sql_prase : str ) -> Dict [str , Any ]:
778778
779779 json_str = extract_nested_json (res )
780780 if json_str is None :
@@ -814,7 +814,7 @@ def check_save_chart(self, session: Session, res: str) -> Dict[str, Any]:
814814 if error :
815815 raise SingleMessageError (message )
816816
817- save_chart (session = session , chart = orjson .dumps (chart ).decode (), record_id = self .record .id )
817+ save_chart (session = session , chart = orjson .dumps (chart ).decode (), record_id = self .record .id , sql_prase = sql_prase )
818818
819819 return chart
820820
@@ -989,6 +989,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
989989
990990 use_dynamic_ds : bool = self .current_assistant and self .current_assistant .type in dynamic_ds_types
991991 is_page_embedded : bool = self .current_assistant and self .current_assistant .type == 4
992+ is_assistant_embedded : bool = self .current_assistant and self .current_assistant .type == 1
992993 dynamic_sql_result = None
993994 sqlbot_temp_sql_text = None
994995 assistant_dynamic_sql = None
@@ -1087,7 +1088,16 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10871088
10881089 # filter chart
10891090 SQLBotLogUtil .info (full_chart_text )
1090- chart = self .check_save_chart (session = _session , res = full_chart_text )
1091+
1092+ # sql prase
1093+ if is_assistant_embedded :
1094+ sql_prase = self .generate_sql_paras (_session ,real_execute_sql ,full_chart_text )
1095+ if in_chat :
1096+ yield 'data:' + orjson .dumps (
1097+ {'content' : sql_prase ,
1098+ 'type' : 'sql_prase' }).decode () + '\n \n '
1099+
1100+ chart = self .check_save_chart (session = _session , res = full_chart_text ,sql_prase = sql_prase )
10911101 SQLBotLogUtil .info (chart )
10921102
10931103 if not stream :
@@ -1333,6 +1343,49 @@ def validate_history_ds(self, session: Session):
13331343 except Exception as e :
13341344 raise SingleMessageError (f"ds is invalid [{ str (e )} ]" )
13351345
1346+ def generate_sql_paras (self , _session : Session , real_execute_sql : Optional [str ] = '' ,chart : Optional [str ] = '' ):
1347+ # prase sql
1348+ prase_sql_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
1349+ prase_sql_msg .append (SystemMessage (self .chat_question .prase_sql_sys_question ()))
1350+ prase_sql_msg .append (HumanMessage (self .chat_question .prase_sql_user_question (real_execute_sql ,chart )))
1351+ self .current_logs [OperationEnum .PRASE_SQL ] = start_log (session = _session ,
1352+ ai_modal_id = self .chat_question .ai_modal_id ,
1353+ ai_modal_name = self .chat_question .ai_modal_name ,
1354+ operate = OperationEnum .PRASE_SQL ,
1355+ record_id = self .record .id ,
1356+ full_message = [{'type' : msg .type ,
1357+ 'content' : msg .content }
1358+ for
1359+ msg in prase_sql_msg ])
1360+
1361+ token_usage = {}
1362+ prase_res = process_stream (self .llm .stream (prase_sql_msg ), token_usage )
1363+ prase_full_thinking_text = ''
1364+ prase_full_text = ''
1365+ for chunk in prase_res :
1366+ if chunk .get ('content' ):
1367+ prase_full_text += chunk .get ('content' )
1368+ if chunk .get ('reasoning_content' ):
1369+ prase_full_thinking_text += chunk .get ('reasoning_content' )
1370+ prase_sql_msg .append (AIMessage (prase_full_text ))
1371+
1372+ self .current_logs [OperationEnum .PRASE_SQL ] = end_log (session = _session ,
1373+ log = self .current_logs [
1374+ OperationEnum .PRASE_SQL ],
1375+ full_message = [
1376+ {'type' : msg .type ,
1377+ 'content' : msg .content }
1378+ for msg in prase_sql_msg ],
1379+ reasoning_content = prase_full_thinking_text ,
1380+ token_usage = token_usage )
1381+
1382+ prase_json_str = extract_nested_json (prase_full_text )
1383+ return prase_json_str
1384+ # if prase_json_str is None:
1385+ # raise SingleMessageError(f'Cannot parse datasource from answer: {prase_full_text}')
1386+ # ds = orjson.loads(prase_json_str)
1387+ # return ds['info']
1388+
13361389
13371390def execute_sql_with_db (db : SQLDatabase , sql : str ) -> str :
13381391 """Execute SQL query using SQLDatabase
@@ -1505,7 +1558,6 @@ def process_stream(res: Iterator[BaseMessageChunk],
15051558 }
15061559 get_token_usage (chunk , token_usage )
15071560
1508-
15091561def get_lang_name (lang : str ):
15101562 if not lang :
15111563 return '简体中文'
0 commit comments