|
7 | 7 |
|
8 | 8 | from ag_ui.core import ( |
9 | 9 | RunAgentInput, BaseEvent, EventType, Tool as AGUITool, |
10 | | - UserMessage, ToolMessage, RunStartedEvent, RunFinishedEvent, RunErrorEvent |
| 10 | + UserMessage, ToolMessage, RunStartedEvent, RunFinishedEvent, RunErrorEvent, |
| 11 | + AssistantMessage, ToolCall, FunctionCall, |
11 | 12 | ) |
12 | 13 |
|
13 | 14 | from ag_ui_adk import ADKAgent |
@@ -530,6 +531,116 @@ async def mock_start_new_execution(input_data, *, tool_results=None, message_bat |
530 | 531 | assert len(tool_messages) == 1 |
531 | 532 | assert getattr(tool_messages[0], 'id', None) == "tool_1" |
532 | 533 |
|
| 534 | + @pytest.mark.asyncio |
| 535 | + async def test_run_skips_assistant_history_before_tool_result(self, ag_ui_adk): |
| 536 | + """Assistant tool call history should not trigger a new execution before tool results arrive.""" |
| 537 | + assistant_call = AssistantMessage( |
| 538 | + id="assistant_tool", |
| 539 | + role="assistant", |
| 540 | + content=None, |
| 541 | + tool_calls=[ |
| 542 | + ToolCall( |
| 543 | + id="call_1", |
| 544 | + function=FunctionCall(name="test_tool", arguments="{}"), |
| 545 | + ) |
| 546 | + ], |
| 547 | + ) |
| 548 | + |
| 549 | + tool_result = ToolMessage( |
| 550 | + id="tool_result", |
| 551 | + role="tool", |
| 552 | + content='{"result": "value"}', |
| 553 | + tool_call_id="call_1", |
| 554 | + ) |
| 555 | + |
| 556 | + input_data = RunAgentInput( |
| 557 | + thread_id="thread_assistant_tool", |
| 558 | + run_id="run_assistant_tool", |
| 559 | + messages=[ |
| 560 | + UserMessage(id="user_initial", role="user", content="Initial question"), |
| 561 | + assistant_call, |
| 562 | + tool_result, |
| 563 | + ], |
| 564 | + tools=[], |
| 565 | + context=[], |
| 566 | + state={}, |
| 567 | + forwarded_props={}, |
| 568 | + ) |
| 569 | + |
| 570 | + # Mark the initial user message as already processed so only the assistant call and tool result are unseen |
| 571 | + app_name = ag_ui_adk._get_app_name(input_data) |
| 572 | + ag_ui_adk._session_manager.mark_messages_processed(app_name, input_data.thread_id, ["user_initial"]) |
| 573 | + |
| 574 | + start_calls = [] |
| 575 | + |
| 576 | + async def mock_start_new_execution(input_data, *, tool_results=None, message_batch=None): |
| 577 | + start_calls.append((tool_results, message_batch)) |
| 578 | + |
| 579 | + call_id = None |
| 580 | + if tool_results: |
| 581 | + call_id = tool_results[0]['message'].tool_call_id |
| 582 | + elif message_batch: |
| 583 | + for message in message_batch: |
| 584 | + tool_calls = getattr(message, "tool_calls", None) |
| 585 | + if tool_calls: |
| 586 | + call_id = tool_calls[0].id |
| 587 | + break |
| 588 | + |
| 589 | + if call_id: |
| 590 | + await ag_ui_adk._add_pending_tool_call_with_context( |
| 591 | + input_data.thread_id, |
| 592 | + call_id, |
| 593 | + ag_ui_adk._get_app_name(input_data), |
| 594 | + ag_ui_adk._get_user_id(input_data), |
| 595 | + ) |
| 596 | + |
| 597 | + yield RunStartedEvent( |
| 598 | + type=EventType.RUN_STARTED, |
| 599 | + thread_id=input_data.thread_id, |
| 600 | + run_id=input_data.run_id, |
| 601 | + ) |
| 602 | + yield RunFinishedEvent( |
| 603 | + type=EventType.RUN_FINISHED, |
| 604 | + thread_id=input_data.thread_id, |
| 605 | + run_id=input_data.run_id, |
| 606 | + ) |
| 607 | + |
| 608 | + with patch.object( |
| 609 | + ag_ui_adk, |
| 610 | + '_start_new_execution', |
| 611 | + side_effect=mock_start_new_execution, |
| 612 | + ) as start_mock, patch.object( |
| 613 | + ag_ui_adk, |
| 614 | + '_handle_tool_result_submission', |
| 615 | + wraps=ag_ui_adk._handle_tool_result_submission, |
| 616 | + ), patch.object( |
| 617 | + ag_ui_adk, |
| 618 | + '_add_pending_tool_call_with_context', |
| 619 | + new_callable=AsyncMock, |
| 620 | + ) as pending_mock: |
| 621 | + events = [] |
| 622 | + async for event in ag_ui_adk.run(input_data): |
| 623 | + events.append(event) |
| 624 | + |
| 625 | + assert [event.type for event in events] == [ |
| 626 | + EventType.RUN_STARTED, |
| 627 | + EventType.RUN_FINISHED, |
| 628 | + ] |
| 629 | + |
| 630 | + assert start_mock.call_count == 1 |
| 631 | + assert len(start_calls) == 1 |
| 632 | + first_tool_results, first_batch = start_calls[0] |
| 633 | + assert first_tool_results is not None |
| 634 | + assert first_batch is None |
| 635 | + assert first_tool_results[0]['message'].id == "tool_result" |
| 636 | + |
| 637 | + assert pending_mock.await_count == 1 |
| 638 | + pending_call = pending_mock.await_args_list[0] |
| 639 | + assert pending_call.args[1] == "call_1" |
| 640 | + |
| 641 | + processed_ids = ag_ui_adk._session_manager.get_processed_message_ids(app_name, input_data.thread_id) |
| 642 | + assert "assistant_tool" in processed_ids |
| 643 | + |
533 | 644 | @pytest.mark.asyncio |
534 | 645 | async def test_run_preserves_order_for_user_then_tool(self, ag_ui_adk): |
535 | 646 | """Verify user updates are handled before subsequent tool messages.""" |
|
0 commit comments