|
19 | 19 | from pydantic_ai import ( |
20 | 20 | BuiltinToolCallPart, |
21 | 21 | BuiltinToolReturnPart, |
| 22 | + FunctionToolCallEvent, |
| 23 | + FunctionToolResultEvent, |
22 | 24 | ModelMessage, |
23 | 25 | ModelRequest, |
24 | 26 | ModelResponse, |
|
29 | 31 | TextPart, |
30 | 32 | TextPartDelta, |
31 | 33 | ToolCallPart, |
| 34 | + ToolCallPartDelta, |
32 | 35 | ToolReturn, |
33 | 36 | ToolReturnPart, |
34 | 37 | UserPromptPart, |
@@ -1661,6 +1664,194 @@ async def event_generator(): |
1661 | 1664 | ) |
1662 | 1665 |
|
1663 | 1666 |
|
| 1667 | +async def test_event_stream_multiple_responses_with_tool_calls(): |
| 1668 | + async def event_generator(): |
| 1669 | + yield PartStartEvent(index=0, part=TextPart(content='Hello')) |
| 1670 | + yield PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' world')) |
| 1671 | + yield PartEndEvent(index=0, part=TextPart(content='Hello world'), next_part_kind='tool-call') |
| 1672 | + |
| 1673 | + yield PartStartEvent( |
| 1674 | + index=1, |
| 1675 | + part=ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'), |
| 1676 | + previous_part_kind='text', |
| 1677 | + ) |
| 1678 | + yield PartDeltaEvent( |
| 1679 | + index=1, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_1') |
| 1680 | + ) |
| 1681 | + yield PartEndEvent( |
| 1682 | + index=1, |
| 1683 | + part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1'), |
| 1684 | + next_part_kind='tool-call', |
| 1685 | + ) |
| 1686 | + |
| 1687 | + yield PartStartEvent( |
| 1688 | + index=2, |
| 1689 | + part=ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'), |
| 1690 | + previous_part_kind='tool-call', |
| 1691 | + ) |
| 1692 | + yield PartDeltaEvent( |
| 1693 | + index=2, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_2') |
| 1694 | + ) |
| 1695 | + yield PartEndEvent( |
| 1696 | + index=2, |
| 1697 | + part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Hello world"}', tool_call_id='tool_call_2'), |
| 1698 | + next_part_kind=None, |
| 1699 | + ) |
| 1700 | + |
| 1701 | + yield FunctionToolCallEvent( |
| 1702 | + part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1') |
| 1703 | + ) |
| 1704 | + yield FunctionToolCallEvent( |
| 1705 | + part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Goodbye world"}', tool_call_id='tool_call_2') |
| 1706 | + ) |
| 1707 | + |
| 1708 | + yield FunctionToolResultEvent( |
| 1709 | + result=ToolReturnPart(tool_name='tool_call_1', content='Hi!', tool_call_id='tool_call_1') |
| 1710 | + ) |
| 1711 | + yield FunctionToolResultEvent( |
| 1712 | + result=ToolReturnPart(tool_name='tool_call_2', content='Bye!', tool_call_id='tool_call_2') |
| 1713 | + ) |
| 1714 | + |
| 1715 | + yield PartStartEvent( |
| 1716 | + index=0, |
| 1717 | + part=ToolCallPart(tool_name='tool_call_3', args='{}', tool_call_id='tool_call_3'), |
| 1718 | + previous_part_kind=None, |
| 1719 | + ) |
| 1720 | + yield PartDeltaEvent( |
| 1721 | + index=0, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_3') |
| 1722 | + ) |
| 1723 | + yield PartEndEvent( |
| 1724 | + index=0, |
| 1725 | + part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3'), |
| 1726 | + next_part_kind='tool-call', |
| 1727 | + ) |
| 1728 | + |
| 1729 | + yield PartStartEvent( |
| 1730 | + index=1, |
| 1731 | + part=ToolCallPart(tool_name='tool_call_4', args='{}', tool_call_id='tool_call_4'), |
| 1732 | + previous_part_kind='tool-call', |
| 1733 | + ) |
| 1734 | + yield PartDeltaEvent( |
| 1735 | + index=1, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_4') |
| 1736 | + ) |
| 1737 | + yield PartEndEvent( |
| 1738 | + index=1, |
| 1739 | + part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4'), |
| 1740 | + next_part_kind=None, |
| 1741 | + ) |
| 1742 | + |
| 1743 | + yield FunctionToolCallEvent( |
| 1744 | + part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3') |
| 1745 | + ) |
| 1746 | + yield FunctionToolCallEvent( |
| 1747 | + part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4') |
| 1748 | + ) |
| 1749 | + |
| 1750 | + yield FunctionToolResultEvent( |
| 1751 | + result=ToolReturnPart(tool_name='tool_call_3', content='Hi!', tool_call_id='tool_call_3') |
| 1752 | + ) |
| 1753 | + yield FunctionToolResultEvent( |
| 1754 | + result=ToolReturnPart(tool_name='tool_call_4', content='Bye!', tool_call_id='tool_call_4') |
| 1755 | + ) |
| 1756 | + |
| 1757 | + run_input = create_input( |
| 1758 | + UserMessage( |
| 1759 | + id='msg_1', |
| 1760 | + content='Tell me about Hello World', |
| 1761 | + ), |
| 1762 | + ) |
| 1763 | + event_stream = AGUIEventStream(run_input=run_input) |
| 1764 | + events = [ |
| 1765 | + json.loads(event.removeprefix('data: ')) |
| 1766 | + async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator())) |
| 1767 | + ] |
| 1768 | + |
| 1769 | + assert events == snapshot( |
| 1770 | + [ |
| 1771 | + { |
| 1772 | + 'type': 'RUN_STARTED', |
| 1773 | + 'threadId': (thread_id := IsSameStr()), |
| 1774 | + 'runId': (run_id := IsSameStr()), |
| 1775 | + }, |
| 1776 | + {'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'}, |
| 1777 | + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'Hello'}, |
| 1778 | + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': ' world'}, |
| 1779 | + {'type': 'TEXT_MESSAGE_END', 'messageId': message_id}, |
| 1780 | + { |
| 1781 | + 'type': 'TOOL_CALL_START', |
| 1782 | + 'toolCallId': 'tool_call_1', |
| 1783 | + 'toolCallName': 'tool_call_1', |
| 1784 | + 'parentMessageId': message_id, |
| 1785 | + }, |
| 1786 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{}'}, |
| 1787 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{"query": "Hello world"}'}, |
| 1788 | + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_1'}, |
| 1789 | + { |
| 1790 | + 'type': 'TOOL_CALL_START', |
| 1791 | + 'toolCallId': 'tool_call_2', |
| 1792 | + 'toolCallName': 'tool_call_2', |
| 1793 | + 'parentMessageId': message_id, |
| 1794 | + }, |
| 1795 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{}'}, |
| 1796 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{"query": "Goodbye world"}'}, |
| 1797 | + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_2'}, |
| 1798 | + { |
| 1799 | + 'type': 'TOOL_CALL_RESULT', |
| 1800 | + 'messageId': IsStr(), |
| 1801 | + 'toolCallId': 'tool_call_1', |
| 1802 | + 'content': 'Hi!', |
| 1803 | + 'role': 'tool', |
| 1804 | + }, |
| 1805 | + { |
| 1806 | + 'type': 'TOOL_CALL_RESULT', |
| 1807 | + 'messageId': (result_message_id := IsSameStr()), |
| 1808 | + 'toolCallId': 'tool_call_2', |
| 1809 | + 'content': 'Bye!', |
| 1810 | + 'role': 'tool', |
| 1811 | + }, |
| 1812 | + { |
| 1813 | + 'type': 'TOOL_CALL_START', |
| 1814 | + 'toolCallId': 'tool_call_3', |
| 1815 | + 'toolCallName': 'tool_call_3', |
| 1816 | + 'parentMessageId': (new_message_id := IsSameStr()), |
| 1817 | + }, |
| 1818 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{}'}, |
| 1819 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{"query": "Hello world"}'}, |
| 1820 | + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_3'}, |
| 1821 | + { |
| 1822 | + 'type': 'TOOL_CALL_START', |
| 1823 | + 'toolCallId': 'tool_call_4', |
| 1824 | + 'toolCallName': 'tool_call_4', |
| 1825 | + 'parentMessageId': new_message_id, |
| 1826 | + }, |
| 1827 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{}'}, |
| 1828 | + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{"query": "Goodbye world"}'}, |
| 1829 | + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_4'}, |
| 1830 | + { |
| 1831 | + 'type': 'TOOL_CALL_RESULT', |
| 1832 | + 'messageId': IsStr(), |
| 1833 | + 'toolCallId': 'tool_call_3', |
| 1834 | + 'content': 'Hi!', |
| 1835 | + 'role': 'tool', |
| 1836 | + }, |
| 1837 | + { |
| 1838 | + 'type': 'TOOL_CALL_RESULT', |
| 1839 | + 'messageId': IsStr(), |
| 1840 | + 'toolCallId': 'tool_call_4', |
| 1841 | + 'content': 'Bye!', |
| 1842 | + 'role': 'tool', |
| 1843 | + }, |
| 1844 | + { |
| 1845 | + 'type': 'RUN_FINISHED', |
| 1846 | + 'threadId': thread_id, |
| 1847 | + 'runId': run_id, |
| 1848 | + }, |
| 1849 | + ] |
| 1850 | + ) |
| 1851 | + |
| 1852 | + assert result_message_id != new_message_id |
| 1853 | + |
| 1854 | + |
1664 | 1855 | async def test_handle_ag_ui_request(): |
1665 | 1856 | agent = Agent(model=TestModel()) |
1666 | 1857 | run_input = create_input( |
|
0 commit comments