Skip to content

Commit da80f5d

Browse files
authored
Fix AG-UI parallel tool calls (#2301)
1 parent 7eb4491 commit da80f5d

File tree

3 files changed

+449
-207
lines changed

3 files changed

+449
-207
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 60 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,12 @@ async def run(
291291
if isinstance(deps, StateHandler):
292292
deps.state = run_input.state
293293

294-
history = _History.from_ag_ui(run_input.messages)
294+
messages = _messages_from_ag_ui(run_input.messages)
295295

296296
async with self.agent.iter(
297297
user_prompt=None,
298298
output_type=[output_type or self.agent.output_type, DeferredToolCalls],
299-
message_history=history.messages,
299+
message_history=messages,
300300
model=model,
301301
deps=deps,
302302
model_settings=model_settings,
@@ -305,7 +305,7 @@ async def run(
305305
infer_name=infer_name,
306306
toolsets=toolsets,
307307
) as run:
308-
async for event in self._agent_stream(run, history):
308+
async for event in self._agent_stream(run):
309309
yield encoder.encode(event)
310310
except _RunError as e:
311311
yield encoder.encode(
@@ -327,20 +327,18 @@ async def run(
327327
async def _agent_stream(
328328
self,
329329
run: AgentRun[AgentDepsT, Any],
330-
history: _History,
331330
) -> AsyncGenerator[BaseEvent, None]:
332331
"""Run the agent streaming responses using AG-UI protocol events.
333332
334333
Args:
335334
run: The agent run to process.
336-
history: The history of messages and tool calls to use for the run.
337335
338336
Yields:
339337
AG-UI Server-Sent Events (SSE).
340338
"""
341339
async for node in run:
340+
stream_ctx = _RequestStreamContext()
342341
if isinstance(node, ModelRequestNode):
343-
stream_ctx = _RequestStreamContext()
344342
async with node.stream(run.ctx) as request_stream:
345343
async for agent_event in request_stream:
346344
async for msg in self._handle_model_request_event(stream_ctx, agent_event):
@@ -352,8 +350,8 @@ async def _agent_stream(
352350
elif isinstance(node, CallToolsNode):
353351
async with node.stream(run.ctx) as handle_stream:
354352
async for event in handle_stream:
355-
if isinstance(event, FunctionToolResultEvent) and isinstance(event.result, ToolReturnPart):
356-
async for msg in self._handle_tool_result_event(event.result, history.prompt_message_id):
353+
if isinstance(event, FunctionToolResultEvent):
354+
async for msg in self._handle_tool_result_event(stream_ctx, event):
357355
yield msg
358356

359357
async def _handle_model_request_event(
@@ -391,9 +389,11 @@ async def _handle_model_request_event(
391389
delta=part.content,
392390
)
393391
elif isinstance(part, ToolCallPart): # pragma: no branch
392+
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
394393
yield ToolCallStartEvent(
395394
tool_call_id=part.tool_call_id,
396395
tool_call_name=part.tool_name,
396+
parent_message_id=message_id,
397397
)
398398
stream_ctx.part_end = ToolCallEndEvent(
399399
tool_call_id=part.tool_call_id,
@@ -403,11 +403,9 @@ async def _handle_model_request_event(
403403
yield ThinkingTextMessageStartEvent(
404404
type=EventType.THINKING_TEXT_MESSAGE_START,
405405
)
406-
# Always send the content even if it's empty, as it may be
407-
# used to indicate the start of thinking.
408406
yield ThinkingTextMessageContentEvent(
409407
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
410-
delta=part.content or '',
408+
delta=part.content,
411409
)
412410
stream_ctx.part_end = ThinkingTextMessageEndEvent(
413411
type=EventType.THINKING_TEXT_MESSAGE_END,
@@ -435,20 +433,25 @@ async def _handle_model_request_event(
435433

436434
async def _handle_tool_result_event(
437435
self,
438-
result: ToolReturnPart,
439-
prompt_message_id: str,
436+
stream_ctx: _RequestStreamContext,
437+
event: FunctionToolResultEvent,
440438
) -> AsyncGenerator[BaseEvent, None]:
441439
"""Convert a tool call result to AG-UI events.
442440
443441
Args:
444-
result: The tool call result to process.
445-
prompt_message_id: The message ID of the prompt that initiated the tool call.
442+
stream_ctx: The request stream context to manage state.
443+
event: The tool call result event to process.
446444
447445
Yields:
448446
AG-UI Server-Sent Events (SSE).
449447
"""
448+
result = event.result
449+
if not isinstance(result, ToolReturnPart):
450+
return
451+
452+
message_id = stream_ctx.new_message_id()
450453
yield ToolCallResultEvent(
451-
message_id=prompt_message_id,
454+
message_id=message_id,
452455
type=EventType.TOOL_CALL_RESULT,
453456
role='tool',
454457
tool_call_id=result.tool_call_id,
@@ -468,75 +471,55 @@ async def _handle_tool_result_event(
468471
yield item
469472

470473

471-
@dataclass
472-
class _History:
473-
"""A simple history representation for AG-UI protocol."""
474-
475-
prompt_message_id: str # The ID of the last user message.
476-
messages: list[ModelMessage]
477-
478-
@classmethod
479-
def from_ag_ui(cls, messages: list[Message]) -> _History:
480-
"""Convert a AG-UI history to a Pydantic AI one.
481-
482-
Args:
483-
messages: List of AG-UI messages to convert.
484-
485-
Returns:
486-
List of Pydantic AI model messages.
487-
"""
488-
prompt_message_id = ''
489-
result: list[ModelMessage] = []
490-
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
491-
for msg in messages:
492-
if isinstance(msg, UserMessage):
493-
prompt_message_id = msg.id
494-
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
495-
elif isinstance(msg, AssistantMessage):
496-
if msg.tool_calls:
497-
for tool_call in msg.tool_calls:
498-
tool_calls[tool_call.id] = tool_call.function.name
499-
500-
result.append(
501-
ModelResponse(
502-
parts=[
503-
ToolCallPart(
504-
tool_name=tool_call.function.name,
505-
tool_call_id=tool_call.id,
506-
args=tool_call.function.arguments,
507-
)
508-
for tool_call in msg.tool_calls
509-
]
510-
)
511-
)
512-
513-
if msg.content:
514-
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
515-
elif isinstance(msg, SystemMessage):
516-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
517-
elif isinstance(msg, ToolMessage):
518-
tool_name = tool_calls.get(msg.tool_call_id)
519-
if tool_name is None: # pragma: no cover
520-
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
474+
def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
475+
"""Convert a AG-UI history to a Pydantic AI one."""
476+
result: list[ModelMessage] = []
477+
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
478+
for msg in messages:
479+
if isinstance(msg, UserMessage):
480+
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
481+
elif isinstance(msg, AssistantMessage):
482+
if msg.tool_calls:
483+
for tool_call in msg.tool_calls:
484+
tool_calls[tool_call.id] = tool_call.function.name
521485

522486
result.append(
523-
ModelRequest(
487+
ModelResponse(
524488
parts=[
525-
ToolReturnPart(
526-
tool_name=tool_name,
527-
content=msg.content,
528-
tool_call_id=msg.tool_call_id,
489+
ToolCallPart(
490+
tool_name=tool_call.function.name,
491+
tool_call_id=tool_call.id,
492+
args=tool_call.function.arguments,
529493
)
494+
for tool_call in msg.tool_calls
530495
]
531496
)
532497
)
533-
elif isinstance(msg, DeveloperMessage): # pragma: no branch
534-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
535498

536-
return cls(
537-
prompt_message_id=prompt_message_id,
538-
messages=result,
539-
)
499+
if msg.content:
500+
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
501+
elif isinstance(msg, SystemMessage):
502+
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
503+
elif isinstance(msg, ToolMessage):
504+
tool_name = tool_calls.get(msg.tool_call_id)
505+
if tool_name is None: # pragma: no cover
506+
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
507+
508+
result.append(
509+
ModelRequest(
510+
parts=[
511+
ToolReturnPart(
512+
tool_name=tool_name,
513+
content=msg.content,
514+
tool_call_id=msg.tool_call_id,
515+
)
516+
]
517+
)
518+
)
519+
elif isinstance(msg, DeveloperMessage): # pragma: no branch
520+
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
521+
522+
return result
540523

541524

542525
@runtime_checkable

tests/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def IsFloat(*args: Any, **kwargs: Any) -> float: ...
5050
def IsInt(*args: Any, **kwargs: Any) -> int: ...
5151
def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
5252
def IsStr(*args: Any, **kwargs: Any) -> str: ...
53+
def IsSameStr(*args: Any, **kwargs: Any) -> str: ...
5354
else:
5455
from dirty_equals import IsDatetime, IsFloat, IsInstance, IsInt, IsNow as _IsNow, IsStr
5556

@@ -59,6 +60,44 @@ def IsNow(*args: Any, **kwargs: Any):
5960
kwargs['delta'] = 10
6061
return _IsNow(*args, **kwargs)
6162

63+
class IsSameStr(IsStr):
64+
"""
65+
Checks if the value is a string, and that subsequent uses have the same value as the first one.
66+
67+
Example:
68+
```python {test="skip"}
69+
assert events == [
70+
{
71+
'type': 'RUN_STARTED',
72+
'threadId': (thread_id := IsSameStr()),
73+
'runId': (run_id := IsSameStr()),
74+
},
75+
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
76+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'success '},
77+
{
78+
'type': 'TEXT_MESSAGE_CONTENT',
79+
'messageId': message_id,
80+
'delta': '(no tool calls)',
81+
},
82+
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
83+
{
84+
'type': 'RUN_FINISHED',
85+
'threadId': thread_id,
86+
'runId': run_id,
87+
},
88+
]
89+
```
90+
"""
91+
92+
_first_other: str | None = None
93+
94+
def equals(self, other: Any) -> bool:
95+
if self._first_other is None:
96+
self._first_other = other
97+
return super().equals(other)
98+
else:
99+
return other == self._first_other
100+
62101

63102
class TestEnv:
64103
__test__ = False

0 commit comments

Comments
 (0)