Skip to content

Commit 5d56c14

Browse files
authored
refactor(agui_protocol): remove copilotkit compatibility mode and simplify tool call handling (#21)
this change removes the copilotkit_compatibility mode from the AGUI protocol implementation, simplifying the tool call state management and event processing logic. the uuid mapping and event queuing functionality has been removed, allowing for true parallel tool call support according to the AG-UI specification. test cases have been updated to reflect the new parallel behavior and remove compatibility mode specific tests. 移除 AGUI 协议中的 copilotkit_compatibility 模式并简化工具调用处理 此更改从 AGUI 协议实现中删除了 copilotkit_compatibility 模式,简化了工具调用状态管理和事件处理逻辑。uuid 映射和事件队列功能已被删除,允许根据 AG-UI 规范支持真正的并行工具调用。测试用例已更新以反映新的并行行为并删除兼容性模式特定的测试。 Change-Id: I099353b835d4761f15fdec4b42f82c9514930a5e Signed-off-by: OhYee <[email protected]>
1 parent 9b47c50 commit 5d56c14

File tree

5 files changed

+365
-631
lines changed

5 files changed

+365
-631
lines changed

agentrun/server/agui_protocol.py

Lines changed: 12 additions & 258 deletions
Original file line numberDiff line numberDiff line change
@@ -95,45 +95,10 @@ class ToolCallState:
9595

9696
@dataclass
9797
class StreamStateMachine:
98-
copilotkit_compatibility: bool
9998
text: TextState = field(default_factory=TextState)
10099
tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict)
101100
tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict)
102-
uuid_to_tool_call_id: Dict[str, str] = field(default_factory=dict)
103101
run_errored: bool = False
104-
active_tool_id: Optional[str] = None
105-
pending_events: List["AgentEvent"] = field(default_factory=list)
106-
107-
@staticmethod
108-
def _is_uuid_like(value: Optional[str]) -> bool:
109-
if not value:
110-
return False
111-
try:
112-
uuid.UUID(str(value))
113-
return True
114-
except (ValueError, TypeError, AttributeError):
115-
return False
116-
117-
def resolve_tool_id(self, tool_id: str, tool_name: str) -> str:
118-
"""将 UUID 形式的 ID 映射到已有的 call_xxx ID,避免模糊匹配。"""
119-
if not tool_id:
120-
return ""
121-
if not self._is_uuid_like(tool_id):
122-
return tool_id
123-
if tool_id in self.uuid_to_tool_call_id:
124-
return self.uuid_to_tool_call_id[tool_id]
125-
126-
candidates = [
127-
existing_id
128-
for existing_id, state in self.tool_call_states.items()
129-
if not self._is_uuid_like(existing_id)
130-
and state.started
131-
and (state.name == tool_name or not tool_name)
132-
]
133-
if len(candidates) == 1:
134-
self.uuid_to_tool_call_id[tool_id] = candidates[0]
135-
return candidates[0]
136-
return tool_id
137102

138103
def end_all_tools(
139104
self, encoder: EventEncoder, exclude: Optional[str] = None
@@ -205,10 +170,6 @@ class AGUIProtocolHandler(BaseProtocolHandler):
205170
def __init__(self, config: Optional[ServerConfig] = None):
206171
self._config = config.agui if config else None
207172
self._encoder = EventEncoder()
208-
# 是否串行化工具调用(兼容 CopilotKit 等前端)
209-
self._copilotkit_compatibility = pydash.get(
210-
self._config, "copilotkit_compatibility", False
211-
)
212173

213174
def get_prefix(self) -> str:
214175
"""AG-UI 协议建议使用 /ag-ui/agent 前缀"""
@@ -402,9 +363,7 @@ async def _format_stream(
402363
Yields:
403364
SSE 格式的字符串
404365
"""
405-
state = StreamStateMachine(
406-
copilotkit_compatibility=self._copilotkit_compatibility
407-
)
366+
state = StreamStateMachine()
408367

409368
# 发送 RUN_STARTED
410369
yield self._encoder.encode(
@@ -414,37 +373,6 @@ async def _format_stream(
414373
)
415374
)
416375

417-
# 辅助函数:处理队列中的所有事件
418-
def process_pending_queue() -> Iterator[str]:
419-
"""处理队列中的所有待处理事件"""
420-
while state.pending_events:
421-
pending_event = state.pending_events.pop(0)
422-
pending_tool_id = (
423-
pending_event.data.get("id", "")
424-
if pending_event.data
425-
else ""
426-
)
427-
428-
# 如果是新的工具调用,设置为活跃
429-
if (
430-
pending_event.event == EventType.TOOL_CALL_CHUNK
431-
or pending_event.event == EventType.TOOL_CALL
432-
) and state.active_tool_id is None:
433-
state.active_tool_id = pending_tool_id
434-
435-
for sse_data in self._process_event_with_boundaries(
436-
pending_event,
437-
context,
438-
state,
439-
):
440-
if sse_data:
441-
yield sse_data
442-
443-
# 如果处理的是 TOOL_RESULT,检查是否需要继续处理队列
444-
if pending_event.event == EventType.TOOL_RESULT:
445-
if pending_tool_id == state.active_tool_id:
446-
state.active_tool_id = None
447-
448376
async for event in event_stream:
449377
# RUN_ERROR 后不再处理任何事件
450378
if state.run_errored:
@@ -454,95 +382,6 @@ def process_pending_queue() -> Iterator[str]:
454382
if event.event == EventType.ERROR:
455383
state.run_errored = True
456384

457-
# 在 copilotkit_compatibility=True 模式下,实现严格的工具调用序列化
458-
# 当一个工具调用正在进行时,其他工具的事件会被放入队列
459-
if self._copilotkit_compatibility and not state.run_errored:
460-
original_tool_id = (
461-
event.data.get("id", "") if event.data else ""
462-
)
463-
tool_name = event.data.get("name", "") if event.data else ""
464-
resolved_tool_id = state.resolve_tool_id(
465-
original_tool_id, tool_name
466-
)
467-
if resolved_tool_id and event.data is not None:
468-
event.data["id"] = resolved_tool_id
469-
tool_id = resolved_tool_id
470-
else:
471-
tool_id = original_tool_id
472-
473-
# 处理 TOOL_CALL_CHUNK 事件
474-
if event.event == EventType.TOOL_CALL_CHUNK:
475-
if state.active_tool_id is None:
476-
# 没有活跃的工具调用,直接处理
477-
state.active_tool_id = tool_id
478-
elif tool_id != state.active_tool_id:
479-
# 有其他活跃的工具调用,放入队列
480-
state.pending_events.append(event)
481-
continue
482-
# 如果是同一个工具调用,继续处理
483-
484-
# 处理 TOOL_CALL 事件
485-
elif event.event == EventType.TOOL_CALL:
486-
# TOOL_CALL 事件主要用于 UUID 到 call_xxx ID 的映射
487-
# 在 copilotkit 模式下:
488-
# 1. 结束当前活跃的工具调用(如果有)
489-
# 2. 处理队列中的事件
490-
# 3. 不将 TOOL_CALL 事件的 tool_id 设置为活跃工具
491-
if self._copilotkit_compatibility:
492-
if state.active_tool_id is not None:
493-
for sse_data in state.end_all_tools(self._encoder):
494-
yield sse_data
495-
state.active_tool_id = None
496-
# 处理队列中的事件
497-
if state.pending_events:
498-
for sse_data in process_pending_queue():
499-
yield sse_data
500-
501-
# 处理 TOOL_RESULT 事件
502-
elif event.event == EventType.TOOL_RESULT:
503-
actual_tool_id = resolved_tool_id or tool_id
504-
505-
# 如果不是当前活跃工具的结果,放入队列
506-
if (
507-
state.active_tool_id is not None
508-
and actual_tool_id != state.active_tool_id
509-
):
510-
state.pending_events.append(event)
511-
continue
512-
513-
# 标记工具调用已有结果
514-
if (
515-
actual_tool_id
516-
and actual_tool_id in state.tool_call_states
517-
):
518-
state.tool_call_states[actual_tool_id].has_result = True
519-
520-
# 处理当前事件
521-
for sse_data in self._process_event_with_boundaries(
522-
event,
523-
context,
524-
state,
525-
):
526-
if sse_data:
527-
yield sse_data
528-
529-
# 如果这是当前活跃工具的结果,处理队列中的事件
530-
if actual_tool_id == state.active_tool_id:
531-
state.active_tool_id = None
532-
# 处理队列中的事件
533-
for sse_data in process_pending_queue():
534-
yield sse_data
535-
continue
536-
537-
# 处理非工具相关事件(如 TEXT)
538-
# 需要先处理队列中的所有事件
539-
elif event.event == EventType.TEXT:
540-
# 先处理队列中的所有事件
541-
for sse_data in process_pending_queue():
542-
yield sse_data
543-
# 清除活跃工具 ID(因为我们要处理文本了)
544-
state.active_tool_id = None
545-
546385
# 处理边界事件注入
547386
for sse_data in self._process_event_with_boundaries(
548387
event,
@@ -552,15 +391,6 @@ def process_pending_queue() -> Iterator[str]:
552391
if sse_data:
553392
yield sse_data
554393

555-
# 在 copilotkit 兼容模式下,如果当前没有活跃工具且队列中有事件,处理队列
556-
if (
557-
self._copilotkit_compatibility
558-
and state.active_tool_id is None
559-
and state.pending_events
560-
):
561-
for sse_data in process_pending_queue():
562-
yield sse_data
563-
564394
# RUN_ERROR 后不发送任何清理事件
565395
if state.run_errored:
566396
return
@@ -629,49 +459,19 @@ def _process_event_with_boundaries(
629459

630460
# TOOL_CALL_CHUNK 事件:在首个 CHUNK 前注入 TOOL_CALL_START
631461
if event.event == EventType.TOOL_CALL_CHUNK:
632-
tool_id_raw = event.data.get("id", "")
462+
tool_id = event.data.get("id", "")
633463
tool_name = event.data.get("name", "")
634-
resolved_tool_id = state.resolve_tool_id(tool_id_raw, tool_name)
635-
tool_id = resolved_tool_id or tool_id_raw
636-
if tool_id and event.data is not None:
637-
event.data["id"] = tool_id
638464

639465
for sse_data in state.end_text_if_open(self._encoder):
640466
yield sse_data
641467

642-
if (
643-
state.copilotkit_compatibility
644-
and state._is_uuid_like(tool_id_raw)
645-
and tool_name
646-
):
647-
for existing_id, call_state in state.tool_call_states.items():
648-
if (
649-
not state._is_uuid_like(existing_id)
650-
and call_state.name == tool_name
651-
and call_state.started
652-
):
653-
if not call_state.ended:
654-
args_delta = event.data.get("args_delta", "")
655-
if args_delta:
656-
yield self._encoder.encode(
657-
ToolCallArgsEvent(
658-
tool_call_id=existing_id,
659-
delta=args_delta,
660-
)
661-
)
662-
return
663-
664468
need_start = False
665469
current_state = state.tool_call_states.get(tool_id)
666470
if tool_id:
667471
if current_state is None or current_state.ended:
668472
need_start = True
669473

670474
if need_start:
671-
if state.copilotkit_compatibility:
672-
for sse_data in state.end_all_tools(self._encoder):
673-
yield sse_data
674-
675475
yield self._encoder.encode(
676476
ToolCallStartEvent(
677477
tool_call_id=tool_id,
@@ -694,51 +494,20 @@ def _process_event_with_boundaries(
694494

695495
# TOOL_CALL 事件:完整的工具调用事件
696496
if event.event == EventType.TOOL_CALL:
697-
tool_id_raw = event.data.get("id", "")
497+
tool_id = event.data.get("id", "")
698498
tool_name = event.data.get("name", "")
699499
tool_args = event.data.get("args", "")
700-
resolved_tool_id = state.resolve_tool_id(tool_id_raw, tool_name)
701-
tool_id = resolved_tool_id or tool_id_raw
702-
if tool_id and event.data is not None:
703-
event.data["id"] = tool_id
704500

705501
for sse_data in state.end_text_if_open(self._encoder):
706502
yield sse_data
707503

708-
# 在 CopilotKit 兼容模式下,检查 UUID 映射
709-
if (
710-
state.copilotkit_compatibility
711-
and state._is_uuid_like(tool_id_raw)
712-
and tool_name
713-
):
714-
for existing_id, call_state in state.tool_call_states.items():
715-
if (
716-
not state._is_uuid_like(existing_id)
717-
and call_state.name == tool_name
718-
and call_state.started
719-
):
720-
if not call_state.ended:
721-
# UUID 事件可能包含参数,发送参数事件
722-
if tool_args:
723-
yield self._encoder.encode(
724-
ToolCallArgsEvent(
725-
tool_call_id=existing_id,
726-
delta=tool_args,
727-
)
728-
)
729-
return # UUID 事件已完成处理,不创建新的工具调用
730-
731504
need_start = False
732505
current_state = state.tool_call_states.get(tool_id)
733506
if tool_id:
734507
if current_state is None or current_state.ended:
735508
need_start = True
736509

737510
if need_start:
738-
if state.copilotkit_compatibility:
739-
for sse_data in state.end_all_tools(self._encoder):
740-
yield sse_data
741-
742511
yield self._encoder.encode(
743512
ToolCallStartEvent(
744513
tool_call_id=tool_id,
@@ -838,60 +607,45 @@ def _process_event_with_boundaries(
838607
if event.event == EventType.TOOL_RESULT:
839608
tool_id = event.data.get("id", "")
840609
tool_name = event.data.get("name", "")
841-
actual_tool_id = (
842-
state.resolve_tool_id(tool_id, tool_name)
843-
if state.copilotkit_compatibility
844-
else tool_id
845-
)
846-
if actual_tool_id and event.data is not None:
847-
event.data["id"] = actual_tool_id
848610

849611
for sse_data in state.end_text_if_open(self._encoder):
850612
yield sse_data
851613

852-
if state.copilotkit_compatibility:
853-
for sse_data in state.end_all_tools(
854-
self._encoder, exclude=actual_tool_id
855-
):
856-
yield sse_data
857-
858614
tool_state = (
859-
state.tool_call_states.get(actual_tool_id)
860-
if actual_tool_id
861-
else None
615+
state.tool_call_states.get(tool_id) if tool_id else None
862616
)
863-
if actual_tool_id and tool_state is None:
617+
if tool_id and tool_state is None:
864618
yield self._encoder.encode(
865619
ToolCallStartEvent(
866-
tool_call_id=actual_tool_id,
620+
tool_call_id=tool_id,
867621
tool_call_name=tool_name or "",
868622
)
869623
)
870624
tool_state = ToolCallState(
871625
name=tool_name, started=True, ended=False
872626
)
873-
state.tool_call_states[actual_tool_id] = tool_state
627+
state.tool_call_states[tool_id] = tool_state
874628

875629
if tool_state and tool_state.started and not tool_state.ended:
876630
yield self._encoder.encode(
877-
ToolCallEndEvent(tool_call_id=actual_tool_id)
631+
ToolCallEndEvent(tool_call_id=tool_id)
878632
)
879633
tool_state.ended = True
880634

881635
final_result = event.data.get("content") or event.data.get(
882636
"result", ""
883637
)
884-
if actual_tool_id:
885-
cached_chunks = state.pop_tool_result_chunks(actual_tool_id)
638+
if tool_id:
639+
cached_chunks = state.pop_tool_result_chunks(tool_id)
886640
if cached_chunks:
887641
final_result = cached_chunks + final_result
888642

889643
yield self._encoder.encode(
890644
ToolCallResultEvent(
891645
message_id=event.data.get(
892-
"message_id", f"tool-result-{actual_tool_id}"
646+
"message_id", f"tool-result-{tool_id}"
893647
),
894-
tool_call_id=actual_tool_id,
648+
tool_call_id=tool_id,
895649
content=final_result,
896650
role="tool",
897651
)

0 commit comments

Comments
 (0)