@@ -38,10 +38,18 @@ async def execute(
3838 query : str ,
3939 system_prompt : str ,
4040 db_config : dict [str , Any ] | None = None ,
41+ history : list [dict [str , str ]] | None = None ,
4142 stop_checker : Callable [[], bool ] | None = None ,
4243 ) -> AsyncGenerator [SSEEvent , None ]:
4344 """
4445 执行查询并流式返回结果
46+
47+ Args:
48+ query: 用户查询
49+ system_prompt: 系统提示
50+ db_config: 数据库配置
51+ history: 对话历史消息列表 [{"role": "user/assistant", "content": "..."}]
52+ stop_checker: 停止检查函数
4553 """
4654 logger .info ("GptmeEngine.execute called" , model = self .model , query_preview = query [:50 ])
4755
@@ -62,6 +70,7 @@ async def execute(
6270 query = query ,
6371 system_prompt = system_prompt ,
6472 db_config = db_config ,
73+ history = history ,
6574 stop_checker = stop_checker ,
6675 ):
6776 yield event
@@ -74,6 +83,7 @@ async def _execute_with_litellm(
7483 query : str ,
7584 system_prompt : str ,
7685 db_config : dict [str , Any ] | None = None ,
86+ history : list [dict [str , str ]] | None = None ,
7787 stop_checker : Callable [[], bool ] | None = None ,
7888 ) -> AsyncGenerator [SSEEvent , None ]:
7989 """使用 LiteLLM 执行查询"""
@@ -88,6 +98,14 @@ async def _execute_with_litellm(
8898 db_context = self ._build_db_context (db_config )
8999 messages .append ({"role" : "system" , "content" : db_context })
90100
101+ # 添加对话历史(不包括当前查询,因为当前查询会单独添加)
102+ if history :
103+ # 过滤掉最后一条用户消息(如果和当前查询相同)
104+ for msg in history :
105+ if msg .get ("role" ) in ("user" , "assistant" ) and msg .get ("content" ):
106+ messages .append ({"role" : msg ["role" ], "content" : msg ["content" ]})
107+ logger .info (f"Added { len (history )} history messages to context" )
108+
91109 messages .append ({"role" : "user" , "content" : query })
92110
93111 yield SSEEvent .progress ("generating" , "正在生成响应..." )
@@ -117,7 +135,6 @@ async def _execute_with_litellm(
117135 data = None
118136 rows_count = None
119137 execution_time = None
120- visualization = None
121138
122139 if sql_code and db_config :
123140 yield SSEEvent .progress ("executing" , "正在执行 SQL 查询..." )
@@ -126,26 +143,40 @@ async def _execute_with_litellm(
126143 try :
127144 data , rows_count = await self ._execute_sql (sql_code , db_config )
128145 execution_time = time .time () - start_time
129-
130- # 尝试生成可视化
131- if data and len (data ) > 0 :
132- visualization = self ._generate_visualization (data , query )
133146 except Exception as e :
134147 full_content += f"\n \n ⚠️ SQL 执行错误: { str (e )} "
135148
149+ # 从 AI 输出中提取图表配置
150+ chart_config = self ._extract_chart_config (full_content )
151+
152+ # 移除图表配置代码块,使输出更干净
153+ clean_content = re .sub (r"```chart\s*\n?[\s\S]*?\n?```" , "" , full_content ).strip ()
154+
136155 yield SSEEvent .result (
137- content = full_content ,
156+ content = clean_content ,
138157 sql = sql_code ,
139158 data = data ,
140159 rows_count = rows_count ,
141160 execution_time = execution_time ,
142161 )
143162
144- if visualization :
145- yield SSEEvent .visualization (
146- chart_type = visualization .get ("type" , "bar" ),
147- chart_data = visualization .get ("data" , {}),
148- )
163+ # 如果 AI 提供了图表配置且有数据,生成可视化
164+ if chart_config and data and len (data ) > 0 :
165+ # 构建图表数据
166+ visualization = self ._build_chart_from_config (chart_config , data )
167+ if visualization :
168+ yield SSEEvent .visualization (
169+ chart_type = visualization .get ("type" , "bar" ),
170+ chart_data = visualization ,
171+ )
172+ elif data and len (data ) > 0 :
173+ # 如果 AI 没有提供图表配置,使用后备的自动生成逻辑
174+ visualization = self ._generate_visualization (data , query )
175+ if visualization :
176+ yield SSEEvent .visualization (
177+ chart_type = visualization .get ("type" , "bar" ),
178+ chart_data = visualization .get ("data" , {}),
179+ )
149180
150181 except Exception as e :
151182 yield SSEEvent .error ("LITELLM_ERROR" , str (e ))
@@ -160,8 +191,66 @@ async def _execute_sql(
160191 result = db_manager .execute_query (sql , read_only = True )
161192 return result .data , result .rows_count
162193
194+ def _build_chart_from_config (
195+ self , config : dict , data : list [dict ]
196+ ) -> dict | None :
197+ """根据 AI 提供的配置构建图表数据
198+
199+ Args:
200+ config: AI 生成的图表配置 {"type", "title", "xKey", "yKeys"}
201+ data: SQL 查询结果数据
202+
203+ Returns:
204+ 完整的图表配置,包含数据
205+ """
206+ if not data or len (data ) == 0 :
207+ return None
208+
209+ chart_type = config .get ("type" , "bar" )
210+ title = config .get ("title" , "" )
211+ x_key = config .get ("xKey" )
212+ y_keys = config .get ("yKeys" , [])
213+
214+ columns = list (data [0 ].keys ())
215+
216+ # 如果 AI 没有指定 xKey,使用第一列
217+ if not x_key or x_key not in columns :
218+ x_key = columns [0 ]
219+
220+ # 如果 AI 没有指定 yKeys,自动检测数值列
221+ if not y_keys :
222+ for col in columns :
223+ if col != x_key :
224+ try :
225+ float (data [0 ][col ])
226+ y_keys .append (col )
227+ except (ValueError , TypeError ):
228+ pass
229+
230+ if not y_keys :
231+ return None
232+
233+ # 构建图表数据
234+ chart_data = []
235+ for row in data [:50 ]: # 限制最多 50 条数据
236+ item = {"name" : str (row .get (x_key , "" ))}
237+ for y_key in y_keys :
238+ try :
239+ item [y_key ] = float (row .get (y_key , 0 ))
240+ except (ValueError , TypeError ):
241+ item [y_key ] = 0
242+ chart_data .append (item )
243+
244+ return {
245+ "type" : chart_type ,
246+ "title" : title ,
247+ "data" : chart_data ,
248+ "xKey" : "name" ,
249+ "yKeys" : y_keys ,
250+ }
251+
163252 def _generate_visualization (self , data : list [dict ], query : str ) -> dict | None :
164- """根据数据和查询生成可视化配置 """
253+ """根据数据和查询自动生成可视化配置(后备方案) """
165254 if not data or len (data ) == 0 :
166255 return None
167256
@@ -245,6 +334,36 @@ def _extract_sql(self, content: str) -> str | None:
245334
246335 return None
247336
337+ def _extract_chart_config (self , content : str ) -> dict | None :
338+ """从 AI 输出中提取图表配置
339+
340+ Args:
341+ content: AI 输出的完整内容
342+
343+ Returns:
344+ 图表配置字典,如果没有找到则返回 None
345+ """
346+ import json
347+
348+ # 匹配 ```chart ... ``` 代码块
349+ pattern = r"```chart\s*\n?([\s\S]*?)\n?```"
350+ match = re .search (pattern , content , re .IGNORECASE )
351+
352+ if match :
353+ try :
354+ config_str = match .group (1 ).strip ()
355+ config = json .loads (config_str )
356+
357+ # 验证必要字段
358+ if "type" in config :
359+ logger .info (f"Extracted chart config: type={ config .get ('type' )} " )
360+ return config
361+ except json .JSONDecodeError as e :
362+ logger .warning (f"Failed to parse chart config: { e } " )
363+ return None
364+
365+ return None
366+
248367
249368# 全局引擎实例
250369_engine : GptmeEngine | None = None
0 commit comments