-
Notifications
You must be signed in to change notification settings - Fork 1k
Python: (AG-UI) Support service-managed thread on AG-UI #3136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
howlowck
wants to merge
10
commits into
microsoft:main
Choose a base branch
from
howlowck:feature/ag-ui-proper-conversation-id-run-id
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
0a432e5
added service thread support
howlowck abc5a11
set service_thread_id to only supplied_thread_id
howlowck e99bf5f
uses raw_representation to extract the conversation_id
howlowck fb8543c
removed accidental edit
howlowck 48066fd
updated test to use raw_representation
howlowck 1fecf5e
resolves copilot review feedback
howlowck 11761fe
revert back StubAgent, since not used
howlowck c9ed539
removed relative module import
howlowck 5f9eec8
merged main
howlowck c82feed
Merge branch 'main' into feature/ag-ui-proper-conversation-id-run-id
moonbox3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| import logging | ||
| import uuid | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import AsyncGenerator | ||
| from collections.abc import AsyncGenerator, Sequence | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from ag_ui.core import ( | ||
|
|
@@ -53,11 +53,18 @@ | |
| merge_tools, | ||
| register_additional_client_tools, | ||
| ) | ||
| from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value | ||
| from ._utils import ( | ||
| convert_agui_tools_to_agent_framework, | ||
| generate_event_id, | ||
| get_conversation_id_from_update, | ||
| get_role_value, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ._agent import AgentConfig | ||
| from ._confirmation_strategies import ConfirmationStrategy | ||
| from ._events import AgentFrameworkEventBridge | ||
| from ._orchestration._state_manager import StateManager | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -92,6 +99,8 @@ def __init__( | |
| self._last_message = None | ||
| self._run_id: str | None = None | ||
| self._thread_id: str | None = None | ||
| self._supplied_run_id: str | None = None | ||
| self._supplied_thread_id: str | None = None | ||
|
|
||
| @property | ||
| def messages(self): | ||
|
|
@@ -125,26 +134,66 @@ def last_message(self): | |
| self._last_message = self.messages[-1] | ||
| return self._last_message | ||
|
|
||
| @property | ||
| def supplied_run_id(self) -> str | None: | ||
| """Get the supplied run ID, if any.""" | ||
| if self._supplied_run_id is None: | ||
| self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId") | ||
| return self._supplied_run_id | ||
|
|
||
| @property | ||
| def run_id(self) -> str: | ||
| """Get or generate run ID.""" | ||
| """Get supplied run ID or generate a new run ID.""" | ||
| if self._run_id: | ||
| return self._run_id | ||
|
|
||
| if self.supplied_run_id: | ||
| self._run_id = self.supplied_run_id | ||
|
|
||
| if self._run_id is None: | ||
| self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4()) | ||
| # This should never be None after the if block above, but satisfy type checkers | ||
| if self._run_id is None: # pragma: no cover | ||
| raise RuntimeError("Failed to initialize run_id") | ||
| self._run_id = str(uuid.uuid4()) | ||
|
|
||
| return self._run_id | ||
|
|
||
| @property | ||
| def supplied_thread_id(self) -> str | None: | ||
| """Get the supplied thread ID, if any.""" | ||
| if self._supplied_thread_id is None: | ||
| self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") | ||
| return self._supplied_thread_id | ||
|
|
||
| @property | ||
| def thread_id(self) -> str: | ||
| """Get or generate thread ID.""" | ||
| """Get supplied thread ID or generate a new thread ID.""" | ||
| if self._thread_id: | ||
| return self._thread_id | ||
|
|
||
| if self.supplied_thread_id: | ||
| self._thread_id = self.supplied_thread_id | ||
|
|
||
| if self._thread_id is None: | ||
| self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4()) | ||
| # This should never be None after the if block above, but satisfy type checkers | ||
| if self._thread_id is None: # pragma: no cover | ||
| raise RuntimeError("Failed to initialize thread_id") | ||
| self._thread_id = str(uuid.uuid4()) | ||
|
|
||
| return self._thread_id | ||
|
|
||
| def update_run_id(self, new_run_id: str) -> None: | ||
| """Update the run ID in the context. | ||
|
|
||
| Args: | ||
| new_run_id: The new run ID to set | ||
| """ | ||
| self._supplied_run_id = new_run_id | ||
| self._run_id = new_run_id | ||
|
|
||
| def update_thread_id(self, new_thread_id: str) -> None: | ||
| """Update the thread ID in the context. | ||
|
|
||
| Args: | ||
| new_thread_id: The new thread ID to set | ||
| """ | ||
| self._supplied_thread_id = new_thread_id | ||
| self._thread_id = new_thread_id | ||
|
|
||
|
|
||
| class Orchestrator(ABC): | ||
| """Base orchestrator for agent execution flows.""" | ||
|
|
@@ -297,6 +346,28 @@ def can_handle(self, context: ExecutionContext) -> bool: | |
| """ | ||
| return True | ||
|
|
||
| def _create_initial_events( | ||
| self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager" | ||
| ) -> Sequence[BaseEvent]: | ||
| """Generate initial events for the run. | ||
|
|
||
| Args: | ||
| event_bridge: Event bridge for creating events | ||
| Returns: | ||
| Initial AG-UI events | ||
| """ | ||
| events: list[BaseEvent] = [event_bridge.create_run_started_event()] | ||
|
|
||
| predict_event = state_manager.predict_state_event() | ||
| if predict_event: | ||
| events.append(predict_event) | ||
|
|
||
| snapshot_event = state_manager.initial_snapshot_event(event_bridge) | ||
| if snapshot_event: | ||
| events.append(snapshot_event) | ||
|
|
||
| return events | ||
|
|
||
| async def run( | ||
| self, | ||
| context: ExecutionContext, | ||
|
|
@@ -342,17 +413,11 @@ async def run( | |
| approval_tool_name=approval_tool_name, | ||
| ) | ||
|
|
||
| yield event_bridge.create_run_started_event() | ||
|
|
||
| predict_event = state_manager.predict_state_event() | ||
| if predict_event: | ||
| yield predict_event | ||
|
|
||
| snapshot_event = state_manager.initial_snapshot_event(event_bridge) | ||
| if snapshot_event: | ||
| yield snapshot_event | ||
| if context.config.use_service_thread: | ||
| thread = AgentThread(service_thread_id=context.supplied_thread_id) | ||
| else: | ||
| thread = AgentThread() | ||
|
|
||
| thread = AgentThread() | ||
| thread.metadata = { # type: ignore[attr-defined] | ||
| "ag_ui_thread_id": context.thread_id, | ||
| "ag_ui_run_id": context.run_id, | ||
|
|
@@ -363,6 +428,8 @@ async def run( | |
| provider_messages = context.messages or [] | ||
| snapshot_messages = context.snapshot_messages | ||
| if not provider_messages: | ||
| for event in self._create_initial_events(event_bridge, state_manager): | ||
| yield event | ||
| logger.warning("No messages provided in AG-UI input") | ||
| yield event_bridge.create_run_finished_event() | ||
| return | ||
|
|
@@ -554,13 +621,41 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap | |
| confirmation_message = strategy.on_state_rejected() | ||
|
|
||
| message_id = generate_event_id() | ||
| for event in self._create_initial_events(event_bridge, state_manager): | ||
| yield event | ||
| yield TextMessageStartEvent(message_id=message_id, role="assistant") | ||
| yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message) | ||
| yield TextMessageEndEvent(message_id=message_id) | ||
| yield event_bridge.create_run_finished_event() | ||
| return | ||
|
|
||
| should_recreate_event_bridge = False | ||
| async for update in context.agent.run_stream(messages_to_run, **run_kwargs): | ||
| conv_id = get_conversation_id_from_update(update) | ||
| if conv_id and conv_id != context.thread_id: | ||
| context.update_thread_id(conv_id) | ||
| should_recreate_event_bridge = True | ||
|
|
||
| if hasattr(update, "response_id") and update.response_id and update.response_id != context.run_id: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we remove the |
||
| context.update_run_id(update.response_id) | ||
| should_recreate_event_bridge = True | ||
|
|
||
| if should_recreate_event_bridge: | ||
| event_bridge = AgentFrameworkEventBridge( | ||
| run_id=context.run_id, | ||
| thread_id=context.thread_id, | ||
| predict_state_config=context.config.predict_state_config, | ||
| current_state=current_state, | ||
| skip_text_content=skip_text_content, | ||
| require_confirmation=context.config.require_confirmation, | ||
| approval_tool_name=approval_tool_name, | ||
| ) | ||
| should_recreate_event_bridge = False | ||
|
|
||
| if update_count == 0: | ||
| for event in self._create_initial_events(event_bridge, state_manager): | ||
| yield event | ||
|
|
||
| update_count += 1 | ||
| logger.info(f"[STREAM] Received update #{update_count} from agent") | ||
| if all_updates is not None: | ||
|
|
@@ -672,6 +767,11 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap | |
| yield TextMessageEndEvent(message_id=message_id) | ||
| logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") | ||
|
|
||
| if all_updates is not None and len(all_updates) == 0: | ||
| logger.info("No updates received from agent - emitting initial events") | ||
| for event in self._create_initial_events(event_bridge, state_manager): | ||
| yield event | ||
|
|
||
| logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") | ||
| if event_bridge.current_message_id: | ||
| logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the naming
_supplied_run_idis slightly misleading after mutation - consider_resolved_run_idif you agree it's clearer.