diff --git a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py index c18c1f1c98..391cf06f2f 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py @@ -404,7 +404,7 @@ async def before_request(self) -> AsyncIterator[EventT]: Override this to inject custom events at the start of the request. """ - return + return # pragma: lax no cover yield # Make this an async generator async def after_request(self) -> AsyncIterator[EventT]: @@ -412,7 +412,7 @@ async def after_request(self) -> AsyncIterator[EventT]: Override this to inject custom events at the end of the request. """ - return + return # pragma: lax no cover yield # Make this an async generator async def before_response(self) -> AsyncIterator[EventT]: @@ -420,7 +420,7 @@ async def before_response(self) -> AsyncIterator[EventT]: Override this to inject custom events at the start of the response. """ - return + return # pragma: no cover yield # Make this an async generator async def after_response(self) -> AsyncIterator[EventT]: @@ -428,7 +428,7 @@ async def after_response(self) -> AsyncIterator[EventT]: Override this to inject custom events at the end of the response. """ - return + return # pragma: lax no cover yield # Make this an async generator async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[EventT]: diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py index 0a6f354abf..2b37d36351 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py @@ -92,6 +92,13 @@ async def before_stream(self) -> AsyncIterator[BaseEvent]: run_id=self.run_input.run_id, ) + async def before_response(self) -> AsyncIterator[BaseEvent]: + # Prevent parts from a subsequent response being tied to parts from an earlier response. + # See https://github.com/pydantic/pydantic-ai/issues/3316 + self.new_message_id() + return + yield # Make this an async generator + async def after_stream(self) -> AsyncIterator[BaseEvent]: if not self._error: yield RunFinishedEvent( @@ -167,9 +174,11 @@ async def _handle_tool_call_start( self, part: ToolCallPart | BuiltinToolCallPart, tool_call_id: str | None = None ) -> AsyncIterator[BaseEvent]: tool_call_id = tool_call_id or part.tool_call_id - message_id = self.message_id or self.new_message_id() + parent_message_id = self.message_id - yield ToolCallStartEvent(tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=message_id) + yield ToolCallStartEvent( + tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=parent_message_id + ) if part.args: yield ToolCallArgsEvent(tool_call_id=tool_call_id, delta=part.args_as_json_str()) diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 0ca9dcc3aa..05071d2259 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -19,6 +19,8 @@ from pydantic_ai import ( BuiltinToolCallPart, BuiltinToolReturnPart, + FunctionToolCallEvent, + FunctionToolResultEvent, ModelMessage, ModelRequest, ModelResponse, @@ -29,6 +31,7 @@ TextPart, TextPartDelta, ToolCallPart, + ToolCallPartDelta, ToolReturn, ToolReturnPart, UserPromptPart, @@ -1661,6 +1664,194 @@ async def event_generator(): ) +async def test_event_stream_multiple_responses_with_tool_calls(): + async def event_generator(): + yield PartStartEvent(index=0, part=TextPart(content='Hello')) + yield PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' world')) + yield PartEndEvent(index=0, part=TextPart(content='Hello world'), next_part_kind='tool-call') + + yield PartStartEvent( + index=1, + part=ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'), + previous_part_kind='text', + ) + yield PartDeltaEvent( + index=1, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_1') + ) + yield PartEndEvent( + index=1, + part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1'), + next_part_kind='tool-call', + ) + + yield PartStartEvent( + index=2, + part=ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'), + previous_part_kind='tool-call', + ) + yield PartDeltaEvent( + index=2, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_2') + ) + yield PartEndEvent( + index=2, + part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Hello world"}', tool_call_id='tool_call_2'), + next_part_kind=None, + ) + + yield FunctionToolCallEvent( + part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1') + ) + yield FunctionToolCallEvent( + part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Goodbye world"}', tool_call_id='tool_call_2') + ) + + yield FunctionToolResultEvent( + result=ToolReturnPart(tool_name='tool_call_1', content='Hi!', tool_call_id='tool_call_1') + ) + yield FunctionToolResultEvent( + result=ToolReturnPart(tool_name='tool_call_2', content='Bye!', tool_call_id='tool_call_2') + ) + + yield PartStartEvent( + index=0, + part=ToolCallPart(tool_name='tool_call_3', args='{}', tool_call_id='tool_call_3'), + previous_part_kind=None, + ) + yield PartDeltaEvent( + index=0, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_3') + ) + yield PartEndEvent( + index=0, + part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3'), + next_part_kind='tool-call', + ) + + yield PartStartEvent( + index=1, + part=ToolCallPart(tool_name='tool_call_4', args='{}', tool_call_id='tool_call_4'), + previous_part_kind='tool-call', + ) + yield PartDeltaEvent( + index=1, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_4') + ) + yield PartEndEvent( + index=1, + part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4'), + next_part_kind=None, + ) + + yield FunctionToolCallEvent( + part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3') + ) + yield FunctionToolCallEvent( + part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4') + ) + + yield FunctionToolResultEvent( + result=ToolReturnPart(tool_name='tool_call_3', content='Hi!', tool_call_id='tool_call_3') + ) + yield FunctionToolResultEvent( + result=ToolReturnPart(tool_name='tool_call_4', content='Bye!', tool_call_id='tool_call_4') + ) + + run_input = create_input( + UserMessage( + id='msg_1', + content='Tell me about Hello World', + ), + ) + event_stream = AGUIEventStream(run_input=run_input) + events = [ + json.loads(event.removeprefix('data: ')) + async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator())) + ] + + assert events == snapshot( + [ + { + 'type': 'RUN_STARTED', + 'threadId': (thread_id := IsSameStr()), + 'runId': (run_id := IsSameStr()), + }, + {'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'Hello'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': ' world'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': message_id}, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': 'tool_call_1', + 'toolCallName': 'tool_call_1', + 'parentMessageId': message_id, + }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{}'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{"query": "Hello world"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_1'}, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': 'tool_call_2', + 'toolCallName': 'tool_call_2', + 'parentMessageId': message_id, + }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{}'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{"query": "Goodbye world"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_2'}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': IsStr(), + 'toolCallId': 'tool_call_1', + 'content': 'Hi!', + 'role': 'tool', + }, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': (result_message_id := IsSameStr()), + 'toolCallId': 'tool_call_2', + 'content': 'Bye!', + 'role': 'tool', + }, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': 'tool_call_3', + 'toolCallName': 'tool_call_3', + 'parentMessageId': (new_message_id := IsSameStr()), + }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{}'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{"query": "Hello world"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_3'}, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': 'tool_call_4', + 'toolCallName': 'tool_call_4', + 'parentMessageId': new_message_id, + }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{}'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{"query": "Goodbye world"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_4'}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': IsStr(), + 'toolCallId': 'tool_call_3', + 'content': 'Hi!', + 'role': 'tool', + }, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': IsStr(), + 'toolCallId': 'tool_call_4', + 'content': 'Bye!', + 'role': 'tool', + }, + { + 'type': 'RUN_FINISHED', + 'threadId': thread_id, + 'runId': run_id, + }, + ] + ) + + assert result_message_id != new_message_id + + async def test_handle_ag_ui_request(): agent = Agent(model=TestModel()) run_input = create_input(