Skip to content

Commit b071e75

Browse files
authored
Merge pull request #2 from jihe520/dev
Dev
2 parents 6747b55 + d85e2e8 commit b071e75

File tree

10 files changed

+170
-45
lines changed

10 files changed

+170
-45
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
.vscode/
22
backend/app/.DS_Store
33
*.DS_Store
4-
.DS_Store
4+
.DS_Store
5+
.DS_Store?
6+
._*
7+
.Spotlight-V100
8+
.Trashes

backend/.DS_Store

-6 KB
Binary file not shown.

backend/app/.DS_Store

-6 KB
Binary file not shown.

backend/app/example/.DS_Store

-6 KB
Binary file not shown.

backend/app/schemas/response.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal
1+
from typing import List, Literal, Union
22
from app.utils.enums import AgentType
33
from pydantic import BaseModel, Field
44
from uuid import uuid4
@@ -20,17 +20,52 @@ class AgentMessage(Message):
2020
agent_type: AgentType # CoderAgent | WriterAgent
2121

2222

23-
class CodeExecutionResult(BaseModel):
24-
res_type: str
25-
msg: str
23+
class CodeExecution(BaseModel):
24+
res_type: Literal["stdout", "stderr", "result", "error"]
25+
msg: str | None = None
26+
27+
28+
class StdOutModel(CodeExecution):
29+
res_type: str = "stdout"
30+
31+
32+
class StdErrModel(CodeExecution):
33+
res_type: str = "stderr"
34+
35+
36+
class ResultModel(CodeExecution):
37+
res_type: str = "result"
38+
format: Literal[
39+
"text",
40+
"html",
41+
"markdown",
42+
"png",
43+
"jpeg",
44+
"svg",
45+
"pdf",
46+
"latex",
47+
"json",
48+
"javascript",
49+
]
50+
51+
52+
class ErrorModel(CodeExecution):
53+
res_type: str = "error"
54+
name: str
55+
value: str
56+
traceback: str
57+
58+
59+
# 总返回类型
60+
OutputItem = Union[StdOutModel, StdErrModel, ResultModel, ErrorModel]
2661

2762

2863
# 1. 只带 code
2964
# 2. 只带 code result
3065
class CoderMessage(AgentMessage):
3166
agent_type: AgentType = AgentType.CODER
3267
code: str | None = None
33-
code_result: CodeExecutionResult | None = None
68+
code_results: list[OutputItem] | None = None
3469

3570

3671
class WriterMessage(AgentMessage):

backend/app/tools/code_interpreter.py

Lines changed: 124 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import os
22
import re
33
from e2b_code_interpreter import AsyncSandbox
4-
from app.schemas.response import CodeExecutionResult, CoderMessage, SystemMessage
4+
from app.schemas.response import (
5+
CoderMessage,
6+
ErrorModel,
7+
OutputItem,
8+
ResultModel,
9+
StdErrModel,
10+
StdOutModel,
11+
SystemMessage,
12+
)
513
from app.utils.enums import AgentType
614
from app.utils.redis_manager import redis_manager
715
from app.utils.notebook_serializer import NotebookSerializer
@@ -122,7 +130,7 @@ async def execute_code(self, code: str) -> tuple[str, bool, str]:
122130
self.notebook_serializer.add_code_cell_to_notebook(code)
123131

124132
text_to_gpt: list[str] = []
125-
content_to_display = []
133+
content_to_display: list[OutputItem] | None = []
126134
error_occurred: bool = False
127135
error_message: str = ""
128136

@@ -147,7 +155,13 @@ async def execute_code(self, code: str) -> tuple[str, bool, str]:
147155
error_message = self._truncate_text(error_message)
148156
logger.error(f"执行错误: {error_message}")
149157
text_to_gpt.append(delete_color_control_char(error_message))
150-
content_to_display.append(("error", error_message))
158+
content_to_display.append(
159+
ErrorModel(
160+
name=execution.error.name,
161+
value=execution.error.value,
162+
traceback=execution.error.traceback,
163+
)
164+
)
151165
# 处理标准输出和标准错误
152166

153167
if execution.logs:
@@ -156,48 +170,124 @@ async def execute_code(self, code: str) -> tuple[str, bool, str]:
156170
stdout_str = self._truncate_text(stdout_str)
157171
logger.info(f"标准输出: {stdout_str}")
158172
text_to_gpt.append(stdout_str)
159-
content_to_display.append(("text", stdout_str))
173+
content_to_display.append(
174+
StdOutModel(msg="\n".join(execution.logs.stdout))
175+
)
160176
self.notebook_serializer.add_code_cell_output_to_notebook(stdout_str)
161177

162178
if execution.logs.stderr:
163179
stderr_str = "\n".join(execution.logs.stderr)
164180
stderr_str = self._truncate_text(stderr_str)
165181
logger.warning(f"标准错误: {stderr_str}")
166182
text_to_gpt.append(stderr_str)
167-
content_to_display.append(("error", stderr_str))
183+
content_to_display.append(
184+
StdErrModel(msg="\n".join(execution.logs.stderr))
185+
)
168186

169187
# 处理执行结果
170188
if execution.results:
171189
for result in execution.results:
172-
# 处理主要结果
173-
if result.is_main_result and result.text:
174-
result_text = self._truncate_text(result.text)
175-
logger.info(f"主要结果: {result_text}")
176-
text_to_gpt.append(result_text)
177-
content_to_display.append(("text", result_text))
178-
self.notebook_serializer.add_code_cell_output_to_notebook(
179-
result_text
190+
# 1. 文本格式
191+
if str(result):
192+
content_to_display.append(
193+
ResultModel(type="result", format="text", msg=str(result))
194+
)
195+
# 2. HTML格式
196+
if result._repr_html_():
197+
content_to_display.append(
198+
ResultModel(
199+
type="result", format="html", msg=result._repr_html_()
200+
)
201+
)
202+
# 3. Markdown格式
203+
if result._repr_markdown_():
204+
content_to_display.append(
205+
ResultModel(
206+
type="result",
207+
format="markdown",
208+
msg=result._repr_markdown_(),
209+
)
210+
)
211+
# 4. PNG图片(base64字符串,前端可直接渲染)
212+
if result._repr_png_():
213+
content_to_display.append(
214+
ResultModel(
215+
type="result", format="png", msg=result._repr_png_()
216+
)
217+
)
218+
# 5. JPEG图片
219+
if result._repr_jpeg_():
220+
content_to_display.append(
221+
ResultModel(
222+
type="result", format="jpeg", msg=result._repr_jpeg_()
223+
)
224+
)
225+
# 6. SVG
226+
if result._repr_svg_():
227+
content_to_display.append(
228+
ResultModel(
229+
type="result", format="svg", msg=result._repr_svg_()
230+
)
231+
)
232+
# 7. PDF
233+
if result._repr_pdf_():
234+
content_to_display.append(
235+
ResultModel(
236+
type="result", format="pdf", msg=result._repr_pdf_()
237+
)
238+
)
239+
# 8. LaTeX
240+
if result._repr_latex_():
241+
content_to_display.append(
242+
ResultModel(
243+
type="result", format="latex", msg=result._repr_latex_()
244+
)
245+
)
246+
# 9. JSON
247+
if result._repr_json_():
248+
content_to_display.append(
249+
ResultModel(
250+
type="result",
251+
format="json",
252+
msg=json.dumps(result._repr_json_()),
253+
)
180254
)
255+
# 10. JavaScript
256+
if result._repr_javascript_():
257+
content_to_display.append(
258+
ResultModel(
259+
type="result",
260+
format="javascript",
261+
msg=result._repr_javascript_(),
262+
)
263+
)
264+
# 处理主要结果
265+
# if result.is_main_result and result.text:
266+
# result_text = self._truncate_text(result.text)
267+
# logger.info(f"主要结果: {result_text}")
268+
# text_to_gpt.append(result_text)
269+
# self.notebook_serializer.add_code_cell_output_to_notebook(
270+
# result_text
271+
# )
181272

182-
# 处理图表结果
183-
if result.chart:
184-
logger.info("发现图表结果")
185-
chart_data = result.chart.to_dict()
186-
chart_str = str(chart_data)
187-
# if len(chart_str) > 1000: # 限制图表数据大小
188-
# chart_str = "图表数据过大,已省略"
189-
text_to_gpt.append(chart_str)
190-
content_to_display.append(("chart", chart_data))
191-
192-
# 保存到分段内容
193-
for val in content_to_display:
194-
self.add_section(val[0])
195-
self.add_content(val[0], val[1])
196-
await self._push_to_websocket(val[0], val[1])
273+
# 限制返回的文本总长度
197274

198-
logger.info("执行结果已推送到WebSocket")
275+
for item in content_to_display:
276+
if isinstance(item, dict):
277+
if item.get("type") in ["stdout", "stderr", "error"]:
278+
text_to_gpt.append(item.get("content") or item.get("value") or "")
279+
elif isinstance(item, ResultModel):
280+
if item.format in ["text", "html", "markdown", "json"]:
281+
text_to_gpt.append(f"[{item.format}]\n{item.msg}")
282+
elif item.format in ["png", "jpeg", "svg", "pdf"]:
283+
text_to_gpt.append(
284+
f"[{item.format} 图片已生成,内容为 base64,未展示]"
285+
)
286+
287+
# 保存到分段内容
288+
## TODO: Base64 等图像需要优化
289+
await self._push_to_websocket(content_to_display)
199290

200-
# 限制返回的文本总长度
201291
combined_text = "\n".join(text_to_gpt)
202292

203293
return (
@@ -206,17 +296,12 @@ async def execute_code(self, code: str) -> tuple[str, bool, str]:
206296
error_message,
207297
)
208298

209-
async def _push_to_websocket(self, res_type, msg):
210-
# 如果msg不是字符串,转为json字符串
211-
if not isinstance(msg, str):
212-
msg = json.dumps(msg, ensure_ascii=False)
213-
code_execution_result = CodeExecutionResult(
214-
res_type=res_type,
215-
msg=msg,
216-
)
299+
async def _push_to_websocket(self, content_to_display: list[OutputItem] | None):
300+
logger.info("执行结果已推送到WebSocket")
301+
217302
agent_msg = CoderMessage(
218303
agent_type=AgentType.CODER,
219-
code_result=code_execution_result,
304+
code_results=content_to_display,
220305
)
221306
logger.debug(f"发送消息: {agent_msg.model_dump_json()}")
222307
await redis_manager.publish_message(

frontend/src/pages/task/index.vue

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ const chatMessages = computed(() =>
6161
// 有 code_result 的 CoderAgent 消息不显示
6262
return false
6363
}
64+
// writer agent 不显示 ## TODO writer 应该显示
6465
// 其他 agent 或 system 消息正常显示
6566
return msg.msg_type === 'agent' && msg.content || msg.msg_type === 'system'
6667
}

0 commit comments

Comments
 (0)