Skip to content

Commit ad78a95

Browse files
committed
feat(agui): add error event handling and improve protocol event sequencing
Added comprehensive error event handling for LangGraph integrations including on_tool_error, on_llm_error, on_chain_error, and on_retriever_error events. Enhanced AG-UI protocol event sequencing with proper boundary management, ensuring correct order of TEXT_MESSAGE and TOOL_CALL events. Implemented RUN_ERROR handling that properly terminates event streams without sending subsequent events. Updated AgentRequest to use raw_request object instead of separate body/headers fields for better request access. 新增了 LangGraph 集成的全面错误事件处理,包括 on_tool_error、on_llm_error、 on_chain_error 和 on_retriever_error 事件。增强了 AG-UI 协议事件序列, 确保 TEXT_MESSAGE 和 TOOL_CALL 事件的正确顺序。实现了 RUN_ERROR 处理, 在发生错误时正确终止事件流,不再发送后续事件。更新了 AgentRequest 使用 raw_request 对象替代独立的 body/headers 字段以更好地访问请求。 BREAKING CHANGE: AgentRequest body and headers fields replaced with raw_request object 重大变更:AgentRequest 的 body 和 headers 字段被 raw_request 对象替代 Change-Id: Ibc612068239977c3d01a338ba8d34992b988e451 Signed-off-by: OhYee <[email protected]>
1 parent 14496d3 commit ad78a95

File tree

9 files changed

+1714
-64
lines changed

9 files changed

+1714
-64
lines changed

agentrun/integration/langgraph/agent_converter.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,109 @@ def _convert_astream_events_event(
689689
# 无状态模式下不处理,避免重复
690690
pass
691691

692+
# 6. 工具错误
693+
elif event_type == "on_tool_error":
694+
run_id = event_dict.get("run_id", "")
695+
error = data.get("error")
696+
tool_input_raw = data.get("input", {})
697+
tool_name = event_dict.get("name", "")
698+
# 优先使用 runtime 中的原始 tool_call_id
699+
tool_call_id = _extract_tool_call_id(tool_input_raw) or run_id
700+
701+
# 格式化错误信息
702+
error_message = ""
703+
if error is not None:
704+
if isinstance(error, Exception):
705+
error_message = f"{type(error).__name__}: {str(error)}"
706+
elif isinstance(error, str):
707+
error_message = error
708+
else:
709+
error_message = str(error)
710+
711+
# 发送 ERROR 事件
712+
yield AgentResult(
713+
event=EventType.ERROR,
714+
data={
715+
"message": (
716+
f"Tool '{tool_name}' error: {error_message}"
717+
if tool_name
718+
else error_message
719+
),
720+
"code": "TOOL_ERROR",
721+
"tool_call_id": tool_call_id,
722+
},
723+
)
724+
725+
# 7. LLM 错误
726+
elif event_type == "on_llm_error":
727+
error = data.get("error")
728+
error_message = ""
729+
if error is not None:
730+
if isinstance(error, Exception):
731+
error_message = f"{type(error).__name__}: {str(error)}"
732+
elif isinstance(error, str):
733+
error_message = error
734+
else:
735+
error_message = str(error)
736+
737+
yield AgentResult(
738+
event=EventType.ERROR,
739+
data={
740+
"message": f"LLM error: {error_message}",
741+
"code": "LLM_ERROR",
742+
},
743+
)
744+
745+
# 8. Chain 错误
746+
elif event_type == "on_chain_error":
747+
error = data.get("error")
748+
chain_name = event_dict.get("name", "")
749+
error_message = ""
750+
if error is not None:
751+
if isinstance(error, Exception):
752+
error_message = f"{type(error).__name__}: {str(error)}"
753+
elif isinstance(error, str):
754+
error_message = error
755+
else:
756+
error_message = str(error)
757+
758+
yield AgentResult(
759+
event=EventType.ERROR,
760+
data={
761+
"message": (
762+
f"Chain '{chain_name}' error: {error_message}"
763+
if chain_name
764+
else error_message
765+
),
766+
"code": "CHAIN_ERROR",
767+
},
768+
)
769+
770+
# 9. Retriever 错误
771+
elif event_type == "on_retriever_error":
772+
error = data.get("error")
773+
retriever_name = event_dict.get("name", "")
774+
error_message = ""
775+
if error is not None:
776+
if isinstance(error, Exception):
777+
error_message = f"{type(error).__name__}: {str(error)}"
778+
elif isinstance(error, str):
779+
error_message = error
780+
else:
781+
error_message = str(error)
782+
783+
yield AgentResult(
784+
event=EventType.ERROR,
785+
data={
786+
"message": (
787+
f"Retriever '{retriever_name}' error: {error_message}"
788+
if retriever_name
789+
else error_message
790+
),
791+
"code": "RETRIEVER_ERROR",
792+
},
793+
)
794+
692795

693796
# =============================================================================
694797
# 主要 API

agentrun/server/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,18 @@
5959
... yield f"当前时间: {result}"
6060
6161
Example (访问原始请求):
62-
>>> def invoke_agent(request: AgentRequest):
62+
>>> async def invoke_agent(request: AgentRequest):
6363
... # 访问当前协议
6464
... protocol = request.protocol # "openai" 或 "agui"
6565
...
6666
... # 访问原始请求头
67-
... auth = request.headers.get("Authorization")
67+
... auth = request.raw_request.headers.get("Authorization")
68+
...
69+
... # 访问查询参数
70+
... params = request.raw_request.query_params
6871
...
69-
... # 访问原始请求体
70-
... custom_field = request.body.get("custom_field")
72+
... # 访问客户端 IP
73+
... client_ip = request.raw_request.client.host if request.raw_request.client else None
7174
...
7275
... return "Hello, world!"
7376
"""

agentrun/server/agui_protocol.py

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,13 @@ async def parse_request(
192192
# 解析工具列表
193193
tools = self._parse_tools(request_data.get("tools"))
194194

195-
# 提取原始请求头
196-
raw_headers = dict(request.headers)
197-
198195
# 构建 AgentRequest
199196
agent_request = AgentRequest(
200197
protocol="agui", # 设置协议名称
201198
messages=messages,
202199
stream=True, # AG-UI 总是流式
203200
tools=tools,
204-
body=request_data,
205-
headers=raw_headers,
201+
raw_request=request, # 保留原始请求对象
206202
)
207203

208204
return agent_request, context
@@ -295,19 +291,26 @@ async def _format_stream(
295291
- TEXT_MESSAGE_START / TEXT_MESSAGE_END(文本边界)
296292
- TOOL_CALL_START / TOOL_CALL_END(工具调用边界)
297293
294+
注意:RUN_ERROR 之后不能再发送任何事件(包括 RUN_FINISHED)
295+
298296
Args:
299297
event_stream: AgentEvent 流
300298
context: 上下文信息
301299
302300
Yields:
303301
SSE 格式的字符串
304302
"""
305-
message_id = str(uuid.uuid4())
306-
307-
# 状态追踪
308-
text_started = False
303+
# 状态追踪(使用可变容器以便在 _process_event_with_boundaries 中更新)
304+
# text_state: {"started": bool, "ended": bool, "message_id": str}
305+
text_state: Dict[str, Any] = {
306+
"started": False,
307+
"ended": False,
308+
"message_id": str(uuid.uuid4()),
309+
}
309310
# 工具调用状态:{tool_id: {"started": bool, "ended": bool}}
310311
tool_call_states: Dict[str, Dict[str, bool]] = {}
312+
# 错误状态:RUN_ERROR 后不能再发送任何事件
313+
run_errored = False
311314

312315
# 发送 RUN_STARTED
313316
yield self._encoder.encode(
@@ -318,24 +321,24 @@ async def _format_stream(
318321
)
319322

320323
async for event in event_stream:
324+
# RUN_ERROR 后不再处理任何事件
325+
if run_errored:
326+
continue
327+
328+
# 检查是否是错误事件
329+
if event.event == EventType.ERROR:
330+
run_errored = True
331+
321332
# 处理边界事件注入
322333
for sse_data in self._process_event_with_boundaries(
323-
event, context, message_id, text_started, tool_call_states
334+
event, context, text_state, tool_call_states
324335
):
325336
if sse_data:
326337
yield sse_data
327338

328-
# 更新状态
329-
if event.event == EventType.TEXT:
330-
text_started = True
331-
elif event.event == EventType.TOOL_CALL_CHUNK:
332-
tool_id = event.data.get("id", "")
333-
if tool_id:
334-
if tool_id not in tool_call_states:
335-
tool_call_states[tool_id] = {
336-
"started": True,
337-
"ended": False,
338-
}
339+
# RUN_ERROR 后不发送任何清理事件
340+
if run_errored:
341+
return
339342

340343
# 结束所有未结束的工具调用
341344
for tool_id, state in tool_call_states.items():
@@ -344,10 +347,10 @@ async def _format_stream(
344347
ToolCallEndEvent(tool_call_id=tool_id)
345348
)
346349

347-
# 发送 TEXT_MESSAGE_END(如果有文本消息
348-
if text_started:
350+
# 发送 TEXT_MESSAGE_END(如果有文本消息且未结束
351+
if text_state["started"] and not text_state["ended"]:
349352
yield self._encoder.encode(
350-
TextMessageEndEvent(message_id=message_id)
353+
TextMessageEndEvent(message_id=text_state["message_id"])
351354
)
352355

353356
# 发送 RUN_FINISHED
@@ -362,17 +365,15 @@ def _process_event_with_boundaries(
362365
self,
363366
event: AgentEvent,
364367
context: Dict[str, Any],
365-
message_id: str,
366-
text_started: bool,
368+
text_state: Dict[str, Any],
367369
tool_call_states: Dict[str, Dict[str, bool]],
368370
) -> Iterator[str]:
369371
"""处理事件并注入边界事件
370372
371373
Args:
372374
event: 用户事件
373375
context: 上下文
374-
message_id: 消息 ID
375-
text_started: 文本是否已开始
376+
text_state: 文本状态 {"started": bool, "ended": bool, "message_id": str}
376377
tool_call_states: 工具调用状态
377378
378379
Yields:
@@ -391,17 +392,31 @@ def _process_event_with_boundaries(
391392

392393
# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
393394
if event.event == EventType.TEXT:
394-
if not text_started:
395+
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
396+
for tool_id, state in tool_call_states.items():
397+
if state["started"] and not state["ended"]:
398+
yield self._encoder.encode(
399+
ToolCallEndEvent(tool_call_id=tool_id)
400+
)
401+
state["ended"] = True
402+
403+
# 如果文本消息未开始,或者之前已结束(需要重新开始新消息)
404+
if not text_state["started"] or text_state["ended"]:
405+
# 每个新文本消息需要新的 messageId
406+
if text_state["ended"]:
407+
text_state["message_id"] = str(uuid.uuid4())
395408
yield self._encoder.encode(
396409
TextMessageStartEvent(
397-
message_id=message_id,
410+
message_id=text_state["message_id"],
398411
role="assistant",
399412
)
400413
)
414+
text_state["started"] = True
415+
text_state["ended"] = False
401416

402417
# 发送 TEXT_MESSAGE_CONTENT
403418
agui_event = TextMessageContentEvent(
404-
message_id=message_id,
419+
message_id=text_state["message_id"],
405420
delta=event.data.get("delta", ""),
406421
)
407422
if event.addition:
@@ -422,6 +437,14 @@ def _process_event_with_boundaries(
422437
tool_id = event.data.get("id", "")
423438
tool_name = event.data.get("name", "")
424439

440+
# 如果文本消息未结束,先结束文本消息
441+
# AG-UI 协议要求:发送 TOOL_CALL_START 前必须先结束 TEXT_MESSAGE
442+
if text_state["started"] and not text_state["ended"]:
443+
yield self._encoder.encode(
444+
TextMessageEndEvent(message_id=text_state["message_id"])
445+
)
446+
text_state["ended"] = True
447+
425448
if tool_id and tool_id not in tool_call_states:
426449
# 首次见到这个工具调用,发送 TOOL_CALL_START
427450
yield self._encoder.encode(
@@ -445,6 +468,14 @@ def _process_event_with_boundaries(
445468
if event.event == EventType.TOOL_RESULT:
446469
tool_id = event.data.get("id", "")
447470

471+
# 如果文本消息未结束,先结束文本消息
472+
# AG-UI 协议要求:发送 TOOL_CALL_START 前必须先结束 TEXT_MESSAGE
473+
if text_state["started"] and not text_state["ended"]:
474+
yield self._encoder.encode(
475+
TextMessageEndEvent(message_id=text_state["message_id"])
476+
)
477+
text_state["ended"] = True
478+
448479
# 如果工具调用未开始,先补充 START
449480
if tool_id and tool_id not in tool_call_states:
450481
yield self._encoder.encode(
@@ -482,6 +513,21 @@ def _process_event_with_boundaries(
482513

483514
# ERROR 事件
484515
if event.event == EventType.ERROR:
516+
# AG-UI 协议要求:发送 RUN_ERROR 前必须先结束所有未结束的 TOOL_CALL
517+
for tool_id, state in tool_call_states.items():
518+
if state["started"] and not state["ended"]:
519+
yield self._encoder.encode(
520+
ToolCallEndEvent(tool_call_id=tool_id)
521+
)
522+
state["ended"] = True
523+
524+
# AG-UI 协议要求:发送 RUN_ERROR 前必须先结束文本消息
525+
if text_state["started"] and not text_state["ended"]:
526+
yield self._encoder.encode(
527+
TextMessageEndEvent(message_id=text_state["message_id"])
528+
)
529+
text_state["ended"] = True
530+
485531
yield self._encoder.encode(
486532
RunErrorEvent(
487533
message=event.data.get("message", ""),

agentrun/server/invoker.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -307,19 +307,3 @@ def _is_iterator(self, obj: Any) -> bool:
307307
if isinstance(obj, (str, bytes, dict, list, AgentEvent)):
308308
return False
309309
return hasattr(obj, "__iter__") or hasattr(obj, "__aiter__")
310-
311-
def _get_thread_id(self, request: AgentRequest) -> str:
312-
"""获取 thread ID"""
313-
return (
314-
request.body.get("threadId")
315-
or request.body.get("thread_id")
316-
or str(uuid.uuid4())
317-
)
318-
319-
def _get_run_id(self, request: AgentRequest) -> str:
320-
"""获取 run ID"""
321-
return (
322-
request.body.get("runId")
323-
or request.body.get("run_id")
324-
or str(uuid.uuid4())
325-
)

0 commit comments

Comments
 (0)