Skip to content

Commit 6f51053

Browse files
committed
Refactor AG-UI streaming
1 parent 03862a5 commit 6f51053

File tree

7 files changed

+316
-177
lines changed

7 files changed

+316
-177
lines changed

pydantic_ai_slim/pydantic_ai/ui/adapter.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
138138
def dump_messages(self, messages: Sequence[ModelMessage]) -> list[MessageT]:
139139
"""Dump messages to the request and return the dumped messages."""
140140

141-
@cached_property
141+
@property
142142
@abstractmethod
143143
def event_stream(self) -> BaseEventStream[RunRequestT, EventT, AgentDepsT]:
144144
"""Create an event stream for the adapter."""
@@ -165,11 +165,6 @@ def raw_state(self) -> dict[str, Any] | None:
165165
"""Get the state of the agent run."""
166166
return None
167167

168-
@property
169-
def result(self) -> AgentRunResult | None:
170-
"""Get the result of the agent run."""
171-
return self.event_stream.result
172-
173168
@property
174169
def response_headers(self) -> Mapping[str, str] | None:
175170
"""Get the response headers for the adapter."""
@@ -283,6 +278,21 @@ async def run_stream(
283278
):
284279
yield event
285280

281+
async def stream_response(self, stream: AsyncIterator[EventT], accept: str | None = None) -> Response:
282+
"""Stream a response to the client.
283+
284+
Args:
285+
stream: The stream of events to encode.
286+
accept: The accept header value for encoding format.
287+
"""
288+
return StreamingResponse(
289+
self.encode_stream(
290+
stream,
291+
accept=accept,
292+
),
293+
headers=self.response_headers,
294+
)
295+
286296
@classmethod
287297
async def dispatch_request(
288298
cls,
@@ -334,22 +344,18 @@ async def dispatch_request(
334344

335345
adapter = cls(agent=agent, request=request_data)
336346

337-
return StreamingResponse(
338-
adapter.encode_stream(
339-
adapter.run_stream(
340-
message_history=message_history,
341-
deferred_tool_results=deferred_tool_results,
342-
deps=deps,
343-
output_type=output_type,
344-
model=model,
345-
model_settings=model_settings,
346-
usage_limits=usage_limits,
347-
usage=usage,
348-
infer_name=infer_name,
349-
toolsets=toolsets,
350-
on_complete=on_complete,
351-
),
352-
accept=request.headers.get('accept'),
353-
),
354-
headers=adapter.response_headers,
347+
run_stream = adapter.run_stream(
348+
message_history=message_history,
349+
deferred_tool_results=deferred_tool_results,
350+
deps=deps,
351+
output_type=output_type,
352+
model=model,
353+
model_settings=model_settings,
354+
usage_limits=usage_limits,
355+
usage=usage,
356+
infer_name=infer_name,
357+
toolsets=toolsets,
358+
on_complete=on_complete,
355359
)
360+
361+
return await adapter.stream_response(run_stream, accept=request.headers.get('accept'))

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def dump_messages(self, messages: Sequence[ModelMessage]) -> list[Message]:
105105
# TODO (DouweM): bring in from https://github.com/pydantic/pydantic-ai/pull/3068
106106
raise NotImplementedError
107107

108-
@cached_property
108+
@property
109109
def event_stream(self) -> BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT]:
110110
"""Create an event stream for the adapter."""
111111
return AGUIEventStream(self.request)

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 67 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from ...messages import (
1414
BuiltinToolCallPart,
1515
BuiltinToolReturnPart,
16-
FinalResultEvent,
17-
FunctionToolCallEvent,
1816
FunctionToolResultEvent,
17+
ModelResponsePart,
1918
TextPart,
2019
TextPartDelta,
2120
ThinkingPart,
@@ -79,8 +78,7 @@ class AGUIEventStream(BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT]):
7978
def __init__(self, request: RunAgentInput) -> None:
8079
"""Initialize AG-UI event stream state."""
8180
super().__init__(request)
82-
self.part_end: BaseEvent | None = None
83-
self.thinking: bool = False
81+
self.thinking_text = False
8482
self.builtin_tool_call_ids: dict[str, str] = {}
8583

8684
def encode_event(self, event: BaseEvent, accept: str | None = None) -> str:
@@ -105,104 +103,99 @@ async def before_stream(self) -> AsyncIterator[BaseEvent]:
105103

106104
async def after_stream(self) -> AsyncIterator[BaseEvent]:
107105
"""Handle an AgentRunResultEvent, cleaning up any pending state."""
108-
# Emit any pending part end event
109-
if self.part_end: # pragma: no branch
110-
yield self.part_end
111-
self.part_end = None
112-
113-
# End thinking mode if still active
114-
if self.thinking:
115-
yield ThinkingEndEvent(
116-
type=EventType.THINKING_END,
117-
)
118-
self.thinking = False
119-
120-
# Emit finish event
121106
yield RunFinishedEvent(
122107
thread_id=self.request.thread_id,
123108
run_id=self.request.run_id,
124109
)
125110

126111
async def on_error(self, error: Exception) -> AsyncIterator[BaseEvent]:
127112
"""Handle errors during streaming."""
128-
# Try to get code from exception if it has one, otherwise use class name
129-
code = getattr(error, 'code', error.__class__.__name__)
130-
yield RunErrorEvent(message=str(error), code=code)
113+
yield RunErrorEvent(message=str(error))
131114

132-
# Granular handlers implementation
133-
134-
async def handle_text_start(self, part: TextPart) -> AsyncIterator[BaseEvent]:
115+
async def handle_text_start(
116+
self, part: TextPart, previous_part: ModelResponsePart | None = None
117+
) -> AsyncIterator[BaseEvent]:
135118
"""Handle a TextPart at start."""
136-
if self.part_end:
137-
yield self.part_end
138-
self.part_end = None
139-
140-
if self.thinking:
141-
yield ThinkingEndEvent(type=EventType.THINKING_END)
142-
self.thinking = False
119+
if isinstance(previous_part, TextPart):
120+
message_id = previous_part.message_id
121+
else:
122+
message_id = self.new_message_id()
123+
yield TextMessageStartEvent(message_id=message_id)
143124

144-
message_id = self.new_message_id()
145-
yield TextMessageStartEvent(message_id=message_id)
146125
if part.content: # pragma: no branch
147126
yield TextMessageContentEvent(message_id=message_id, delta=part.content)
148-
self.part_end = TextMessageEndEvent(message_id=message_id)
149127

150128
async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[BaseEvent]:
151129
"""Handle a TextPartDelta."""
152130
if delta.content_delta: # pragma: no branch
153131
yield TextMessageContentEvent(message_id=self.message_id, delta=delta.content_delta)
154132

155-
async def handle_thinking_start(self, part: ThinkingPart) -> AsyncIterator[BaseEvent]:
156-
"""Handle a ThinkingPart at start."""
157-
if self.part_end:
158-
yield self.part_end
159-
self.part_end = None
133+
async def handle_text_end(
134+
self, part: TextPart, next_part: ModelResponsePart | None = None
135+
) -> AsyncIterator[BaseEvent]:
136+
"""Handle a TextPart at end."""
137+
if not isinstance(next_part, TextPart):
138+
yield TextMessageEndEvent(message_id=self.message_id)
160139

161-
if not self.thinking:
140+
async def handle_thinking_start(
141+
self, part: ThinkingPart, previous_part: ModelResponsePart | None = None
142+
) -> AsyncIterator[BaseEvent]:
143+
"""Handle a ThinkingPart at start."""
144+
if not isinstance(previous_part, ThinkingPart):
162145
yield ThinkingStartEvent(type=EventType.THINKING_START)
163-
self.thinking = True
164146

165147
if part.content:
166148
yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START)
167149
yield ThinkingTextMessageContentEvent(type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=part.content)
168-
self.part_end = ThinkingTextMessageEndEvent(type=EventType.THINKING_TEXT_MESSAGE_END)
150+
self.thinking_text = True
169151

170152
async def handle_thinking_delta(self, delta: ThinkingPartDelta) -> AsyncIterator[BaseEvent]:
171153
"""Handle a ThinkingPartDelta."""
172-
if delta.content_delta: # pragma: no branch
173-
if not isinstance(self.part_end, ThinkingTextMessageEndEvent):
174-
yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START)
175-
self.part_end = ThinkingTextMessageEndEvent(type=EventType.THINKING_TEXT_MESSAGE_END)
154+
if not delta.content_delta:
155+
return
176156

177-
yield ThinkingTextMessageContentEvent(
178-
type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=delta.content_delta
179-
)
157+
if not self.thinking_text:
158+
yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START)
159+
self.thinking_text = True
180160

181-
async def handle_tool_call_start(self, part: ToolCallPart | BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
182-
"""Handle a ToolCallPart or BuiltinToolCallPart at start."""
183-
if self.part_end:
184-
yield self.part_end
185-
self.part_end = None
161+
yield ThinkingTextMessageContentEvent(type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=delta.content_delta)
186162

187-
if self.thinking:
163+
async def handle_thinking_end(
164+
self, part: ThinkingPart, next_part: ModelResponsePart | None = None
165+
) -> AsyncIterator[BaseEvent]:
166+
"""Handle a ThinkingPart at end."""
167+
if self.thinking_text:
168+
yield ThinkingTextMessageEndEvent(type=EventType.THINKING_TEXT_MESSAGE_END)
169+
self.thinking_text = False
170+
171+
if not isinstance(next_part, ThinkingPart):
188172
yield ThinkingEndEvent(type=EventType.THINKING_END)
189-
self.thinking = False
190173

174+
async def handle_tool_call_start(self, part: ToolCallPart | BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
175+
"""Handle a ToolCallPart or BuiltinToolCallPart at start."""
176+
async for e in self._handle_tool_call_start(part):
177+
yield e
178+
179+
async def handle_builtin_tool_call_start(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
180+
"""Handle a BuiltinToolCallPart at start."""
191181
tool_call_id = part.tool_call_id
192-
if isinstance(part, BuiltinToolCallPart):
193-
builtin_tool_call_id = '|'.join([BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', tool_call_id])
194-
self.builtin_tool_call_ids[tool_call_id] = builtin_tool_call_id
195-
tool_call_id = builtin_tool_call_id
182+
builtin_tool_call_id = '|'.join([BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', tool_call_id])
183+
self.builtin_tool_call_ids[tool_call_id] = builtin_tool_call_id
184+
tool_call_id = builtin_tool_call_id
196185

186+
async for e in self._handle_tool_call_start(part, tool_call_id):
187+
yield e
188+
189+
async def _handle_tool_call_start(
190+
self, part: ToolCallPart | BuiltinToolCallPart, tool_call_id: str | None = None
191+
) -> AsyncIterator[BaseEvent]:
192+
"""Handle a ToolCallPart or BuiltinToolCallPart at start."""
193+
tool_call_id = tool_call_id or part.tool_call_id
197194
message_id = self.message_id or self.new_message_id()
195+
198196
yield ToolCallStartEvent(tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=message_id)
199197
if part.args:
200198
yield ToolCallArgsEvent(tool_call_id=tool_call_id, delta=part.args_as_json_str())
201-
self.part_end = ToolCallEndEvent(tool_call_id=tool_call_id)
202-
203-
def handle_builtin_tool_call_start(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
204-
"""Handle a BuiltinToolCallPart at start."""
205-
return self.handle_tool_call_start(part)
206199

207200
async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterator[BaseEvent]:
208201
"""Handle a ToolCallPartDelta."""
@@ -215,13 +208,16 @@ async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterato
215208
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
216209
)
217210

211+
async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[BaseEvent]:
212+
"""Handle a ToolCallPart at end."""
213+
yield ToolCallEndEvent(tool_call_id=part.tool_call_id)
214+
215+
async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseEvent]:
216+
"""Handle a BuiltinToolCallPart at end."""
217+
yield ToolCallEndEvent(tool_call_id=self.builtin_tool_call_ids[part.tool_call_id])
218+
218219
async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseEvent]:
219220
"""Handle a BuiltinToolReturnPart."""
220-
# Emit any pending part_end event (e.g., TOOL_CALL_END) before the result
221-
if self.part_end:
222-
yield self.part_end
223-
self.part_end = None
224-
225221
tool_call_id = self.builtin_tool_call_ids[part.tool_call_id]
226222
yield ToolCallResultEvent(
227223
message_id=self.new_message_id(),
@@ -231,26 +227,13 @@ async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> Async
231227
content=part.model_response_str(),
232228
)
233229

234-
async def handle_function_tool_call(self, event: FunctionToolCallEvent) -> AsyncIterator[BaseEvent]:
235-
"""Handle a FunctionToolCallEvent.
236-
237-
This event is emitted when a function tool is called, but no AG-UI events
238-
are needed at this stage since tool calls are handled in PartStartEvent.
239-
"""
240-
return
241-
yield # Make this an async generator
242-
243230
async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[BaseEvent]:
244231
"""Handle a FunctionToolResultEvent, emitting tool result events."""
245232
result = event.result
246233
if not isinstance(result, ToolReturnPart):
234+
# TODO (DouweM): Stream RetryPromptParts to the frontend as well?
247235
return
248236

249-
# Emit any pending part_end event (e.g., TOOL_CALL_END) before the result
250-
if self.part_end:
251-
yield self.part_end
252-
self.part_end = None
253-
254237
yield ToolCallResultEvent(
255238
message_id=self.new_message_id(),
256239
type=EventType.TOOL_CALL_RESULT,
@@ -271,11 +254,4 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
271254
if isinstance(item, BaseEvent): # pragma: no branch
272255
yield item
273256

274-
async def handle_final_result(self, event: FinalResultEvent) -> AsyncIterator[BaseEvent]:
275-
"""Handle a FinalResultEvent.
276-
277-
This event is emitted when the agent produces a final result, but no AG-UI events
278-
are needed at this stage.
279-
"""
280-
return
281-
yield # Make this an async generator
257+
# TODO (DouweM): Stream ToolCallResultEvent.content as user parts?

0 commit comments

Comments
 (0)