Skip to content

Commit 5fa446a

Browse files
authored
Ensure AG-UI ToolCallStartEvent doesn't use a parent_message_id from a previous request/response (#3325)
1 parent 59faf42 commit 5fa446a

File tree

3 files changed

+206
-6
lines changed

3 files changed

+206
-6
lines changed

pydantic_ai_slim/pydantic_ai/ui/_event_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,31 +404,31 @@ async def before_request(self) -> AsyncIterator[EventT]:
404404
405405
Override this to inject custom events at the start of the request.
406406
"""
407-
return
407+
return # pragma: lax no cover
408408
yield # Make this an async generator
409409

410410
async def after_request(self) -> AsyncIterator[EventT]:
411411
"""Yield events after a model request is processed.
412412
413413
Override this to inject custom events at the end of the request.
414414
"""
415-
return
415+
return # pragma: lax no cover
416416
yield # Make this an async generator
417417

418418
async def before_response(self) -> AsyncIterator[EventT]:
419419
"""Yield events before a model response is processed.
420420
421421
Override this to inject custom events at the start of the response.
422422
"""
423-
return
423+
return # pragma: no cover
424424
yield # Make this an async generator
425425

426426
async def after_response(self) -> AsyncIterator[EventT]:
427427
"""Yield events after a model response is processed.
428428
429429
Override this to inject custom events at the end of the response.
430430
"""
431-
return
431+
return # pragma: lax no cover
432432
yield # Make this an async generator
433433

434434
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[EventT]:

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ async def before_stream(self) -> AsyncIterator[BaseEvent]:
9292
run_id=self.run_input.run_id,
9393
)
9494

95+
async def before_response(self) -> AsyncIterator[BaseEvent]:
96+
# Prevent parts from a subsequent response being tied to parts from an earlier response.
97+
# See https://github.com/pydantic/pydantic-ai/issues/3316
98+
self.new_message_id()
99+
return
100+
yield # Make this an async generator
101+
95102
async def after_stream(self) -> AsyncIterator[BaseEvent]:
96103
if not self._error:
97104
yield RunFinishedEvent(
@@ -167,9 +174,11 @@ async def _handle_tool_call_start(
167174
self, part: ToolCallPart | BuiltinToolCallPart, tool_call_id: str | None = None
168175
) -> AsyncIterator[BaseEvent]:
169176
tool_call_id = tool_call_id or part.tool_call_id
170-
message_id = self.message_id or self.new_message_id()
177+
parent_message_id = self.message_id
171178

172-
yield ToolCallStartEvent(tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=message_id)
179+
yield ToolCallStartEvent(
180+
tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=parent_message_id
181+
)
173182
if part.args:
174183
yield ToolCallArgsEvent(tool_call_id=tool_call_id, delta=part.args_as_json_str())
175184

tests/test_ag_ui.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pydantic_ai import (
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
22+
FunctionToolCallEvent,
23+
FunctionToolResultEvent,
2224
ModelMessage,
2325
ModelRequest,
2426
ModelResponse,
@@ -29,6 +31,7 @@
2931
TextPart,
3032
TextPartDelta,
3133
ToolCallPart,
34+
ToolCallPartDelta,
3235
ToolReturn,
3336
ToolReturnPart,
3437
UserPromptPart,
@@ -1661,6 +1664,194 @@ async def event_generator():
16611664
)
16621665

16631666

1667+
async def test_event_stream_multiple_responses_with_tool_calls():
1668+
async def event_generator():
1669+
yield PartStartEvent(index=0, part=TextPart(content='Hello'))
1670+
yield PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' world'))
1671+
yield PartEndEvent(index=0, part=TextPart(content='Hello world'), next_part_kind='tool-call')
1672+
1673+
yield PartStartEvent(
1674+
index=1,
1675+
part=ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'),
1676+
previous_part_kind='text',
1677+
)
1678+
yield PartDeltaEvent(
1679+
index=1, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_1')
1680+
)
1681+
yield PartEndEvent(
1682+
index=1,
1683+
part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1'),
1684+
next_part_kind='tool-call',
1685+
)
1686+
1687+
yield PartStartEvent(
1688+
index=2,
1689+
part=ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'),
1690+
previous_part_kind='tool-call',
1691+
)
1692+
yield PartDeltaEvent(
1693+
index=2, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_2')
1694+
)
1695+
yield PartEndEvent(
1696+
index=2,
1697+
part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Hello world"}', tool_call_id='tool_call_2'),
1698+
next_part_kind=None,
1699+
)
1700+
1701+
yield FunctionToolCallEvent(
1702+
part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1')
1703+
)
1704+
yield FunctionToolCallEvent(
1705+
part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Goodbye world"}', tool_call_id='tool_call_2')
1706+
)
1707+
1708+
yield FunctionToolResultEvent(
1709+
result=ToolReturnPart(tool_name='tool_call_1', content='Hi!', tool_call_id='tool_call_1')
1710+
)
1711+
yield FunctionToolResultEvent(
1712+
result=ToolReturnPart(tool_name='tool_call_2', content='Bye!', tool_call_id='tool_call_2')
1713+
)
1714+
1715+
yield PartStartEvent(
1716+
index=0,
1717+
part=ToolCallPart(tool_name='tool_call_3', args='{}', tool_call_id='tool_call_3'),
1718+
previous_part_kind=None,
1719+
)
1720+
yield PartDeltaEvent(
1721+
index=0, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_3')
1722+
)
1723+
yield PartEndEvent(
1724+
index=0,
1725+
part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3'),
1726+
next_part_kind='tool-call',
1727+
)
1728+
1729+
yield PartStartEvent(
1730+
index=1,
1731+
part=ToolCallPart(tool_name='tool_call_4', args='{}', tool_call_id='tool_call_4'),
1732+
previous_part_kind='tool-call',
1733+
)
1734+
yield PartDeltaEvent(
1735+
index=1, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_4')
1736+
)
1737+
yield PartEndEvent(
1738+
index=1,
1739+
part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4'),
1740+
next_part_kind=None,
1741+
)
1742+
1743+
yield FunctionToolCallEvent(
1744+
part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3')
1745+
)
1746+
yield FunctionToolCallEvent(
1747+
part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4')
1748+
)
1749+
1750+
yield FunctionToolResultEvent(
1751+
result=ToolReturnPart(tool_name='tool_call_3', content='Hi!', tool_call_id='tool_call_3')
1752+
)
1753+
yield FunctionToolResultEvent(
1754+
result=ToolReturnPart(tool_name='tool_call_4', content='Bye!', tool_call_id='tool_call_4')
1755+
)
1756+
1757+
run_input = create_input(
1758+
UserMessage(
1759+
id='msg_1',
1760+
content='Tell me about Hello World',
1761+
),
1762+
)
1763+
event_stream = AGUIEventStream(run_input=run_input)
1764+
events = [
1765+
json.loads(event.removeprefix('data: '))
1766+
async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator()))
1767+
]
1768+
1769+
assert events == snapshot(
1770+
[
1771+
{
1772+
'type': 'RUN_STARTED',
1773+
'threadId': (thread_id := IsSameStr()),
1774+
'runId': (run_id := IsSameStr()),
1775+
},
1776+
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
1777+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'Hello'},
1778+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': ' world'},
1779+
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
1780+
{
1781+
'type': 'TOOL_CALL_START',
1782+
'toolCallId': 'tool_call_1',
1783+
'toolCallName': 'tool_call_1',
1784+
'parentMessageId': message_id,
1785+
},
1786+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{}'},
1787+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{"query": "Hello world"}'},
1788+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_1'},
1789+
{
1790+
'type': 'TOOL_CALL_START',
1791+
'toolCallId': 'tool_call_2',
1792+
'toolCallName': 'tool_call_2',
1793+
'parentMessageId': message_id,
1794+
},
1795+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{}'},
1796+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{"query": "Goodbye world"}'},
1797+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_2'},
1798+
{
1799+
'type': 'TOOL_CALL_RESULT',
1800+
'messageId': IsStr(),
1801+
'toolCallId': 'tool_call_1',
1802+
'content': 'Hi!',
1803+
'role': 'tool',
1804+
},
1805+
{
1806+
'type': 'TOOL_CALL_RESULT',
1807+
'messageId': (result_message_id := IsSameStr()),
1808+
'toolCallId': 'tool_call_2',
1809+
'content': 'Bye!',
1810+
'role': 'tool',
1811+
},
1812+
{
1813+
'type': 'TOOL_CALL_START',
1814+
'toolCallId': 'tool_call_3',
1815+
'toolCallName': 'tool_call_3',
1816+
'parentMessageId': (new_message_id := IsSameStr()),
1817+
},
1818+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{}'},
1819+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{"query": "Hello world"}'},
1820+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_3'},
1821+
{
1822+
'type': 'TOOL_CALL_START',
1823+
'toolCallId': 'tool_call_4',
1824+
'toolCallName': 'tool_call_4',
1825+
'parentMessageId': new_message_id,
1826+
},
1827+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{}'},
1828+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{"query": "Goodbye world"}'},
1829+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_4'},
1830+
{
1831+
'type': 'TOOL_CALL_RESULT',
1832+
'messageId': IsStr(),
1833+
'toolCallId': 'tool_call_3',
1834+
'content': 'Hi!',
1835+
'role': 'tool',
1836+
},
1837+
{
1838+
'type': 'TOOL_CALL_RESULT',
1839+
'messageId': IsStr(),
1840+
'toolCallId': 'tool_call_4',
1841+
'content': 'Bye!',
1842+
'role': 'tool',
1843+
},
1844+
{
1845+
'type': 'RUN_FINISHED',
1846+
'threadId': thread_id,
1847+
'runId': run_id,
1848+
},
1849+
]
1850+
)
1851+
1852+
assert result_message_id != new_message_id
1853+
1854+
16641855
async def test_handle_ag_ui_request():
16651856
agent = Agent(model=TestModel())
16661857
run_input = create_input(

0 commit comments

Comments
 (0)