1313from ...messages import (
1414 BuiltinToolCallPart ,
1515 BuiltinToolReturnPart ,
16- FinalResultEvent ,
17- FunctionToolCallEvent ,
1816 FunctionToolResultEvent ,
17+ ModelResponsePart ,
1918 TextPart ,
2019 TextPartDelta ,
2120 ThinkingPart ,
@@ -79,8 +78,7 @@ class AGUIEventStream(BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT]):
7978 def __init__ (self , request : RunAgentInput ) -> None :
8079 """Initialize AG-UI event stream state."""
8180 super ().__init__ (request )
82- self .part_end : BaseEvent | None = None
83- self .thinking : bool = False
81+ self .thinking_text = False
8482 self .builtin_tool_call_ids : dict [str , str ] = {}
8583
8684 def encode_event (self , event : BaseEvent , accept : str | None = None ) -> str :
@@ -105,104 +103,99 @@ async def before_stream(self) -> AsyncIterator[BaseEvent]:
105103
106104 async def after_stream (self ) -> AsyncIterator [BaseEvent ]:
107105 """Handle an AgentRunResultEvent, cleaning up any pending state."""
108- # Emit any pending part end event
109- if self .part_end : # pragma: no branch
110- yield self .part_end
111- self .part_end = None
112-
113- # End thinking mode if still active
114- if self .thinking :
115- yield ThinkingEndEvent (
116- type = EventType .THINKING_END ,
117- )
118- self .thinking = False
119-
120- # Emit finish event
121106 yield RunFinishedEvent (
122107 thread_id = self .request .thread_id ,
123108 run_id = self .request .run_id ,
124109 )
125110
126111 async def on_error (self , error : Exception ) -> AsyncIterator [BaseEvent ]:
127112 """Handle errors during streaming."""
128- # Try to get code from exception if it has one, otherwise use class name
129- code = getattr (error , 'code' , error .__class__ .__name__ )
130- yield RunErrorEvent (message = str (error ), code = code )
113+ yield RunErrorEvent (message = str (error ))
131114
132- # Granular handlers implementation
133-
134- async def handle_text_start ( self , part : TextPart ) -> AsyncIterator [BaseEvent ]:
115+ async def handle_text_start (
116+ self , part : TextPart , previous_part : ModelResponsePart | None = None
117+ ) -> AsyncIterator [BaseEvent ]:
135118 """Handle a TextPart at start."""
136- if self .part_end :
137- yield self .part_end
138- self .part_end = None
139-
140- if self .thinking :
141- yield ThinkingEndEvent (type = EventType .THINKING_END )
142- self .thinking = False
119+ if isinstance (previous_part , TextPart ):
120+ message_id = previous_part .message_id
121+ else :
122+ message_id = self .new_message_id ()
123+ yield TextMessageStartEvent (message_id = message_id )
143124
144- message_id = self .new_message_id ()
145- yield TextMessageStartEvent (message_id = message_id )
146125 if part .content : # pragma: no branch
147126 yield TextMessageContentEvent (message_id = message_id , delta = part .content )
148- self .part_end = TextMessageEndEvent (message_id = message_id )
149127
150128 async def handle_text_delta (self , delta : TextPartDelta ) -> AsyncIterator [BaseEvent ]:
151129 """Handle a TextPartDelta."""
152130 if delta .content_delta : # pragma: no branch
153131 yield TextMessageContentEvent (message_id = self .message_id , delta = delta .content_delta )
154132
155- async def handle_thinking_start (self , part : ThinkingPart ) -> AsyncIterator [BaseEvent ]:
156- """Handle a ThinkingPart at start."""
157- if self .part_end :
158- yield self .part_end
159- self .part_end = None
133+ async def handle_text_end (
134+ self , part : TextPart , next_part : ModelResponsePart | None = None
135+ ) -> AsyncIterator [BaseEvent ]:
136+ """Handle a TextPart at end."""
137+ if not isinstance (next_part , TextPart ):
138+ yield TextMessageEndEvent (message_id = self .message_id )
160139
161- if not self .thinking :
140+ async def handle_thinking_start (
141+ self , part : ThinkingPart , previous_part : ModelResponsePart | None = None
142+ ) -> AsyncIterator [BaseEvent ]:
143+ """Handle a ThinkingPart at start."""
144+ if not isinstance (previous_part , ThinkingPart ):
162145 yield ThinkingStartEvent (type = EventType .THINKING_START )
163- self .thinking = True
164146
165147 if part .content :
166148 yield ThinkingTextMessageStartEvent (type = EventType .THINKING_TEXT_MESSAGE_START )
167149 yield ThinkingTextMessageContentEvent (type = EventType .THINKING_TEXT_MESSAGE_CONTENT , delta = part .content )
168- self .part_end = ThinkingTextMessageEndEvent ( type = EventType . THINKING_TEXT_MESSAGE_END )
150+ self .thinking_text = True
169151
170152 async def handle_thinking_delta (self , delta : ThinkingPartDelta ) -> AsyncIterator [BaseEvent ]:
171153 """Handle a ThinkingPartDelta."""
172- if delta .content_delta : # pragma: no branch
173- if not isinstance (self .part_end , ThinkingTextMessageEndEvent ):
174- yield ThinkingTextMessageStartEvent (type = EventType .THINKING_TEXT_MESSAGE_START )
175- self .part_end = ThinkingTextMessageEndEvent (type = EventType .THINKING_TEXT_MESSAGE_END )
154+ if not delta .content_delta :
155+ return
176156
177- yield ThinkingTextMessageContentEvent (
178- type = EventType .THINKING_TEXT_MESSAGE_CONTENT , delta = delta . content_delta
179- )
157+ if not self . thinking_text :
158+ yield ThinkingTextMessageStartEvent ( type = EventType .THINKING_TEXT_MESSAGE_START )
159+ self . thinking_text = True
180160
181- async def handle_tool_call_start (self , part : ToolCallPart | BuiltinToolCallPart ) -> AsyncIterator [BaseEvent ]:
182- """Handle a ToolCallPart or BuiltinToolCallPart at start."""
183- if self .part_end :
184- yield self .part_end
185- self .part_end = None
161+ yield ThinkingTextMessageContentEvent (type = EventType .THINKING_TEXT_MESSAGE_CONTENT , delta = delta .content_delta )
186162
187- if self .thinking :
163+ async def handle_thinking_end (
164+ self , part : ThinkingPart , next_part : ModelResponsePart | None = None
165+ ) -> AsyncIterator [BaseEvent ]:
166+ """Handle a ThinkingPart at end."""
167+ if self .thinking_text :
168+ yield ThinkingTextMessageEndEvent (type = EventType .THINKING_TEXT_MESSAGE_END )
169+ self .thinking_text = False
170+
171+ if not isinstance (next_part , ThinkingPart ):
188172 yield ThinkingEndEvent (type = EventType .THINKING_END )
189- self .thinking = False
190173
174+ async def handle_tool_call_start (self , part : ToolCallPart | BuiltinToolCallPart ) -> AsyncIterator [BaseEvent ]:
175+ """Handle a ToolCallPart or BuiltinToolCallPart at start."""
176+ async for e in self ._handle_tool_call_start (part ):
177+ yield e
178+
179+ async def handle_builtin_tool_call_start (self , part : BuiltinToolCallPart ) -> AsyncIterator [BaseEvent ]:
180+ """Handle a BuiltinToolCallPart at start."""
191181 tool_call_id = part .tool_call_id
192- if isinstance (part , BuiltinToolCallPart ):
193- builtin_tool_call_id = '|' .join ([BUILTIN_TOOL_CALL_ID_PREFIX , part .provider_name or '' , tool_call_id ])
194- self .builtin_tool_call_ids [tool_call_id ] = builtin_tool_call_id
195- tool_call_id = builtin_tool_call_id
182+ builtin_tool_call_id = '|' .join ([BUILTIN_TOOL_CALL_ID_PREFIX , part .provider_name or '' , tool_call_id ])
183+ self .builtin_tool_call_ids [tool_call_id ] = builtin_tool_call_id
184+ tool_call_id = builtin_tool_call_id
196185
186+ async for e in self ._handle_tool_call_start (part , tool_call_id ):
187+ yield e
188+
189+ async def _handle_tool_call_start (
190+ self , part : ToolCallPart | BuiltinToolCallPart , tool_call_id : str | None = None
191+ ) -> AsyncIterator [BaseEvent ]:
192+ """Handle a ToolCallPart or BuiltinToolCallPart at start."""
193+ tool_call_id = tool_call_id or part .tool_call_id
197194 message_id = self .message_id or self .new_message_id ()
195+
198196 yield ToolCallStartEvent (tool_call_id = tool_call_id , tool_call_name = part .tool_name , parent_message_id = message_id )
199197 if part .args :
200198 yield ToolCallArgsEvent (tool_call_id = tool_call_id , delta = part .args_as_json_str ())
201- self .part_end = ToolCallEndEvent (tool_call_id = tool_call_id )
202-
203- def handle_builtin_tool_call_start (self , part : BuiltinToolCallPart ) -> AsyncIterator [BaseEvent ]:
204- """Handle a BuiltinToolCallPart at start."""
205- return self .handle_tool_call_start (part )
206199
207200 async def handle_tool_call_delta (self , delta : ToolCallPartDelta ) -> AsyncIterator [BaseEvent ]:
208201 """Handle a ToolCallPartDelta."""
@@ -215,13 +208,16 @@ async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterato
215208 delta = delta .args_delta if isinstance (delta .args_delta , str ) else json .dumps (delta .args_delta ),
216209 )
217210
211+ async def handle_tool_call_end (self , part : ToolCallPart ) -> AsyncIterator [BaseEvent ]:
212+ """Handle a ToolCallPart at end."""
213+ yield ToolCallEndEvent (tool_call_id = part .tool_call_id )
214+
215+ async def handle_builtin_tool_call_end (self , part : BuiltinToolCallPart ) -> AsyncIterator [BaseEvent ]:
216+ """Handle a BuiltinToolCallPart at end."""
217+ yield ToolCallEndEvent (tool_call_id = self .builtin_tool_call_ids [part .tool_call_id ])
218+
218219 async def handle_builtin_tool_return (self , part : BuiltinToolReturnPart ) -> AsyncIterator [BaseEvent ]:
219220 """Handle a BuiltinToolReturnPart."""
220- # Emit any pending part_end event (e.g., TOOL_CALL_END) before the result
221- if self .part_end :
222- yield self .part_end
223- self .part_end = None
224-
225221 tool_call_id = self .builtin_tool_call_ids [part .tool_call_id ]
226222 yield ToolCallResultEvent (
227223 message_id = self .new_message_id (),
@@ -231,26 +227,13 @@ async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> Async
231227 content = part .model_response_str (),
232228 )
233229
234- async def handle_function_tool_call (self , event : FunctionToolCallEvent ) -> AsyncIterator [BaseEvent ]:
235- """Handle a FunctionToolCallEvent.
236-
237- This event is emitted when a function tool is called, but no AG-UI events
238- are needed at this stage since tool calls are handled in PartStartEvent.
239- """
240- return
241- yield # Make this an async generator
242-
243230 async def handle_function_tool_result (self , event : FunctionToolResultEvent ) -> AsyncIterator [BaseEvent ]:
244231 """Handle a FunctionToolResultEvent, emitting tool result events."""
245232 result = event .result
246233 if not isinstance (result , ToolReturnPart ):
234+ # TODO (DouweM): Stream RetryPromptParts to the frontend as well?
247235 return
248236
249- # Emit any pending part_end event (e.g., TOOL_CALL_END) before the result
250- if self .part_end :
251- yield self .part_end
252- self .part_end = None
253-
254237 yield ToolCallResultEvent (
255238 message_id = self .new_message_id (),
256239 type = EventType .TOOL_CALL_RESULT ,
@@ -271,11 +254,4 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
271254 if isinstance (item , BaseEvent ): # pragma: no branch
272255 yield item
273256
274- async def handle_final_result (self , event : FinalResultEvent ) -> AsyncIterator [BaseEvent ]:
275- """Handle a FinalResultEvent.
276-
277- This event is emitted when the agent produces a final result, but no AG-UI events
278- are needed at this stage.
279- """
280- return
281- yield # Make this an async generator
257+ # TODO (DouweM): Stream ToolCallResultEvent.content as user parts?
0 commit comments