|
8 | 8 | from unittest.mock import Mock, MagicMock, AsyncMock, patch
|
9 | 9 |
|
10 | 10 |
|
11 |
| -from ag_ui_adk import ADKAgent, SessionManager, EventTranslator |
| 11 | +from ag_ui_adk import ADKAgent, SessionManager |
| 12 | +from ag_ui_adk.event_translator import EventTranslator |
12 | 13 | from ag_ui.core import (
|
13 | 14 | RunAgentInput, EventType, UserMessage, Context,
|
14 |
| - RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage |
| 15 | + RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage, |
| 16 | + TextMessageContentEvent |
15 | 17 | )
|
16 | 18 | from google.adk.agents import Agent
|
17 | 19 |
|
@@ -201,7 +203,8 @@ async def run_async(self, *args, **kwargs):
|
201 | 203 | # Confirm branch selection
|
202 | 204 | assert len(streaming_calls) == 1
|
203 | 205 | assert lro_calls == []
|
204 |
| - |
| 206 | + |
| 207 | + @pytest.mark.asyncio |
205 | 208 | async def test_partial_final_chunk_uses_streaming_translation(self, adk_agent, sample_input):
|
206 | 209 | """Ensure partial chunks marked as final still use streaming translation."""
|
207 | 210 |
|
@@ -251,6 +254,104 @@ async def run_async(self, *args, **kwargs):
|
251 | 254 | assert translate_calls == 1
|
252 | 255 | assert lro_calls == 0
|
253 | 256 |
|
| 257 | + @pytest.mark.asyncio |
| 258 | + async def test_streaming_finish_reason_fallback(self, adk_agent, sample_input): |
| 259 | + """Ensure streaming translator handles final responses missing finish_reason.""" |
| 260 | + |
| 261 | + text_part = SimpleNamespace(text="Hello from stream", function_call=None) |
| 262 | + streaming_event = SimpleNamespace( |
| 263 | + id="event-stream", |
| 264 | + author="assistant", |
| 265 | + content=SimpleNamespace(parts=[text_part]), |
| 266 | + partial=False, |
| 267 | + turn_complete=True, |
| 268 | + usage_metadata={"tokens": 9}, |
| 269 | + finish_reason=None, |
| 270 | + actions=None, |
| 271 | + custom_data=None, |
| 272 | + long_running_tool_ids=[], |
| 273 | + ) |
| 274 | + streaming_event.is_final_response = lambda: True |
| 275 | + streaming_event.get_function_calls = Mock(return_value=[]) |
| 276 | + streaming_event.get_function_responses = Mock(return_value=[]) |
| 277 | + |
| 278 | + function_call = SimpleNamespace(id="tool-1", name="long_tool", args={"foo": "bar"}) |
| 279 | + function_part = SimpleNamespace(text=None, function_call=function_call) |
| 280 | + lro_event = SimpleNamespace( |
| 281 | + id="event-lro", |
| 282 | + author="assistant", |
| 283 | + content=SimpleNamespace(parts=[function_part]), |
| 284 | + partial=False, |
| 285 | + turn_complete=True, |
| 286 | + usage_metadata={"tokens": 1}, |
| 287 | + finish_reason="STOP", |
| 288 | + actions=None, |
| 289 | + custom_data=None, |
| 290 | + long_running_tool_ids=[function_call.id], |
| 291 | + ) |
| 292 | + lro_event.is_final_response = lambda: True |
| 293 | + lro_event.get_function_calls = Mock(return_value=[]) |
| 294 | + lro_event.get_function_responses = Mock(return_value=[]) |
| 295 | + |
| 296 | + events_to_yield = [streaming_event, lro_event] |
| 297 | + |
| 298 | + class DummyRunner: |
| 299 | + async def run_async(self, *args, **kwargs): |
| 300 | + for event in events_to_yield: |
| 301 | + yield event |
| 302 | + |
| 303 | + captured_stream_events = [] |
| 304 | + captured_lro_events = [] |
| 305 | + |
| 306 | + original_translate = EventTranslator.translate |
| 307 | + original_translate_lro = EventTranslator.translate_lro_function_calls |
| 308 | + |
| 309 | + async def translate_spy(self, adk_event, thread_id, run_id): |
| 310 | + translate_spy.call_count += 1 |
| 311 | + translate_spy.adk_events.append(adk_event) |
| 312 | + async for event in original_translate(self, adk_event, thread_id, run_id): |
| 313 | + captured_stream_events.append(event) |
| 314 | + yield event |
| 315 | + |
| 316 | + translate_spy.call_count = 0 |
| 317 | + translate_spy.adk_events = [] |
| 318 | + |
| 319 | + async def translate_lro_spy(self, adk_event): |
| 320 | + translate_lro_spy.call_count += 1 |
| 321 | + translate_lro_spy.adk_events.append(adk_event) |
| 322 | + async for event in original_translate_lro(self, adk_event): |
| 323 | + captured_lro_events.append(event) |
| 324 | + yield event |
| 325 | + |
| 326 | + translate_lro_spy.call_count = 0 |
| 327 | + translate_lro_spy.adk_events = [] |
| 328 | + |
| 329 | + dummy_runner = DummyRunner() |
| 330 | + |
| 331 | + with patch.object(EventTranslator, "translate", translate_spy), \ |
| 332 | + patch.object(EventTranslator, "translate_lro_function_calls", translate_lro_spy), \ |
| 333 | + patch.object(adk_agent, "_create_runner", return_value=dummy_runner): |
| 334 | + |
| 335 | + emitted_events = [] |
| 336 | + async for event in adk_agent.run(sample_input): |
| 337 | + emitted_events.append(event) |
| 338 | + |
| 339 | + # Assert streaming translator was used for the first event |
| 340 | + assert translate_spy.call_count == 1 |
| 341 | + assert translate_spy.adk_events[0] is streaming_event |
| 342 | + |
| 343 | + # Confirm streaming content flowed through as expected |
| 344 | + text_events = [event for event in emitted_events if isinstance(event, TextMessageContentEvent)] |
| 345 | + assert text_events and text_events[0].delta == "Hello from stream" |
| 346 | + assert any(isinstance(event, TextMessageContentEvent) for event in captured_stream_events) |
| 347 | + |
| 348 | + # Long-running translation should be invoked only for the STOP event |
| 349 | + assert translate_lro_spy.call_count == 1 |
| 350 | + assert translate_lro_spy.adk_events[0] is lro_event |
| 351 | + |
| 352 | + # Ensure we produced a tool call event to guard against regressions |
| 353 | + assert any(event.type == EventType.TOOL_CALL_END for event in captured_lro_events) |
| 354 | + |
254 | 355 | @pytest.mark.asyncio
|
255 | 356 | async def test_session_management(self, adk_agent):
|
256 | 357 | """Test session lifecycle management."""
|
|
0 commit comments