|
11 | 11 |
|
12 | 12 | from ag_ui.core import ( |
13 | 13 | BaseEvent, |
| 14 | + MessagesSnapshotEvent, |
14 | 15 | RunErrorEvent, |
15 | 16 | TextMessageContentEvent, |
16 | 17 | TextMessageEndEvent, |
@@ -588,32 +589,37 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: |
588 | 589 | # We should NOT add to thread.on_new_messages() as that would cause duplication. |
589 | 590 | # Instead, we pass messages directly to the agent via messages_to_run. |
590 | 591 |
|
591 | | - # Inject current state as system message context if we have state |
| 592 | + # Inject current state as system message context if we have state and this is a new user turn |
592 | 593 | messages_to_run: list[Any] = [] |
593 | 594 |
|
| 595 | + # Check if the last message is from the user (new turn) vs assistant/tool (mid-execution) |
| 596 | + is_new_user_turn = False |
| 597 | + if provider_messages: |
| 598 | + last_msg = provider_messages[-1] |
| 599 | + is_new_user_turn = last_msg.role.value == "user" |
| 600 | + |
| 601 | + # Check if conversation has tool calls (indicates mid-execution) |
594 | 602 | conversation_has_tool_calls = False |
595 | | - logger.debug(f"Checking {len(provider_messages)} provider messages for tool calls") |
596 | | - for i, msg in enumerate(provider_messages): |
597 | | - logger.debug( |
598 | | - f" Message {i}: role={msg.role.value}, contents={len(msg.contents) if hasattr(msg, 'contents') and msg.contents else 0}" |
599 | | - ) |
600 | 603 | for msg in provider_messages: |
601 | 604 | if msg.role.value == "assistant" and hasattr(msg, "contents") and msg.contents: |
602 | 605 | if any(isinstance(content, FunctionCallContent) for content in msg.contents): |
603 | 606 | conversation_has_tool_calls = True |
604 | 607 | break |
605 | | - if current_state and context.config.state_schema and not conversation_has_tool_calls: |
| 608 | + |
| 609 | + # Only inject state context on new user turns AND when conversation doesn't have tool calls |
| 610 | + # (tool calls indicate we're mid-execution, so state context was already injected) |
| 611 | + if current_state and context.config.state_schema and is_new_user_turn and not conversation_has_tool_calls: |
606 | 612 | state_json = json.dumps(current_state, indent=2) |
607 | 613 | state_context_msg = ChatMessage( |
608 | 614 | role="system", |
609 | 615 | contents=[ |
610 | 616 | TextContent( |
611 | 617 | text=f"""Current state of the application: |
612 | | -{state_json} |
| 618 | + {state_json} |
613 | 619 |
|
614 | | -When modifying state, you MUST include ALL existing data plus your changes. |
615 | | -For example, if adding a new ingredient, include all existing ingredients PLUS the new one. |
616 | | -Never replace existing data - always append or merge.""" |
| 620 | + When modifying state, you MUST include ALL existing data plus your changes. |
| 621 | + For example, if adding one new item to a list, include ALL existing items PLUS the one new item. |
| 622 | + Never replace existing data - always preserve and append or merge.""" |
617 | 623 | ) |
618 | 624 | ], |
619 | 625 | ) |
@@ -714,12 +720,19 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: |
714 | 720 |
|
715 | 721 | # Collect all updates to get the final structured output |
716 | 722 | all_updates: list[Any] = [] |
| 723 | + update_count = 0 |
717 | 724 | async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param): |
| 725 | + update_count += 1 |
| 726 | + logger.info(f"[STREAM] Received update #{update_count} from agent") |
718 | 727 | all_updates.append(update) |
719 | 728 | events = await event_bridge.from_agent_run_update(update) |
| 729 | + logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") |
720 | 730 | for event in events: |
| 731 | + logger.info(f"[STREAM] Yielding event: {type(event).__name__}") |
721 | 732 | yield event |
722 | 733 |
|
| 734 | + logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") |
| 735 | + |
723 | 736 | # After agent completes, check if we should stop (waiting for user to confirm changes) |
724 | 737 | if event_bridge.should_stop_after_confirm: |
725 | 738 | logger.info("Stopping run after confirm_changes - waiting for user response") |
@@ -793,9 +806,56 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: |
793 | 806 | yield TextMessageEndEvent(message_id=message_id) |
794 | 807 | logger.info(f"Emitted conversational message: {response_dict['message'][:100]}...") |
795 | 808 |
|
| 809 | + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") |
796 | 810 | if event_bridge.current_message_id: |
| 811 | + logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") |
797 | 812 | yield event_bridge.create_message_end_event(event_bridge.current_message_id) |
798 | 813 |
|
| 814 | + # Emit MessagesSnapshotEvent to persist the final assistant text message |
| 815 | + from ._message_adapters import agui_messages_to_snapshot_format |
| 816 | + |
| 817 | + # Build the final assistant message with accumulated text content |
| 818 | + assistant_text_message = { |
| 819 | + "id": event_bridge.current_message_id, |
| 820 | + "role": "assistant", |
| 821 | + "content": event_bridge.accumulated_text_content, |
| 822 | + } |
| 823 | + |
| 824 | + # Convert input messages to snapshot format (normalize content structure) |
| 825 | + # event_bridge.input_messages are already in AG-UI format, just need normalization |
| 826 | + converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) |
| 827 | + |
| 828 | + # Build complete messages array |
| 829 | + # Include: input messages + any pending tool calls/results + final text message |
| 830 | + all_messages = converted_input_messages.copy() |
| 831 | + |
| 832 | + # Add assistant message with tool calls if any |
| 833 | + if event_bridge.pending_tool_calls: |
| 834 | + tool_call_message = { |
| 835 | + "id": generate_event_id(), |
| 836 | + "role": "assistant", |
| 837 | + "tool_calls": event_bridge.pending_tool_calls.copy(), |
| 838 | + } |
| 839 | + all_messages.append(tool_call_message) |
| 840 | + |
| 841 | + # Add tool results if any |
| 842 | + all_messages.extend(event_bridge.tool_results.copy()) |
| 843 | + |
| 844 | + # Add final text message |
| 845 | + all_messages.append(assistant_text_message) |
| 846 | + |
| 847 | + messages_snapshot = MessagesSnapshotEvent( |
| 848 | + messages=all_messages, # type: ignore[arg-type] |
| 849 | + ) |
| 850 | + logger.info( |
| 851 | + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages " |
| 852 | + f"(text content length: {len(event_bridge.accumulated_text_content)})" |
| 853 | + ) |
| 854 | + yield messages_snapshot |
| 855 | + else: |
| 856 | + logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") |
| 857 | + |
| 858 | + logger.info("[FINALIZE] Emitting RUN_FINISHED event") |
799 | 859 | yield event_bridge.create_run_finished_event() |
800 | 860 | logger.info(f"Completed agent run for thread_id={context.thread_id}, run_id={context.run_id}") |
801 | 861 |
|
|
0 commit comments