diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 184da0239e..bcb4a7cf82 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -20,6 +20,11 @@ TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, + ThinkingEndEvent, + ThinkingStartEvent, + ThinkingTextMessageContentEvent, + ThinkingTextMessageEndEvent, + ThinkingTextMessageStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent, @@ -31,6 +36,7 @@ FunctionCallContent, FunctionResultContent, TextContent, + TextReasoningContent, ) from ._utils import generate_event_id @@ -89,6 +95,10 @@ def __init__( self.tool_calls_ended: set[str] = set() # Track which tool calls have had ToolCallEndEvent emitted self.accumulated_text_content: str = "" # Track accumulated text for final MessagesSnapshotEvent + # For thinking/reasoning content (extended thinking from Anthropic, reasoning from OpenAI) + self._thinking_started: bool = False + self._thinking_text_started: bool = False + async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[BaseEvent]: """ Convert an AgentRunResponseUpdate to AG-UI events. @@ -104,7 +114,10 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items") for idx, content in enumerate(update.contents): logger.info(f" Content {idx}: type={type(content).__name__}") - if isinstance(content, TextContent): + if isinstance(content, TextReasoningContent): + # Handle reasoning/thinking content first (before regular text) + events.extend(self._handle_text_reasoning_content(content)) + elif isinstance(content, TextContent): events.extend(self._handle_text_content(content)) elif isinstance(content, FunctionCallContent): events.extend(self._handle_function_call_content(content)) @@ -675,3 +688,66 @@ def create_state_delta_event(self, delta: list[dict[str, Any]]) -> StateDeltaEve return StateDeltaEvent( delta=delta, ) + + def _handle_text_reasoning_content(self, content: TextReasoningContent) -> list[BaseEvent]: + """Handle TextReasoningContent by emitting AG-UI thinking events. + + Supports both Anthropic extended thinking (text field) and OpenAI reasoning_details + (protected_data field with JSON-encoded content). + """ + events: list[BaseEvent] = [] + + # Prefer text field, fallback to extracting from protected_data (OpenAI format) + text = content.text + if not text and content.protected_data: + try: + data = json.loads(content.protected_data) + # OpenAI reasoning_details format: [{"type": "text", "content": "..."}] + if isinstance(data, list) and data: + text = data[0].get("content", "") + except (json.JSONDecodeError, KeyError, IndexError, TypeError): + pass + + if not text: + return events + + if not self._thinking_started: + events.append(ThinkingStartEvent()) + self._thinking_started = True + logger.info("Emitting ThinkingStartEvent") + + if not self._thinking_text_started: + events.append(ThinkingTextMessageStartEvent()) + self._thinking_text_started = True + logger.info("Emitting ThinkingTextMessageStartEvent") + + events.append(ThinkingTextMessageContentEvent(delta=text)) + logger.info(f"Emitting ThinkingTextMessageContentEvent with delta_length={len(text)}") + return events + + def _end_thinking_if_needed(self) -> list[BaseEvent]: + """End thinking events if they were started. + + Called when transitioning from reasoning content to regular content. + """ + events: list[BaseEvent] = [] + + if self._thinking_text_started: + events.append(ThinkingTextMessageEndEvent()) + self._thinking_text_started = False + logger.info("Emitting ThinkingTextMessageEndEvent") + + if self._thinking_started: + events.append(ThinkingEndEvent()) + self._thinking_started = False + logger.info("Emitting ThinkingEndEvent") + + return events + + def finalize_thinking(self) -> list[BaseEvent]: + """Finalize any open thinking events at the end of a stream. + + This should be called by the orchestrator at stream end, similar to + how TextMessageEndEvent is handled in the orchestrator. + """ + return self._end_thinking_if_needed() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6bdff552b6..9a6f2190f6 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -448,6 +448,12 @@ async def run( events = await event_bridge.from_agent_run_update(update) logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") for event in events: + # End thinking before text or tool call events (like TextMessageEndEvent pattern) + if event.type in ("TEXT_MESSAGE_START", "TOOL_CALL_START"): + thinking_end_events = event_bridge.finalize_thinking() + for te in thinking_end_events: + logger.info(f"[STREAM] Emitting thinking end event before {event.type}: {type(te).__name__}") + yield te logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event @@ -503,6 +509,12 @@ async def run( yield TextMessageEndEvent(message_id=message_id) logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + # Finalize any open thinking events at stream end + thinking_end_events = event_bridge.finalize_thinking() + for event in thinking_end_events: + logger.info(f"[FINALIZE] Emitting thinking end event: {type(event).__name__}") + yield event + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") if event_bridge.current_message_id: logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py index a51d1f382a..248f276c10 100644 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ b/python/packages/ag-ui/tests/test_events_comprehensive.py @@ -688,3 +688,213 @@ async def test_state_delta_count_logging(): # State delta count should have incremented (one per unique state update) assert bridge.state_delta_count >= 1 + + +# ============================================================================= +# TextReasoningContent Tests +# ============================================================================= + + +async def test_text_reasoning_content_with_text(): + """Test TextReasoningContent with text field (Anthropic extended thinking format).""" + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + update = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="Let me think about this...")] + ) + events = await bridge.from_agent_run_update(update) + + # Should emit: ThinkingStartEvent, ThinkingTextMessageStartEvent, ThinkingTextMessageContentEvent + assert len(events) == 3 + assert events[0].type == "THINKING_START" + assert events[1].type == "THINKING_TEXT_MESSAGE_START" + assert events[2].type == "THINKING_TEXT_MESSAGE_CONTENT" + assert events[2].delta == "Let me think about this..." + + +async def test_text_reasoning_content_with_protected_data(): + """Test TextReasoningContent with protected_data field (OpenAI reasoning_details format).""" + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + # OpenAI reasoning_details format + protected_data = json.dumps([{"type": "text", "content": "Reasoning step 1"}]) + update = AgentRunResponseUpdate( + contents=[TextReasoningContent(text=None, protected_data=protected_data)] + ) + events = await bridge.from_agent_run_update(update) + + assert len(events) == 3 + assert events[0].type == "THINKING_START" + assert events[1].type == "THINKING_TEXT_MESSAGE_START" + assert events[2].type == "THINKING_TEXT_MESSAGE_CONTENT" + assert events[2].delta == "Reasoning step 1" + + +async def test_text_reasoning_streaming(): + """Test streaming TextReasoningContent with multiple chunks.""" + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + update1 = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="First ")] + ) + update2 = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="second ")] + ) + update3 = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="third")] + ) + + events1 = await bridge.from_agent_run_update(update1) + events2 = await bridge.from_agent_run_update(update2) + events3 = await bridge.from_agent_run_update(update3) + + # First chunk: START events + content + assert len(events1) == 3 + assert events1[0].type == "THINKING_START" + assert events1[1].type == "THINKING_TEXT_MESSAGE_START" + assert events1[2].type == "THINKING_TEXT_MESSAGE_CONTENT" + assert events1[2].delta == "First " + + # Subsequent chunks: just content (no duplicate START events) + assert len(events2) == 1 + assert events2[0].type == "THINKING_TEXT_MESSAGE_CONTENT" + assert events2[0].delta == "second " + + assert len(events3) == 1 + assert events3[0].type == "THINKING_TEXT_MESSAGE_CONTENT" + assert events3[0].delta == "third" + + +async def test_text_reasoning_then_text(): + """Test transition from reasoning content to regular text content. + + Note: END events are emitted by orchestrator (like TextMessageEndEvent), + not by from_agent_run_update. This test verifies the bridge behavior, + and the orchestrator is responsible for calling finalize_thinking() + before yielding TEXT_MESSAGE_START events. + """ + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + # First: reasoning content + reasoning_update = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="Thinking...")] + ) + reasoning_events = await bridge.from_agent_run_update(reasoning_update) + + # Then: regular text content (bridge does NOT auto-emit END events) + text_update = AgentRunResponseUpdate( + contents=[TextContent(text="Here is the answer")] + ) + text_events = await bridge.from_agent_run_update(text_update) + + # Reasoning events: START events + content + assert len(reasoning_events) == 3 + assert reasoning_events[0].type == "THINKING_START" + + # Text events: just text events (no thinking END - that's orchestrator's job) + assert len(text_events) == 2 + assert text_events[0].type == "TEXT_MESSAGE_START" + assert text_events[1].type == "TEXT_MESSAGE_CONTENT" + assert text_events[1].delta == "Here is the answer" + + # Thinking is still "open" from bridge's perspective + assert bridge._thinking_started is True + + # Orchestrator would call finalize_thinking() before TEXT_MESSAGE_START + end_events = bridge.finalize_thinking() + assert len(end_events) == 2 + assert end_events[0].type == "THINKING_TEXT_MESSAGE_END" + assert end_events[1].type == "THINKING_END" + + +async def test_text_reasoning_content_empty_text(): + """Test TextReasoningContent with empty text returns no events.""" + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + update = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="")] + ) + events = await bridge.from_agent_run_update(update) + + # Empty text should not emit any events + assert len(events) == 0 + + +async def test_text_reasoning_content_none_text(): + """Test TextReasoningContent with None text and no protected_data returns no events.""" + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + update = AgentRunResponseUpdate( + contents=[TextReasoningContent(text=None)] + ) + events = await bridge.from_agent_run_update(update) + + # None text with no protected_data should not emit any events + assert len(events) == 0 + + +async def test_finalize_thinking(): + """Test finalize_thinking() closes open thinking events.""" + from agent_framework import TextReasoningContent + + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + # Start thinking + update = AgentRunResponseUpdate( + contents=[TextReasoningContent(text="Thinking...")] + ) + await bridge.from_agent_run_update(update) + + # Verify thinking is open + assert bridge._thinking_started is True + assert bridge._thinking_text_started is True + + # Finalize thinking + end_events = bridge.finalize_thinking() + + # Should emit END events + assert len(end_events) == 2 + assert end_events[0].type == "THINKING_TEXT_MESSAGE_END" + assert end_events[1].type == "THINKING_END" + + # Verify state is reset + assert bridge._thinking_started is False + assert bridge._thinking_text_started is False + + +async def test_finalize_thinking_when_not_started(): + """Test finalize_thinking() when no thinking was started returns empty list.""" + from agent_framework_ag_ui._events import AgentFrameworkEventBridge + + bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + + # Finalize without starting thinking + end_events = bridge.finalize_thinking() + + assert len(end_events) == 0 diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 305757356d..9ad2925611 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -276,6 +276,10 @@ def _parse_response_update_from_openai( contents.append(text_content) if reasoning_details := getattr(choice.delta, "reasoning_details", None): contents.append(TextReasoningContent(None, protected_data=json.dumps(reasoning_details))) + # Handle custom reasoning field for OpenAI-compatible APIs (e.g., Kimi, DeepSeek) + if reasoning_field := getattr(self, "reasoning_field", None): + if reasoning_content := getattr(choice.delta, reasoning_field, None): + contents.append(TextReasoningContent(text=reasoning_content, raw_representation=choice)) return ChatResponseUpdate( created_at=datetime.fromtimestamp(chunk.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), contents=contents, @@ -520,6 +524,7 @@ def __init__( base_url: str | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, + reasoning_field: str | None = None, ) -> None: """Initialize an OpenAI Chat completion client. @@ -541,6 +546,9 @@ def __init__( env_file_path: Use the environment settings file as a fallback to environment variables. env_file_encoding: The encoding of the environment settings file. + reasoning_field: The field name for reasoning content in OpenAI-compatible APIs. + For example, Kimi and DeepSeek use "reasoning_content". If not set, + only the standard "reasoning_details" field (used by o1 models) is processed. Examples: .. code-block:: python @@ -588,4 +596,5 @@ def __init__( default_headers=default_headers, client=async_client, instruction_role=instruction_role, + reasoning_field=reasoning_field, ) diff --git a/python/packages/core/tests/openai/test_openai_chat_client.py b/python/packages/core/tests/openai/test_openai_chat_client.py index 18854799fd..ca3d795d95 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client.py +++ b/python/packages/core/tests/openai/test_openai_chat_client.py @@ -894,3 +894,170 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert result["type"] == "file" assert "filename" not in result["file"] # None filename should be omitted + + +# ============================================================================= +# reasoning_field Tests +# ============================================================================= + + +async def test_streaming_reasoning_content(): + """Test that reasoning_field parameter enables custom reasoning content extraction.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework import TextReasoningContent + + from agent_framework.openai import OpenAIChatClient + + client = OpenAIChatClient( + model_id="kimi-k2-reasoning", + api_key="test_key", + base_url="https://api.moonshot.cn/v1", + reasoning_field="reasoning_content", + ) + + # Create mock chunk with reasoning_content + mock_chunk = MagicMock() + mock_chunk.id = "chunk_1" + mock_chunk.model = "kimi-k2-reasoning" + mock_chunk.created = 1704067200 + mock_chunk.usage = None + mock_chunk.system_fingerprint = "test_fp" + + mock_choice = MagicMock() + mock_choice.index = 0 + mock_choice.finish_reason = None + mock_choice.logprobs = None + + mock_delta = MagicMock() + mock_delta.content = None + mock_delta.tool_calls = None + mock_delta.refusal = None + mock_delta.reasoning_content = "Let me think step by step..." + mock_delta.reasoning_details = None # Explicitly set to None to avoid MagicMock + + mock_choice.delta = mock_delta + mock_chunk.choices = [mock_choice] + + # Parse the chunk + result = client._parse_response_update_from_openai(mock_chunk) + + # Should have TextReasoningContent with text field + reasoning_contents = [c for c in result.contents if isinstance(c, TextReasoningContent)] + assert len(reasoning_contents) == 1 + assert reasoning_contents[0].text == "Let me think step by step..." + + +async def test_streaming_without_reasoning_field(): + """Test that without reasoning_field, custom reasoning content is not extracted.""" + from unittest.mock import MagicMock + + from agent_framework import TextReasoningContent + + from agent_framework.openai import OpenAIChatClient + + client = OpenAIChatClient( + model_id="gpt-4", + api_key="test_key", + # Note: reasoning_field not set + ) + + # Create mock chunk with reasoning_content + mock_chunk = MagicMock() + mock_chunk.id = "chunk_1" + mock_chunk.model = "gpt-4" + mock_chunk.created = 1704067200 + mock_chunk.usage = None + mock_chunk.system_fingerprint = "test_fp" + + mock_choice = MagicMock() + mock_choice.index = 0 + mock_choice.finish_reason = None + mock_choice.logprobs = None + + mock_delta = MagicMock() + mock_delta.content = None + mock_delta.tool_calls = None + mock_delta.refusal = None + mock_delta.reasoning_content = "This should be ignored" + mock_delta.reasoning_details = None # Explicitly set to None + + mock_choice.delta = mock_delta + mock_chunk.choices = [mock_choice] + + # Parse the chunk + result = client._parse_response_update_from_openai(mock_chunk) + + # Should NOT have TextReasoningContent since reasoning_field is not set + reasoning_contents = [c for c in result.contents if isinstance(c, TextReasoningContent)] + assert len(reasoning_contents) == 0 + + +async def test_reasoning_field_with_regular_text(): + """Test that reasoning_field works alongside regular text content.""" + from unittest.mock import MagicMock + + from agent_framework import TextContent, TextReasoningContent + + from agent_framework.openai import OpenAIChatClient + + client = OpenAIChatClient( + model_id="deepseek-reasoner", + api_key="test_key", + base_url="https://api.deepseek.com", + reasoning_field="reasoning_content", + ) + + # Create mock chunk with both reasoning and regular content + mock_chunk = MagicMock() + mock_chunk.id = "chunk_1" + mock_chunk.model = "deepseek-reasoner" + mock_chunk.created = 1704067200 + mock_chunk.usage = None + mock_chunk.system_fingerprint = "test_fp" + + mock_choice = MagicMock() + mock_choice.index = 0 + mock_choice.finish_reason = None + mock_choice.logprobs = None + + mock_delta = MagicMock() + mock_delta.content = "Final answer: 42" + mock_delta.tool_calls = None + mock_delta.refusal = None + mock_delta.reasoning_content = "Calculating..." + mock_delta.reasoning_details = None + + mock_choice.delta = mock_delta + mock_chunk.choices = [mock_choice] + + # Parse the chunk + result = client._parse_response_update_from_openai(mock_chunk) + + # Should have both TextContent and TextReasoningContent + text_contents = [c for c in result.contents if isinstance(c, TextContent)] + reasoning_contents = [c for c in result.contents if isinstance(c, TextReasoningContent)] + + assert len(text_contents) == 1 + assert text_contents[0].text == "Final answer: 42" + assert len(reasoning_contents) == 1 + assert reasoning_contents[0].text == "Calculating..." + + +async def test_reasoning_field_attribute_stored(): + """Test that reasoning_field is properly stored as an instance attribute.""" + from agent_framework.openai import OpenAIChatClient + + client_with_reasoning = OpenAIChatClient( + model_id="kimi-k2-reasoning", + api_key="test_key", + reasoning_field="reasoning_content", + ) + + client_without_reasoning = OpenAIChatClient( + model_id="gpt-4", + api_key="test_key", + ) + + assert getattr(client_with_reasoning, "reasoning_field", None) == "reasoning_content" + assert getattr(client_without_reasoning, "reasoning_field", None) is None