55from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
66from ai .agents import oai
77from .agent import Agent
8+ import ast
9+ import re
10+ from ai .backend .util import base_util
811from ai .agents .code_utils import (
912 DEFAULT_MODEL ,
1013 UNKNOWN ,
2225try :
2326 from termcolor import colored
2427except ImportError :
25-
2628 def colored (x , * args , ** kwargs ):
2729 return x
2830
2931
32+ # 函数,用于精确到小数点后两位
33+ def format_decimal (value ):
34+ if isinstance (value , float ):
35+ return round (value , 2 )
36+ elif isinstance (value , int ):
37+ return value
38+ return value
39+
40+
3041class PythonProxyAgent (Agent ):
3142 """(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
3243
@@ -66,7 +77,6 @@ def __init__(
6677 db_id : Optional = None ,
6778 is_log_out : Optional [bool ] = True ,
6879 report_file_name : Optional [str ] = None ,
69-
7080 ):
7181 """
7282 Args:
@@ -112,6 +122,7 @@ def __init__(
112122 """
113123 super ().__init__ (name )
114124 # a dictionary of conversations, default value is list
125+ self .delay_messages = None
115126 self ._oai_messages = defaultdict (list )
116127 self ._oai_system_message = [{"content" : system_message , "role" : "system" }]
117128 self ._is_termination_msg = (
@@ -147,6 +158,7 @@ def __init__(
147158 self .db_id = db_id
148159 self .is_log_out = is_log_out
149160 self .report_file_name = report_file_name
161+ delay_messages = self .delay_messages
150162
151163 def register_reply (
152164 self ,
@@ -661,15 +673,16 @@ def generate_oai_reply(
661673
662674 return True , oai .ChatCompletion .extract_text_or_function_call (response )[0 ]
663675
664- def generate_code_execution_reply (
676+ async def generate_code_execution_reply (
665677 self ,
666678 messages : Optional [List [Dict ]] = None ,
667679 sender : Optional [Agent ] = None ,
668680 config : Optional [Any ] = None ,
681+
669682 ):
670683 """Generate a reply using code execution.
671684 """
672-
685+ from ai . agents . agent_instance_util import AgentInstanceUtil
673686 code_execution_config = config if config is not None else self ._code_execution_config
674687 # print('self._code_execution_config :', self._code_execution_config)
675688
@@ -678,6 +691,7 @@ def generate_code_execution_reply(
678691 if messages is None :
679692 messages = self ._oai_messages [sender ]
680693 last_n_messages = code_execution_config .pop ("last_n_messages" , 1 )
694+ base_content = []
681695
682696 # iterate through the last n messages reversly
683697 # if code blocks are found, execute the code blocks and return the output
@@ -693,6 +707,7 @@ def generate_code_execution_reply(
693707
694708 if len (code_blocks ) == 1 and code_blocks [0 ][0 ] != 'python' :
695709 continue
710+ code_blocks = self .regex_fix_date_format (code_blocks )
696711
697712 if self .db_id is not None :
698713 obj = database_util .Main (self .db_id )
@@ -703,28 +718,110 @@ def generate_code_execution_reply(
703718 code_blocks ]
704719
705720 # code_blocks = self.replace_ab_with_ac(code_blocks, db_info)
706- print ('new_code_blocks : ' , code_blocks )
721+ # print('new_code_blocks : ', code_blocks)
707722
708723 # found code blocks, execute code and push "last_n_messages" back
709724 exitcode , logs = self .execute_code_blocks (code_blocks )
710725 code_execution_config ["last_n_messages" ] = last_n_messages
711726 exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
712-
713727 length = 10000
714- length1 = 10001
715728 if not str (logs ).__contains__ ('echart_name' ):
716729 if len (logs ) > length :
717730 print (' ++++++++++ Length exceeds 10000 characters limit, cropped +++++++++++++++++' )
718731 logs = logs [:length ]
719- else :
720- if len (logs ) > length1 :
721- print (' ++++++++++ Length exceeds 10001 characters limit, cropped +++++++++++++++++' )
722- logs = "The echarts code is too long, please simplify the code or data (for example, only keep two decimal places), and ensure that the echarts code length does not exceed 10001"
732+ return True , f"exitcode: { exitcode } ({ exitcode2str } )\n Code output: { logs } "
723733
724734
725- return True , f"exitcode: { exitcode } ({ exitcode2str } )\n Code output: { logs } "
735+ else :
736+ try :
737+ if "'echart_name'" in str (logs ):
738+ logs = json .dumps (eval (str (logs )))
739+ logs = json .loads (str (logs ))
740+ except Exception as e :
741+ return True ,f"exitcode:exitcode failed\n Code output: There is an error in the JSON code causing parsing errors,Please modify the JSON code for me:{ traceback .format_exc ()} "
742+ for entry in logs :
743+ if 'echart_name' in entry and 'echart_code' in entry :
744+ if isinstance (entry ['echart_code' ], str ):
745+ entry ['echart_code' ] = json .loads (entry ['entry' ]['echart_code' ])
746+ if "series" in entry ['echart_code' ]:
747+ series_data = entry ['echart_code' ]['series' ]
748+ formatted_series_list = []
749+ for series_data in series_data :
750+ if series_data ['type' ] in ["bar" , "line" ]:
751+ formatted_series_data = [format_decimal (value ) for value in series_data ['data' ]]
752+ elif series_data ['type' ] in ["pie" , "gauge" , "funnel" ]:
753+ formatted_series_data = [{"name" : d ["name" ], "value" : format_decimal (d ["value" ])} for
754+ d in series_data ['data' ]]
755+ elif series_data ['type' ] in ['graph' ]:
756+ formatted_series_data = [
757+ {'name' : data_point ['name' ], 'symbolSize' : format_decimal (data_point ['symbolSize' ])}
758+ for data_point in series_data ['data' ]]
759+ elif series_data ['type' ] in ["Kline" , "radar" , "heatmap" , "scatter" , "themeRiver" ,
760+ 'parallel' , 'effectScatter' ]:
761+ formatted_series_data = [[format_decimal (value ) for value in sublist ] for sublist in
762+ series_data ['data' ]]
763+ else :
764+ formatted_series_data = series_data ['data' ]
765+ series_data ['data' ] = formatted_series_data
766+ formatted_series_list .append (series_data )
767+ entry ['echart_code' ]['series' ] = formatted_series_list
768+ base_content .append (entry )
769+
770+ agent_instance_util = AgentInstanceUtil (user_name = str (self .user_name ),
771+ delay_messages = self .delay_messages ,
772+ outgoing = self .outgoing ,
773+ incoming = self .incoming ,
774+ websocket = self .websocket
775+ )
776+ bi_proxy = agent_instance_util .get_agent_bi_proxy ()
777+ is_chart = False
778+ # Call the interface to generate pictures
779+ for img_str in base_content :
780+ echart_name = img_str .get ('echart_name' )
781+ echart_code = img_str .get ('echart_code' )
782+
783+ if len (echart_code ) > 0 and str (echart_code ).__contains__ ('x' ):
784+ is_chart = True
785+ print ("echart_name : " , echart_name )
786+ # 格式化echart_code
787+ # if base_util.is_json(str(echart_code)):
788+ # json_obj = json.loads(str(echart_code))
789+ # echart_code = json.dumps(json_obj)
790+ re_str = await bi_proxy .run_echart_code (str (echart_code ), echart_name )
791+ # 初始化一个空列表来保存每个echart的信息
792+ echarts_data = []
793+ # 遍历echarts_code列表,提取数据并构造字典
794+ for echart in base_content :
795+ echart_name = echart ['echart_name' ]
796+ series_data = []
797+ for serie in echart ['echart_code' ]['series' ]:
798+ try :
799+ seri_info = {
800+ 'type' : serie ['type' ],
801+ 'name' : serie ['name' ],
802+ 'data' : serie ['data' ]
803+ }
804+ except Exception as e :
805+ seri_info = {
806+ 'type' : serie ['type' ],
807+ 'data' : serie ['data' ]
808+ }
809+ series_data .append (seri_info )
810+ if "xAxis" in echart ["echart_code" ]:
811+ xAxis_data = echart ['echart_code' ]['xAxis' ][0 ]['data' ]
812+ echart_dict = {
813+ 'echart_name' : echart_name ,
814+ 'series' : series_data ,
815+ 'xAxis_data' : xAxis_data
816+ }
817+ else :
818+ echart_dict = {
819+ 'echart_name' : echart_name ,
820+ 'series' : series_data ,
821+ }
822+ echarts_data .append (echart_dict )
823+ return True , f"exitcode: { exitcode } ({ exitcode2str } )\n Code output: 图像已生成,请直接分析图表数据:{ echarts_data } "
726824
727- # no code blocks are found, push last_n_messages back and return.
728825 code_execution_config ["last_n_messages" ] = last_n_messages
729826
730827 return False , None
@@ -1138,3 +1235,24 @@ async def ask_user(self, q_str):
11381235
11391236 # return "i have no question."
11401237 return None
1238+
1239+ def regex_fix_date_format (self , code_blocks ):
1240+ # fix mysql generate %%Y %%m %%d code :list
1241+ pattern1 = r"%s"
1242+ patterns_replacements = [
1243+ (r"%%Y-%%m-%%d %%H" , "%Y-%m-%d %H" ),
1244+ (r"%%Y-%%m-%%d" , "%Y-%m-%d" ),
1245+ (r"%%Y-%%m" , "%Y-%m" ),
1246+ (r"%%H" , "%H" ),
1247+ (r"%%Y" , "%Y" ),
1248+ (r"%%Y-%%m-%%d %%H:%%i" , "%Y-%m-%d %H:%i" ),
1249+ (r"%%Y-%%m-%%d %%H:%%i:%%s" , "%Y-%m-%d %H:%i:%s" )]
1250+
1251+ if re .search (pattern1 , str (code_blocks )):
1252+ for pattern , replacement in patterns_replacements :
1253+ code_blocks = [(language , re .sub (replacement , pattern , code )) for language , code in code_blocks ]
1254+ else :
1255+ for pattern , replacement in patterns_replacements :
1256+ code_blocks = [(language , re .sub (pattern , replacement , code )) for language , code in code_blocks ]
1257+
1258+ return code_blocks
0 commit comments