Skip to content

Commit 73013e0

Browse files
MKY508claude
andcommitted
fix: 修复对话上下文丢失 + AI 自主生成图表
P0-1: 修复对话上下文丢失 - ExecutionService 新增 _get_conversation_history() 方法 - GptmeEngine.execute() 接收 history 参数 - 历史消息正确注入到 LiteLLM 消息列表 P0-2: AI 自主生成图表 - 更新系统提示,指导 AI 输出 ```chart 配置块 - 新增 _extract_chart_config() 解析 AI 图表配置 - 新增 _build_chart_from_config() 构建图表数据 - 保留自动生成逻辑作为后备方案 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent b95f515 commit 73013e0

File tree

2 files changed

+204
-19
lines changed

2 files changed

+204
-19
lines changed

apps/api/app/services/execution.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from app.core import encryptor
1515
from app.core.config import settings
16-
from app.db.tables import Connection, Model, SemanticTerm, User
16+
from app.db.tables import Connection, Message, Model, SemanticTerm, User
1717
from app.models import SemanticContext, SemanticTermResponse, SSEEvent
1818

1919
logger = structlog.get_logger()
@@ -152,6 +152,35 @@ async def _get_semantic_context(self) -> SemanticContext:
152152
terms=[SemanticTermResponse.model_validate(t) for t in terms]
153153
)
154154

155+
async def _get_conversation_history(
156+
self, conversation_id: UUID, limit: int = 10
157+
) -> list[dict[str, str]]:
158+
"""获取对话历史消息
159+
160+
Args:
161+
conversation_id: 对话 ID
162+
limit: 最大消息数量(最近的 N 条)
163+
164+
Returns:
165+
消息列表 [{"role": "user/assistant", "content": "..."}]
166+
"""
167+
result = await self.db.execute(
168+
select(Message)
169+
.where(Message.conversation_id == conversation_id)
170+
.order_by(Message.created_at.desc())
171+
.limit(limit)
172+
)
173+
messages = result.scalars().all()
174+
175+
# 反转顺序,使最早的消息在前
176+
history = []
177+
for msg in reversed(messages):
178+
if msg.role in ("user", "assistant") and msg.content:
179+
history.append({"role": msg.role, "content": msg.content})
180+
181+
logger.info(f"Loaded {len(history)} history messages for conversation {conversation_id}")
182+
return history
183+
155184
async def execute_stream(
156185
self,
157186
query: str,
@@ -178,6 +207,10 @@ async def execute_stream(
178207
semantic_context = await self._get_semantic_context()
179208
logger.info(f"Semantic terms count: {len(semantic_context.terms)}")
180209

210+
# 加载对话历史(不包括当前查询,因为当前查询还未保存)
211+
logger.info("Getting conversation history...")
212+
history = await self._get_conversation_history(conversation_id, limit=10)
213+
181214
system_prompt = self._build_system_prompt(db_config, semantic_context)
182215

183216
engine = GptmeEngine(
@@ -191,6 +224,7 @@ async def execute_stream(
191224
query=query,
192225
system_prompt=system_prompt,
193226
db_config=db_config,
227+
history=history,
194228
stop_checker=stop_checker,
195229
):
196230
logger.info(f"Yielding event: {event.type}")
@@ -208,18 +242,50 @@ def _build_system_prompt(
208242
209243
请遵循以下规则:
210244
1. 只生成只读 SQL(SELECT、SHOW、DESCRIBE)
211-
2. 使用 pandas 处理数据
212-
3. 使用 plotly 生成可视化图表
213-
4. 用中文回复用户
245+
2. 用中文回复用户
246+
3. 如果查询结果适合可视化,在回复末尾添加图表配置(使用 ```chart 代码块):
247+
248+
```chart
249+
{
250+
"type": "bar",
251+
"title": "图表标题",
252+
"xKey": "x轴字段名",
253+
"yKeys": ["y轴字段名1", "y轴字段名2"]
254+
}
255+
```
256+
257+
图表类型选择指南:
258+
- bar: 比较不同类别的数值(如各地区销售额)
259+
- line: 展示趋势变化(如月度增长)
260+
- pie: 展示占比分布(如市场份额)
261+
- area: 展示累积趋势
262+
263+
注意:只有当数据适合可视化时才添加图表配置,简单的单值查询不需要图表。
214264
"""
215265
else:
216266
base_prompt = """You are QueryGPT data analysis assistant, helping users query and analyze database data.
217267
218268
Follow these rules:
219269
1. Only generate read-only SQL (SELECT, SHOW, DESCRIBE)
220-
2. Use pandas for data processing
221-
3. Use plotly for visualization
222-
4. Reply in English
270+
2. Reply in English
271+
3. If query results are suitable for visualization, add chart config at the end (using ```chart code block):
272+
273+
```chart
274+
{
275+
"type": "bar",
276+
"title": "Chart Title",
277+
"xKey": "x_axis_field",
278+
"yKeys": ["y_axis_field1", "y_axis_field2"]
279+
}
280+
```
281+
282+
Chart type guide:
283+
- bar: Compare values across categories
284+
- line: Show trends over time
285+
- pie: Show proportions/percentages
286+
- area: Show cumulative trends
287+
288+
Note: Only add chart config when data is suitable for visualization.
223289
"""
224290

225291
if db_config:

apps/api/app/services/gptme_engine.py

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,18 @@ async def execute(
3838
query: str,
3939
system_prompt: str,
4040
db_config: dict[str, Any] | None = None,
41+
history: list[dict[str, str]] | None = None,
4142
stop_checker: Callable[[], bool] | None = None,
4243
) -> AsyncGenerator[SSEEvent, None]:
4344
"""
4445
执行查询并流式返回结果
46+
47+
Args:
48+
query: 用户查询
49+
system_prompt: 系统提示
50+
db_config: 数据库配置
51+
history: 对话历史消息列表 [{"role": "user/assistant", "content": "..."}]
52+
stop_checker: 停止检查函数
4553
"""
4654
logger.info("GptmeEngine.execute called", model=self.model, query_preview=query[:50])
4755

@@ -62,6 +70,7 @@ async def execute(
6270
query=query,
6371
system_prompt=system_prompt,
6472
db_config=db_config,
73+
history=history,
6574
stop_checker=stop_checker,
6675
):
6776
yield event
@@ -74,6 +83,7 @@ async def _execute_with_litellm(
7483
query: str,
7584
system_prompt: str,
7685
db_config: dict[str, Any] | None = None,
86+
history: list[dict[str, str]] | None = None,
7787
stop_checker: Callable[[], bool] | None = None,
7888
) -> AsyncGenerator[SSEEvent, None]:
7989
"""使用 LiteLLM 执行查询"""
@@ -88,6 +98,14 @@ async def _execute_with_litellm(
8898
db_context = self._build_db_context(db_config)
8999
messages.append({"role": "system", "content": db_context})
90100

101+
# 添加对话历史(不包括当前查询,因为当前查询会单独添加)
102+
if history:
103+
# 过滤掉最后一条用户消息(如果和当前查询相同)
104+
for msg in history:
105+
if msg.get("role") in ("user", "assistant") and msg.get("content"):
106+
messages.append({"role": msg["role"], "content": msg["content"]})
107+
logger.info(f"Added {len(history)} history messages to context")
108+
91109
messages.append({"role": "user", "content": query})
92110

93111
yield SSEEvent.progress("generating", "正在生成响应...")
@@ -117,7 +135,6 @@ async def _execute_with_litellm(
117135
data = None
118136
rows_count = None
119137
execution_time = None
120-
visualization = None
121138

122139
if sql_code and db_config:
123140
yield SSEEvent.progress("executing", "正在执行 SQL 查询...")
@@ -126,26 +143,40 @@ async def _execute_with_litellm(
126143
try:
127144
data, rows_count = await self._execute_sql(sql_code, db_config)
128145
execution_time = time.time() - start_time
129-
130-
# 尝试生成可视化
131-
if data and len(data) > 0:
132-
visualization = self._generate_visualization(data, query)
133146
except Exception as e:
134147
full_content += f"\n\n⚠️ SQL 执行错误: {str(e)}"
135148

149+
# 从 AI 输出中提取图表配置
150+
chart_config = self._extract_chart_config(full_content)
151+
152+
# 移除图表配置代码块,使输出更干净
153+
clean_content = re.sub(r"```chart\s*\n?[\s\S]*?\n?```", "", full_content).strip()
154+
136155
yield SSEEvent.result(
137-
content=full_content,
156+
content=clean_content,
138157
sql=sql_code,
139158
data=data,
140159
rows_count=rows_count,
141160
execution_time=execution_time,
142161
)
143162

144-
if visualization:
145-
yield SSEEvent.visualization(
146-
chart_type=visualization.get("type", "bar"),
147-
chart_data=visualization.get("data", {}),
148-
)
163+
# 如果 AI 提供了图表配置且有数据,生成可视化
164+
if chart_config and data and len(data) > 0:
165+
# 构建图表数据
166+
visualization = self._build_chart_from_config(chart_config, data)
167+
if visualization:
168+
yield SSEEvent.visualization(
169+
chart_type=visualization.get("type", "bar"),
170+
chart_data=visualization,
171+
)
172+
elif data and len(data) > 0:
173+
# 如果 AI 没有提供图表配置,使用后备的自动生成逻辑
174+
visualization = self._generate_visualization(data, query)
175+
if visualization:
176+
yield SSEEvent.visualization(
177+
chart_type=visualization.get("type", "bar"),
178+
chart_data=visualization.get("data", {}),
179+
)
149180

150181
except Exception as e:
151182
yield SSEEvent.error("LITELLM_ERROR", str(e))
@@ -160,8 +191,66 @@ async def _execute_sql(
160191
result = db_manager.execute_query(sql, read_only=True)
161192
return result.data, result.rows_count
162193

194+
def _build_chart_from_config(
195+
self, config: dict, data: list[dict]
196+
) -> dict | None:
197+
"""根据 AI 提供的配置构建图表数据
198+
199+
Args:
200+
config: AI 生成的图表配置 {"type", "title", "xKey", "yKeys"}
201+
data: SQL 查询结果数据
202+
203+
Returns:
204+
完整的图表配置,包含数据
205+
"""
206+
if not data or len(data) == 0:
207+
return None
208+
209+
chart_type = config.get("type", "bar")
210+
title = config.get("title", "")
211+
x_key = config.get("xKey")
212+
y_keys = config.get("yKeys", [])
213+
214+
columns = list(data[0].keys())
215+
216+
# 如果 AI 没有指定 xKey,使用第一列
217+
if not x_key or x_key not in columns:
218+
x_key = columns[0]
219+
220+
# 如果 AI 没有指定 yKeys,自动检测数值列
221+
if not y_keys:
222+
for col in columns:
223+
if col != x_key:
224+
try:
225+
float(data[0][col])
226+
y_keys.append(col)
227+
except (ValueError, TypeError):
228+
pass
229+
230+
if not y_keys:
231+
return None
232+
233+
# 构建图表数据
234+
chart_data = []
235+
for row in data[:50]: # 限制最多 50 条数据
236+
item = {"name": str(row.get(x_key, ""))}
237+
for y_key in y_keys:
238+
try:
239+
item[y_key] = float(row.get(y_key, 0))
240+
except (ValueError, TypeError):
241+
item[y_key] = 0
242+
chart_data.append(item)
243+
244+
return {
245+
"type": chart_type,
246+
"title": title,
247+
"data": chart_data,
248+
"xKey": "name",
249+
"yKeys": y_keys,
250+
}
251+
163252
def _generate_visualization(self, data: list[dict], query: str) -> dict | None:
164-
"""根据数据和查询生成可视化配置"""
253+
"""根据数据和查询自动生成可视化配置(后备方案)"""
165254
if not data or len(data) == 0:
166255
return None
167256

@@ -245,6 +334,36 @@ def _extract_sql(self, content: str) -> str | None:
245334

246335
return None
247336

337+
def _extract_chart_config(self, content: str) -> dict | None:
338+
"""从 AI 输出中提取图表配置
339+
340+
Args:
341+
content: AI 输出的完整内容
342+
343+
Returns:
344+
图表配置字典,如果没有找到则返回 None
345+
"""
346+
import json
347+
348+
# 匹配 ```chart ... ``` 代码块
349+
pattern = r"```chart\s*\n?([\s\S]*?)\n?```"
350+
match = re.search(pattern, content, re.IGNORECASE)
351+
352+
if match:
353+
try:
354+
config_str = match.group(1).strip()
355+
config = json.loads(config_str)
356+
357+
# 验证必要字段
358+
if "type" in config:
359+
logger.info(f"Extracted chart config: type={config.get('type')}")
360+
return config
361+
except json.JSONDecodeError as e:
362+
logger.warning(f"Failed to parse chart config: {e}")
363+
return None
364+
365+
return None
366+
248367

249368
# 全局引擎实例
250369
_engine: GptmeEngine | None = None

0 commit comments

Comments
 (0)