Skip to content
6 changes: 6 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@ def __init__(
self,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
use_service_thread: bool = False,
require_confirmation: bool = True,
):
"""Initialize agent configuration.

Args:
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
predict_state_config: Configuration for predictive state updates
use_service_thread: Whether the agent thread is service-managed
require_confirmation: Whether predictive updates require confirmation
"""
self.state_schema = self._normalize_state_schema(state_schema)
self.predict_state_config = predict_state_config or {}
self.use_service_thread = use_service_thread
self.require_confirmation = require_confirmation

@staticmethod
Expand Down Expand Up @@ -86,6 +89,7 @@ def __init__(
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
orchestrators: list[Orchestrator] | None = None,
use_service_thread: bool = False,
confirmation_strategy: ConfirmationStrategy | None = None,
):
"""Initialize the AG-UI compatible agent wrapper.
Expand All @@ -101,6 +105,7 @@ def __init__(
Set to False for agentic generative UI that updates automatically.
orchestrators: Custom orchestrators (auto-configured if None).
Orchestrators are checked in order; first match handles the request.
use_service_thread: Whether the agent thread is service-managed.
confirmation_strategy: Strategy for generating confirmation messages.
Defaults to DefaultConfirmationStrategy if None.
"""
Expand All @@ -111,6 +116,7 @@ def __init__(
self.config = AgentConfig(
state_schema=state_schema,
predict_state_config=predict_state_config,
use_service_thread=use_service_thread,
require_confirmation=require_confirmation,
)

Expand Down
144 changes: 122 additions & 22 deletions python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from typing import TYPE_CHECKING, Any

from ag_ui.core import (
Expand Down Expand Up @@ -53,11 +53,18 @@
merge_tools,
register_additional_client_tools,
)
from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value
from ._utils import (
convert_agui_tools_to_agent_framework,
generate_event_id,
get_conversation_id_from_update,
get_role_value,
)

if TYPE_CHECKING:
from ._agent import AgentConfig
from ._confirmation_strategies import ConfirmationStrategy
from ._events import AgentFrameworkEventBridge
from ._orchestration._state_manager import StateManager


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,6 +99,8 @@ def __init__(
self._last_message = None
self._run_id: str | None = None
self._thread_id: str | None = None
self._supplied_run_id: str | None = None
self._supplied_thread_id: str | None = None

@property
def messages(self):
Expand Down Expand Up @@ -125,26 +134,66 @@ def last_message(self):
self._last_message = self.messages[-1]
return self._last_message

@property
def supplied_run_id(self) -> str | None:
"""Get the supplied run ID, if any."""
if self._supplied_run_id is None:
self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId")
return self._supplied_run_id

@property
def run_id(self) -> str:
"""Get or generate run ID."""
"""Get supplied run ID or generate a new run ID."""
if self._run_id:
return self._run_id

if self.supplied_run_id:
self._run_id = self.supplied_run_id

if self._run_id is None:
self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4())
# This should never be None after the if block above, but satisfy type checkers
if self._run_id is None: # pragma: no cover
raise RuntimeError("Failed to initialize run_id")
self._run_id = str(uuid.uuid4())

return self._run_id

@property
def supplied_thread_id(self) -> str | None:
"""Get the supplied thread ID, if any."""
if self._supplied_thread_id is None:
self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId")
return self._supplied_thread_id

@property
def thread_id(self) -> str:
"""Get or generate thread ID."""
"""Get supplied thread ID or generate a new thread ID."""
if self._thread_id:
return self._thread_id

if self.supplied_thread_id:
self._thread_id = self.supplied_thread_id

if self._thread_id is None:
self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4())
# This should never be None after the if block above, but satisfy type checkers
if self._thread_id is None: # pragma: no cover
raise RuntimeError("Failed to initialize thread_id")
self._thread_id = str(uuid.uuid4())

return self._thread_id

def update_run_id(self, new_run_id: str) -> None:
"""Update the run ID in the context.

Args:
new_run_id: The new run ID to set
"""
self._supplied_run_id = new_run_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the naming _supplied_run_id is slightly misleading after mutation - consider _resolved_run_id if you agree it's clearer.

self._run_id = new_run_id

def update_thread_id(self, new_thread_id: str) -> None:
"""Update the thread ID in the context.

Args:
new_thread_id: The new thread ID to set
"""
self._supplied_thread_id = new_thread_id
self._thread_id = new_thread_id


class Orchestrator(ABC):
"""Base orchestrator for agent execution flows."""
Expand Down Expand Up @@ -297,6 +346,28 @@ def can_handle(self, context: ExecutionContext) -> bool:
"""
return True

def _create_initial_events(
self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager"
) -> Sequence[BaseEvent]:
"""Generate initial events for the run.

Args:
event_bridge: Event bridge for creating events
Returns:
Initial AG-UI events
"""
events: list[BaseEvent] = [event_bridge.create_run_started_event()]

predict_event = state_manager.predict_state_event()
if predict_event:
events.append(predict_event)

snapshot_event = state_manager.initial_snapshot_event(event_bridge)
if snapshot_event:
events.append(snapshot_event)

return events

async def run(
self,
context: ExecutionContext,
Expand Down Expand Up @@ -342,17 +413,11 @@ async def run(
approval_tool_name=approval_tool_name,
)

yield event_bridge.create_run_started_event()

predict_event = state_manager.predict_state_event()
if predict_event:
yield predict_event

snapshot_event = state_manager.initial_snapshot_event(event_bridge)
if snapshot_event:
yield snapshot_event
if context.config.use_service_thread:
thread = AgentThread(service_thread_id=context.supplied_thread_id)
else:
thread = AgentThread()

thread = AgentThread()
thread.metadata = { # type: ignore[attr-defined]
"ag_ui_thread_id": context.thread_id,
"ag_ui_run_id": context.run_id,
Expand All @@ -363,6 +428,8 @@ async def run(
provider_messages = context.messages or []
snapshot_messages = context.snapshot_messages
if not provider_messages:
for event in self._create_initial_events(event_bridge, state_manager):
yield event
logger.warning("No messages provided in AG-UI input")
yield event_bridge.create_run_finished_event()
return
Expand Down Expand Up @@ -554,13 +621,41 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
confirmation_message = strategy.on_state_rejected()

message_id = generate_event_id()
for event in self._create_initial_events(event_bridge, state_manager):
yield event
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message)
yield TextMessageEndEvent(message_id=message_id)
yield event_bridge.create_run_finished_event()
return

should_recreate_event_bridge = False
async for update in context.agent.run_stream(messages_to_run, **run_kwargs):
conv_id = get_conversation_id_from_update(update)
if conv_id and conv_id != context.thread_id:
context.update_thread_id(conv_id)
should_recreate_event_bridge = True

if hasattr(update, "response_id") and update.response_id and update.response_id != context.run_id:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the hasattr check since response_id is a typed attribute on AgentRunResponseUpdate? The check could simplify to if update.response_id and update.response_id != context.run_id:

context.update_run_id(update.response_id)
should_recreate_event_bridge = True

if should_recreate_event_bridge:
event_bridge = AgentFrameworkEventBridge(
run_id=context.run_id,
thread_id=context.thread_id,
predict_state_config=context.config.predict_state_config,
current_state=current_state,
skip_text_content=skip_text_content,
require_confirmation=context.config.require_confirmation,
approval_tool_name=approval_tool_name,
)
should_recreate_event_bridge = False

if update_count == 0:
for event in self._create_initial_events(event_bridge, state_manager):
yield event

update_count += 1
logger.info(f"[STREAM] Received update #{update_count} from agent")
if all_updates is not None:
Expand Down Expand Up @@ -672,6 +767,11 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")

if all_updates is not None and len(all_updates) == 0:
logger.info("No updates received from agent - emitting initial events")
for event in self._create_initial_events(event_bridge, state_manager):
yield event

logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}")
if event_bridge.current_message_id:
logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}")
Expand Down
16 changes: 15 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import date, datetime
from typing import Any

from agent_framework import AIFunction, Role, ToolProtocol
from agent_framework import AgentRunResponseUpdate, AIFunction, ChatResponseUpdate, Role, ToolProtocol

# Role mapping constants
AGUI_TO_FRAMEWORK_ROLE: dict[str, Role] = {
Expand Down Expand Up @@ -259,3 +259,17 @@ def convert_tools_to_agui_format(
continue

return results if results else None


def get_conversation_id_from_update(update: AgentRunResponseUpdate) -> str | None:
"""Extract conversation ID from AgentRunResponseUpdate metadata.

Args:
update: AgentRunResponseUpdate instance
Returns:
Conversation ID if present, else None

"""
if isinstance(update.raw_representation, ChatResponseUpdate):
return update.raw_representation.conversation_id
return None
54 changes: 54 additions & 0 deletions python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,60 @@ async def stream_fn(
assert "written" in full_text.lower() or "document" in full_text.lower()


async def test_agent_with_use_service_thread_is_false():
"""Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID."""
from agent_framework.ag_ui import AgentFrameworkAgent

request_service_thread_id: str | None = None

async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
nonlocal request_service_thread_id
thread = kwargs.get("thread")
request_service_thread_id = thread.service_thread_id if thread else None
yield ChatResponseUpdate(
contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)

agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=False)

input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"}

events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
assert request_service_thread_id is None # type: ignore[attr-defined] (service_thread_id should be set)


async def test_agent_with_use_service_thread_is_true():
"""Test that when use_service_thread is True, the AgentThread used to run the agent is set to the service thread ID."""
from agent_framework.ag_ui import AgentFrameworkAgent

request_service_thread_id: str | None = None

async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
nonlocal request_service_thread_id
thread = kwargs.get("thread")
request_service_thread_id = thread.service_thread_id if thread else None
yield ChatResponseUpdate(
contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)

agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=True)

input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"}

events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
assert request_service_thread_id == "conv_123456" # type: ignore[attr-defined] (service_thread_id should be set)


async def test_function_approval_mode_executes_tool():
"""Test that function approval with approval_mode='always_require' sends the correct messages."""
from agent_framework import FunctionResultContent, ai_function
Expand Down
Loading
Loading