diff --git a/typescript-sdk/apps/dojo/e2e/featurePages/ToolBaseGenUIPage.ts b/typescript-sdk/apps/dojo/e2e/featurePages/ToolBaseGenUIPage.ts index c836a49a3..6208e3072 100644 --- a/typescript-sdk/apps/dojo/e2e/featurePages/ToolBaseGenUIPage.ts +++ b/typescript-sdk/apps/dojo/e2e/featurePages/ToolBaseGenUIPage.ts @@ -158,4 +158,4 @@ export class ToolBaseGenUIPage { expect(foundMatch).toBe(true); } -} +} \ No newline at end of file diff --git a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/__init__.py b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/__init__.py index 0552f8059..3729083f5 100644 --- a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/__init__.py +++ b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/__init__.py @@ -5,6 +5,12 @@ This middleware enables Google ADK agents to be used with the AG-UI protocol. """ +from __future__ import annotations + +import logging +import os +from typing import Dict, Iterable + from .adk_agent import ADKAgent from .event_translator import EventTranslator from .session_manager import SessionManager @@ -12,4 +18,51 @@ __all__ = ['ADKAgent', 'add_adk_fastapi_endpoint', 'create_adk_app', 'EventTranslator', 'SessionManager'] -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.1.0" + + +def _configure_logging_from_env() -> None: + """Configure component loggers based on environment variables.""" + + root_level = os.getenv('LOG_ROOT_LEVEL') + if root_level: + try: + level = getattr(logging, root_level.upper()) + except AttributeError: + logging.getLogger(__name__).warning( + "Invalid LOG_ROOT_LEVEL value '%s'", root_level + ) + else: + logging.basicConfig(level=level, force=True) + + component_levels: Dict[str, Iterable[str]] = { + 'LOG_ADK_AGENT': ('ag_ui_adk.adk_agent',), + 'LOG_EVENT_TRANSLATOR': ( + 'ag_ui_adk.event_translator', + 'event_translator', + ), + 'LOG_ENDPOINT': ('ag_ui_adk.endpoint', 'endpoint'), + 'LOG_SESSION_MANAGER': ( + 'ag_ui_adk.session_manager', + 'session_manager', + ), + } + + for env_var, logger_names in component_levels.items(): + level_name = os.getenv(env_var) + if not level_name: + continue + + try: + level = getattr(logging, level_name.upper()) + except AttributeError: + logging.getLogger(__name__).warning( + "Invalid value '%s' for %s", level_name, env_var + ) + continue + + for logger_name in logger_names: + logging.getLogger(logger_name).setLevel(level) + + +_configure_logging_from_env() diff --git a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py index 724864a61..b18d6fcf7 100644 --- a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py +++ b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py @@ -916,19 +916,48 @@ async def _run_adk_in_background( final_response = adk_event.is_final_response() has_content = adk_event.content and hasattr(adk_event.content, 'parts') and adk_event.content.parts - if not final_response or (not adk_event.usage_metadata and has_content): - # Translate and emit events + # Check if this is a streaming chunk that needs regular processing + is_streaming_chunk = ( + getattr(adk_event, 'partial', False) or # Explicitly marked as partial + (not getattr(adk_event, 'turn_complete', True)) or # Live streaming not complete + (not final_response) # Not marked as final by is_final_response() + ) + + # Prefer LRO routing when a long-running tool call is present + has_lro_function_call = False + try: + lro_ids = set(getattr(adk_event, 'long_running_tool_ids', []) or []) + if lro_ids and adk_event.content and getattr(adk_event.content, 'parts', None): + for part in adk_event.content.parts: + func = getattr(part, 'function_call', None) + func_id = getattr(func, 'id', None) if func else None + if func_id and func_id in lro_ids: + has_lro_function_call = True + break + except Exception: + # Be conservative: if detection fails, do not block streaming path + has_lro_function_call = False + + # Process as streaming if it's a chunk OR if it has content but no finish_reason, + # but only when there is no LRO function call present (LRO takes precedence) + if (not has_lro_function_call) and (is_streaming_chunk or (has_content and not getattr(adk_event, 'finish_reason', None))): + # Regular translation path async for ag_ui_event in event_translator.translate( adk_event, input.thread_id, input.run_id ): - + logger.debug(f"Emitting event to queue: {type(ag_ui_event).__name__} (thread {input.thread_id}, queue size before: {event_queue.qsize()})") await event_queue.put(ag_ui_event) logger.debug(f"Event queued: {type(ag_ui_event).__name__} (thread {input.thread_id}, queue size after: {event_queue.qsize()})") else: - # LongRunning Tool events are usually emmitted in final response + # LongRunning Tool events are usually emitted in final response + # Ensure any active streaming text message is closed BEFORE tool calls + async for end_event in event_translator.force_close_streaming_message(): + await event_queue.put(end_event) + logger.debug(f"Event queued (forced close): {type(end_event).__name__} (thread {input.thread_id}, queue size after: {event_queue.qsize()})") + async for ag_ui_event in event_translator.translate_lro_function_calls( adk_event ): @@ -994,4 +1023,4 @@ async def close(self): self._session_lookup_cache.clear() # Stop session manager cleanup task - await self._session_manager.stop_cleanup_task() \ No newline at end of file + await self._session_manager.stop_cleanup_task() diff --git a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py index efb674e17..4dd01199c 100644 --- a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py +++ b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py @@ -87,16 +87,24 @@ async def translate( if hasattr(adk_event, 'get_function_calls'): function_calls = adk_event.get_function_calls() if function_calls: - logger.debug(f"ADK function calls detected: {len(function_calls)} calls") - - # CRITICAL FIX: End any active text message stream before starting tool calls - # Per AG-UI protocol: TEXT_MESSAGE_END must be sent before TOOL_CALL_START - async for event in self.force_close_streaming_message(): - yield event - - # NOW ACTUALLY YIELD THE EVENTS - async for event in self._translate_function_calls(function_calls): - yield event + # Filter out long-running tool calls; those are handled by translate_lro_function_calls + try: + lro_ids = set(getattr(adk_event, 'long_running_tool_ids', []) or []) + except Exception: + lro_ids = set() + + non_lro_calls = [fc for fc in function_calls if getattr(fc, 'id', None) not in lro_ids] + + if non_lro_calls: + logger.debug(f"ADK function calls detected (non-LRO): {len(non_lro_calls)} of {len(function_calls)} total") + # CRITICAL FIX: End any active text message stream before starting tool calls + # Per AG-UI protocol: TEXT_MESSAGE_END must be sent before TOOL_CALL_START + async for event in self.force_close_streaming_message(): + yield event + + # Yield only non-LRO function call events + async for event in self._translate_function_calls(non_lro_calls): + yield event # Handle function responses and yield the tool response event # this is essential for scenerios when user has to render function response at frontend @@ -164,12 +172,17 @@ async def _translate_text_content( elif hasattr(adk_event, 'is_final_response'): is_final_response = adk_event.is_final_response - # Handle None values: if is_final_response=True, it means streaming should end - should_send_end = is_final_response and not is_partial - + # Handle None values: if a turn is complete or a final chunk arrives, end streaming + has_finish_reason = bool(getattr(adk_event, 'finish_reason', None)) + should_send_end = ( + (turn_complete and not is_partial) + or (is_final_response and not is_partial) + or (has_finish_reason and self._is_streaming) + ) + logger.info(f"๐Ÿ“ฅ Text event - partial={is_partial}, turn_complete={turn_complete}, " - f"is_final_response={is_final_response}, should_send_end={should_send_end}, " - f"currently_streaming={self._is_streaming}") + f"is_final_response={is_final_response}, has_finish_reason={has_finish_reason}, " + f"should_send_end={should_send_end}, currently_streaming={self._is_streaming}") if is_final_response: @@ -464,4 +477,4 @@ def reset(self): self._streaming_message_id = None self._is_streaming = False self.long_running_tool_ids.clear() - logger.debug("Reset EventTranslator state (including streaming state)") \ No newline at end of file + logger.debug("Reset EventTranslator state (including streaming state)") diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/conftest.py b/typescript-sdk/integrations/adk-middleware/python/tests/conftest.py new file mode 100644 index 000000000..b64927fa8 --- /dev/null +++ b/typescript-sdk/integrations/adk-middleware/python/tests/conftest.py @@ -0,0 +1,20 @@ +"""Shared pytest fixtures for ADK middleware tests.""" + +from __future__ import annotations + +import pytest + +from ag_ui.core import SystemMessage as CoreSystemMessage + +import ag_ui_adk.adk_agent as adk_agent_module + + +@pytest.fixture(autouse=True) +def restore_system_message_class(): + """Ensure every test starts and ends with the real SystemMessage type.""" + + adk_agent_module.SystemMessage = CoreSystemMessage + try: + yield + finally: + adk_agent_module.SystemMessage = CoreSystemMessage diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/test_adk_agent.py b/typescript-sdk/integrations/adk-middleware/python/tests/test_adk_agent.py index 52e95ab68..92c586ea9 100644 --- a/typescript-sdk/integrations/adk-middleware/python/tests/test_adk_agent.py +++ b/typescript-sdk/integrations/adk-middleware/python/tests/test_adk_agent.py @@ -4,13 +4,16 @@ import pytest import asyncio +from types import SimpleNamespace from unittest.mock import Mock, MagicMock, AsyncMock, patch from ag_ui_adk import ADKAgent, SessionManager +from ag_ui_adk.event_translator import EventTranslator from ag_ui.core import ( RunAgentInput, EventType, UserMessage, Context, - RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage + RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage, + TextMessageContentEvent ) from google.adk.agents import Agent @@ -137,6 +140,218 @@ async def mock_run_async(*args, **kwargs): assert events[0].type == EventType.RUN_STARTED assert events[-1].type == EventType.RUN_FINISHED + @pytest.mark.asyncio + async def test_turn_complete_falls_back_to_streaming_translator( + self, + adk_agent, + sample_input, + ): + """Ensure turn_complete=False triggers streaming translation path.""" + + streaming_calls = [] + lro_calls = [] + + async def fake_translate(self, adk_event, thread_id, run_id): + streaming_calls.append((adk_event, thread_id, run_id)) + yield TextMessageChunkEvent( + message_id=adk_event.id, + role="assistant", + delta="streamed chunk", + ) + + async def fake_translate_lro(self, adk_event): + lro_calls.append(adk_event) + if False: # pragma: no cover - required to keep async generator signature + yield None + + mock_event = Mock() + mock_event.id = "event_stream" + mock_event.author = "assistant" + mock_event.partial = False + mock_event.turn_complete = False + mock_event.finish_reason = "STOP" + mock_event.usage_metadata = {"tokens": 5} + mock_event.is_final_response = Mock(return_value=True) + mock_event.content = Mock() + mock_event.content.parts = [Mock(text="Final response chunk")] + mock_event.actions = None + mock_event.get_function_calls = Mock(return_value=[]) + mock_event.get_function_responses = Mock(return_value=[]) + mock_event.custom_data = None + + class DummyRunner: + async def run_async(self, *args, **kwargs): + yield mock_event + + with patch.object(adk_agent, '_create_runner', return_value=DummyRunner()), \ + patch.object(EventTranslator, 'translate', new=fake_translate), \ + patch.object(EventTranslator, 'translate_lro_function_calls', new=fake_translate_lro): + + events = [] + async for event in adk_agent.run(sample_input): + events.append(event) + + # Verify run lifecycle events emitted + assert events[0].type == EventType.RUN_STARTED + assert events[-1].type == EventType.RUN_FINISHED + + # Ensure streaming translator branch handled the event + chunk_events = [event for event in events if isinstance(event, TextMessageChunkEvent)] + assert chunk_events, "Expected translated chunk event" + assert chunk_events[0].delta == "streamed chunk" + + # Confirm branch selection + assert len(streaming_calls) == 1 + assert lro_calls == [] + + @pytest.mark.asyncio + async def test_partial_final_chunk_uses_streaming_translation(self, adk_agent, sample_input): + """Ensure partial chunks marked as final still use streaming translation.""" + + translate_calls = 0 + lro_calls = 0 + + async def fake_translate(self, adk_event, thread_id, run_id): + nonlocal translate_calls + translate_calls += 1 + yield TextMessageChunkEvent( + type=EventType.TEXT_MESSAGE_CHUNK, + message_id=adk_event.id, + delta="chunk" + ) + + async def fake_translate_lro(self, adk_event): + nonlocal lro_calls + lro_calls += 1 + if False: + yield # pragma: no cover - keeps this an async generator + + adk_event = SimpleNamespace( + id="event-final-chunk", + author="assistant", + content=SimpleNamespace(parts=[SimpleNamespace(text="hello")]), + partial=True, + turn_complete=True, + usage_metadata={"tokens": 1}, + finish_reason="STOP", + actions=None, + custom_data=None, + get_function_calls=lambda: [], + get_function_responses=lambda: [], + is_final_response=lambda: True + ) + + class FakeRunner: + async def run_async(self, *args, **kwargs): + yield adk_event + + with patch("ag_ui_adk.adk_agent.EventTranslator.translate", new=fake_translate), \ + patch("ag_ui_adk.adk_agent.EventTranslator.translate_lro_function_calls", new=fake_translate_lro), \ + patch.object(adk_agent, "_create_runner", return_value=FakeRunner()): + events = [event async for event in adk_agent.run(sample_input)] + + assert any(isinstance(event, TextMessageChunkEvent) for event in events) + assert translate_calls == 1 + assert lro_calls == 0 + + @pytest.mark.asyncio + async def test_streaming_finish_reason_fallback(self, adk_agent, sample_input): + """Ensure streaming translator handles final responses missing finish_reason.""" + + text_part = SimpleNamespace(text="Hello from stream", function_call=None) + streaming_event = SimpleNamespace( + id="event-stream", + author="assistant", + content=SimpleNamespace(parts=[text_part]), + partial=False, + turn_complete=True, + usage_metadata={"tokens": 9}, + finish_reason=None, + actions=None, + custom_data=None, + long_running_tool_ids=[], + ) + streaming_event.is_final_response = lambda: False + streaming_event.get_function_calls = Mock(return_value=[]) + streaming_event.get_function_responses = Mock(return_value=[]) + + function_call = SimpleNamespace(id="tool-1", name="long_tool", args={"foo": "bar"}) + function_part = SimpleNamespace(text=None, function_call=function_call) + lro_event = SimpleNamespace( + id="event-lro", + author="assistant", + content=SimpleNamespace(parts=[function_part]), + partial=False, + turn_complete=True, + usage_metadata={"tokens": 1}, + finish_reason="STOP", + actions=None, + custom_data=None, + long_running_tool_ids=[function_call.id], + ) + lro_event.is_final_response = lambda: True + lro_event.get_function_calls = Mock(return_value=[]) + lro_event.get_function_responses = Mock(return_value=[]) + + events_to_yield = [streaming_event, lro_event] + + class DummyRunner: + async def run_async(self, *args, **kwargs): + for event in events_to_yield: + yield event + + captured_stream_events = [] + captured_lro_events = [] + + original_translate = EventTranslator.translate + original_translate_lro = EventTranslator.translate_lro_function_calls + + async def translate_spy(self, adk_event, thread_id, run_id): + translate_spy.call_count += 1 + translate_spy.adk_events.append(adk_event) + async for event in original_translate(self, adk_event, thread_id, run_id): + captured_stream_events.append(event) + yield event + + translate_spy.call_count = 0 + translate_spy.adk_events = [] + + async def translate_lro_spy(self, adk_event): + translate_lro_spy.call_count += 1 + translate_lro_spy.adk_events.append(adk_event) + async for event in original_translate_lro(self, adk_event): + captured_lro_events.append(event) + yield event + + translate_lro_spy.call_count = 0 + translate_lro_spy.adk_events = [] + + dummy_runner = DummyRunner() + + with patch.object(EventTranslator, "translate", translate_spy), \ + patch.object(EventTranslator, "translate_lro_function_calls", translate_lro_spy), \ + patch.object(adk_agent, "_create_runner", return_value=dummy_runner): + + emitted_events = [] + async for event in adk_agent.run(sample_input): + emitted_events.append(event) + + # Assert streaming translator was used for the first event + assert translate_spy.call_count == 1 + assert translate_spy.adk_events[0] is streaming_event + + # Confirm streaming content flowed through as expected + text_events = [event for event in emitted_events if isinstance(event, TextMessageContentEvent)] + assert text_events and text_events[0].delta == "Hello from stream" + assert any(isinstance(event, TextMessageContentEvent) for event in captured_stream_events) + + # Long-running translation should be invoked only for the STOP event + assert translate_lro_spy.call_count == 1 + assert translate_lro_spy.adk_events[0] is lro_event + + # Ensure we produced a tool call event to guard against regressions + assert any(event.type == EventType.TOOL_CALL_END for event in captured_lro_events) + @pytest.mark.asyncio async def test_session_management(self, adk_agent): """Test session lifecycle management.""" diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py b/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py index 8475cc8cb..b1e13c871 100644 --- a/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py +++ b/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py @@ -90,20 +90,21 @@ async def test_translate_event_with_empty_parts(self, translator, mock_adk_event @pytest.mark.asyncio async def test_translate_function_calls_detection(self, translator, mock_adk_event): - """Test function calls detection and logging.""" - # Mock event with function calls + """Test that function calls produce ToolCall events.""" mock_function_call = MagicMock() mock_function_call.name = "test_function" + mock_function_call.id = "call_123" + mock_function_call.args = {"param": "value"} mock_adk_event.get_function_calls = MagicMock(return_value=[mock_function_call]) - with patch('ag_ui_adk.event_translator.logger') as mock_logger: - events = [] - async for event in translator.translate(mock_adk_event, "thread_1", "run_1"): - events.append(event) + events = [] + async for event in translator.translate(mock_adk_event, "thread_1", "run_1"): + events.append(event) - # Should log function calls detection (along with the ADK Event debug log) - debug_calls = [str(call) for call in mock_logger.debug.call_args_list] - assert any("ADK function calls detected: 1 calls" in call for call in debug_calls) + type_names = [str(event.type).split('.')[-1] for event in events] + assert type_names == ["TOOL_CALL_START", "TOOL_CALL_ARGS", "TOOL_CALL_END"] + ids = [getattr(event, 'tool_call_id', None) for event in events] + assert ids == ["call_123", "call_123", "call_123"] @pytest.mark.asyncio async def test_translate_function_responses_handling(self, translator, mock_adk_event): @@ -223,9 +224,14 @@ async def test_translate_text_content_partial_streaming(self, translator, mock_a async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"): events.append(event) - assert len(events) == 3 # START, CONTENT , END + # The translator keeps streaming open; forcing a close should yield END + async for event in translator.force_close_streaming_message(): + events.append(event) + + assert len(events) == 3 # START, CONTENT, END (forced close) assert isinstance(events[0], TextMessageStartEvent) assert isinstance(events[1], TextMessageContentEvent) + assert isinstance(events[2], TextMessageEndEvent) @pytest.mark.asyncio async def test_translate_text_content_final_response_callable(self, translator, mock_adk_event_with_content): @@ -752,8 +758,8 @@ async def test_partial_streaming_continuation(self, translator, mock_adk_event_w async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"): events1.append(event) - assert len(events1) == 3 # START, CONTENT , END - assert translator._is_streaming is False + assert len(events1) == 2 # START, CONTENT (stream remains open) + assert translator._is_streaming is True message_id = events1[0].message_id # Second partial event (should continue streaming) @@ -764,9 +770,11 @@ async def test_partial_streaming_continuation(self, translator, mock_adk_event_w async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"): events2.append(event) - assert len(events2) == 3 # Will start from begining (START , CONTENT , END) - assert isinstance(events2[1], TextMessageContentEvent) - assert events2[0].message_id != message_id # Not the same message ID Because its a new streaming + assert len(events2) == 1 # Additional CONTENT chunk + assert isinstance(events2[0], TextMessageContentEvent) + assert events2[0].message_id == message_id # Same stream continues + assert translator._is_streaming is True + assert translator._streaming_message_id == message_id # Final event (should end streaming - requires is_final_response=True) mock_adk_event_with_content.partial = False @@ -777,8 +785,10 @@ async def test_partial_streaming_continuation(self, translator, mock_adk_event_w async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"): events3.append(event) - assert len(events3) == 0 # No more message (turn Complete) + assert len(events3) == 1 # Final END to close the stream + assert isinstance(events3[0], TextMessageEndEvent) + assert events3[0].message_id == message_id # Should reset streaming state assert translator._is_streaming is False - assert translator._streaming_message_id is None \ No newline at end of file + assert translator._streaming_message_id is None diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/test_integration_mixed_partials.py b/typescript-sdk/integrations/adk-middleware/python/tests/test_integration_mixed_partials.py new file mode 100644 index 000000000..7a7052a2c --- /dev/null +++ b/typescript-sdk/integrations/adk-middleware/python/tests/test_integration_mixed_partials.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +"""Integration test: mixed partials with non-LRO calls before final LRO. + +Scenario: +- Stream text in partial chunks +- Mid-stream, a non-LRO function call appears (should close text and emit tool events) +- Finally, an LRO function call arrives (should close any open text and emit LRO tool events) + +Asserts order, deduplication, and correct tool ids. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, Mock, patch + +from ag_ui.core import ( + RunAgentInput, UserMessage +) +from ag_ui_adk import ADKAgent + + +@pytest.fixture +def adk_agent_instance(): + from google.adk.agents import Agent + mock_agent = Mock(spec=Agent) + mock_agent.name = "test_agent" + return ADKAgent(adk_agent=mock_agent, app_name="test_app", user_id="test_user") + + +@pytest.mark.asyncio +async def test_mixed_partials_non_lro_then_lro(adk_agent_instance): + # Helper to create partial text events + def mk_partial(text): + e = MagicMock() + e.author = "assistant" + e.content = MagicMock(); e.content.parts = [MagicMock(text=text)] + e.partial = True + e.turn_complete = False + e.is_final_response = lambda: False + # No function responses in these partials + e.get_function_responses = lambda: [] + e.get_function_calls = lambda: [] + return e + + # First partial text only + evt1 = mk_partial("Hello") + + # Second partial: text + non-LRO function call + normal_id = "normal-999" + normal_func = MagicMock(); normal_func.id = normal_id; normal_func.name = "regular_tool"; normal_func.args = {"b": 2} + evt2 = mk_partial(" world") + evt2.get_function_calls = lambda: [normal_func] + evt2.long_running_tool_ids = [] + + # Final: LRO function call + lro_id = "lro-777" + lro_func = MagicMock(); lro_func.id = lro_id; lro_func.name = "long_running_tool"; lro_func.args = {"v": 1} + lro_part = MagicMock(); lro_part.function_call = lro_func + + evt3 = MagicMock() + evt3.author = "assistant" + evt3.content = MagicMock(); evt3.content.parts = [lro_part] + evt3.partial = False + evt3.turn_complete = True + evt3.is_final_response = lambda: True + evt3.get_function_calls = lambda: [] + evt3.get_function_responses = lambda: [] + evt3.long_running_tool_ids = [lro_id] + + async def mock_run_async(*args, **kwargs): + yield evt1 + yield evt2 + yield evt3 + + mock_runner = AsyncMock(); mock_runner.run_async = mock_run_async + + sample_input = RunAgentInput( + thread_id="thread_mixed", + run_id="run_mixed", + messages=[UserMessage(id="u1", role="user", content="go")], + tools=[], context=[], state={}, forwarded_props={}, + ) + + with patch.object(adk_agent_instance, "_create_runner", return_value=mock_runner): + events = [] + async for e in adk_agent_instance.run(sample_input): + events.append(e) + + types = [str(ev.type).split(".")[-1] for ev in events] + + # Expect at least one START and 2 CONTENTs from streaming + assert types.count("TEXT_MESSAGE_START") == 1 + assert types.count("TEXT_MESSAGE_CONTENT") >= 2 + + # Non-LRO tool call should appear exactly once + normal_starts = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_START") and getattr(ev, "tool_call_id", None) == normal_id] + normal_args = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_ARGS") and getattr(ev, "tool_call_id", None) == normal_id] + normal_ends = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_END") and getattr(ev, "tool_call_id", None) == normal_id] + assert len(normal_starts) == len(normal_args) == len(normal_ends) == 1 + + # Ensure a TEXT_MESSAGE_END precedes the normal tool start + text_ends = [i for i, t in enumerate(types) if t == "TEXT_MESSAGE_END"] + assert len(text_ends) >= 1 + assert text_ends[-1] < normal_starts[0], "TEXT_MESSAGE_END must precede first non-LRO TOOL_CALL_START" + + # LRO tool call should appear exactly once and after the non-LRO + lro_starts = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_START") and getattr(ev, "tool_call_id", None) == lro_id] + lro_args = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_ARGS") and getattr(ev, "tool_call_id", None) == lro_id] + lro_ends = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_END") and getattr(ev, "tool_call_id", None) == lro_id] + assert len(lro_starts) == len(lro_args) == len(lro_ends) == 1 + assert lro_starts[0] > normal_starts[0] + diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/test_lro_filtering.py b/typescript-sdk/integrations/adk-middleware/python/tests/test_lro_filtering.py new file mode 100644 index 000000000..0a0a8a0f2 --- /dev/null +++ b/typescript-sdk/integrations/adk-middleware/python/tests/test_lro_filtering.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +"""Tests for LRO-aware routing and translator filtering. + +These tests verify that: +- EventTranslator.translate skips long-running tool calls and only emits non-LRO calls +- translate_lro_function_calls emits events only for long-running tool calls +""" + +import asyncio +from unittest.mock import MagicMock + +from ag_ui.core import EventType +from ag_ui_adk import EventTranslator + + +async def test_translate_skips_lro_function_calls(): + """Ensure non-LRO tool calls are emitted and LRO calls are skipped in translate.""" + translator = EventTranslator() + + # Prepare mock ADK event + adk_event = MagicMock() + adk_event.author = "assistant" + adk_event.content = MagicMock() + adk_event.content.parts = [] # no text + + # Two function calls, one is long-running + lro_id = "tool-call-lro-1" + normal_id = "tool-call-normal-2" + + lro_call = MagicMock() + lro_call.id = lro_id + lro_call.name = "long_running_tool" + lro_call.args = {"x": 1} + + normal_call = MagicMock() + normal_call.id = normal_id + normal_call.name = "regular_tool" + normal_call.args = {"y": 2} + + adk_event.get_function_calls = lambda: [lro_call, normal_call] + # Mark the long-running call id on the event + adk_event.long_running_tool_ids = [lro_id] + + events = [] + async for e in translator.translate(adk_event, "thread", "run"): + events.append(e) + + # We expect only the non-LRO tool call events to be emitted + # Sequence: TOOL_CALL_START(normal), TOOL_CALL_ARGS(normal), TOOL_CALL_END(normal) + event_types = [str(ev.type).split('.')[-1] for ev in events] + assert event_types.count("TOOL_CALL_START") == 1 + assert event_types.count("TOOL_CALL_ARGS") == 1 + assert event_types.count("TOOL_CALL_END") == 1 + + # Ensure the emitted tool_call_id is the normal one + ids = set(getattr(ev, 'tool_call_id', None) for ev in events) + assert normal_id in ids + assert lro_id not in ids + + +async def test_translate_lro_function_calls_only_emits_lro(): + """Ensure translate_lro_function_calls emits only for long-running calls.""" + translator = EventTranslator() + + # Prepare mock ADK event with content parts containing function calls + lro_id = "tool-call-lro-3" + normal_id = "tool-call-normal-4" + + lro_call = MagicMock() + lro_call.id = lro_id + lro_call.name = "long_running_tool" + lro_call.args = {"a": 123} + + normal_call = MagicMock() + normal_call.id = normal_id + normal_call.name = "regular_tool" + normal_call.args = {"b": 456} + + # Build parts with both calls + lro_part = MagicMock() + lro_part.function_call = lro_call + normal_part = MagicMock() + normal_part.function_call = normal_call + + adk_event = MagicMock() + adk_event.content = MagicMock() + adk_event.content.parts = [lro_part, normal_part] + adk_event.long_running_tool_ids = [lro_id] + + events = [] + async for e in translator.translate_lro_function_calls(adk_event): + events.append(e) + + # Expect only the LRO call events + # Sequence: TOOL_CALL_START(lro), TOOL_CALL_ARGS(lro), TOOL_CALL_END(lro) + event_types = [str(ev.type).split('.')[-1] for ev in events] + assert event_types == ["TOOL_CALL_START", "TOOL_CALL_ARGS", "TOOL_CALL_END"] + for ev in events: + assert getattr(ev, 'tool_call_id', None) == lro_id + + +if __name__ == "__main__": + asyncio.run(test_translate_skips_lro_function_calls()) + asyncio.run(test_translate_lro_function_calls_only_emits_lro()) + print("\nโœ… LRO filtering tests ran to completion") + diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/test_streaming.py b/typescript-sdk/integrations/adk-middleware/python/tests/test_streaming.py index f5942f887..2967db53a 100644 --- a/typescript-sdk/integrations/adk-middleware/python/tests/test_streaming.py +++ b/typescript-sdk/integrations/adk-middleware/python/tests/test_streaming.py @@ -100,6 +100,72 @@ async def test_streaming_behavior(): print(f" Got: {event_type_strings}") return False +async def test_partial_with_finish_reason(): + """Test the specific scenario: partial=True, is_final_response=False, but finish_reason=STOP. + + This is the bug we fixed - Gemini returns partial=True even on the final chunk with finish_reason. + The fix checks for finish_reason as a fallback to properly close the streaming message. + """ + print("\n๐Ÿงช Testing Partial Event with finish_reason (Bug Fix Scenario)") + print("=================================================================") + + translator = EventTranslator() + + # First event: start streaming + first_event = MagicMock() + first_event.content = MagicMock() + first_event.content.parts = [MagicMock(text="Hello")] + first_event.author = "assistant" + first_event.partial = True + first_event.turn_complete = None + first_event.finish_reason = None + first_event.is_final_response = lambda: False + first_event.get_function_calls = lambda: [] + first_event.get_function_responses = lambda: [] + + # Second event: final chunk with finish_reason BUT still partial=True (the bug scenario!) + final_event = MagicMock() + final_event.content = MagicMock() + final_event.content.parts = [MagicMock(text=" world")] + final_event.author = "assistant" + final_event.partial = True # Still marked as partial! + final_event.turn_complete = None + final_event.finish_reason = "STOP" # But has finish_reason! + final_event.is_final_response = lambda: False # And is_final_response returns False! + final_event.get_function_calls = lambda: [] + final_event.get_function_responses = lambda: [] + + print("\n๐Ÿ“ก Event 1: partial=True, finish_reason=None, is_final_response=False") + print("๐Ÿ“ก Event 2: partial=True, finish_reason=STOP, is_final_response=False โš ๏ธ") + + all_events = [] + + # Process first event + async for ag_ui_event in translator.translate(first_event, "test_thread", "test_run"): + all_events.append(ag_ui_event) + + # Process final event + async for ag_ui_event in translator.translate(final_event, "test_thread", "test_run"): + all_events.append(ag_ui_event) + + event_types = [str(event.type).split('.')[-1] for event in all_events] + + print(f"\n๐Ÿ“Š Generated Events: {event_types}") + + # Expected: START, CONTENT (Hello), CONTENT (world), END + # The fix ensures that finish_reason triggers END even when partial=True and is_final_response=False + expected = ["TEXT_MESSAGE_START", "TEXT_MESSAGE_CONTENT", "TEXT_MESSAGE_CONTENT", "TEXT_MESSAGE_END"] + + if event_types == expected: + print("โœ… Bug fix verified! finish_reason properly triggers TEXT_MESSAGE_END") + print(" Even when partial=True and is_final_response=False") + return True + else: + print(f"โŒ Bug fix failed!") + print(f" Expected: {expected}") + print(f" Got: {event_types}") + return False + async def test_non_streaming(): """Test that complete messages still work.""" print("\n๐Ÿงช Testing Non-Streaming (Complete Messages)") @@ -135,9 +201,10 @@ async def test_non_streaming(): if __name__ == "__main__": async def run_tests(): test1 = await test_streaming_behavior() - test2 = await test_non_streaming() + test2 = await test_partial_with_finish_reason() + test3 = await test_non_streaming() - if test1 and test2: + if test1 and test2 and test3: print("\n๐ŸŽ‰ All streaming tests passed!") print("๐Ÿ’ก Ready for real ADK integration with proper streaming") else: