Skip to content

Commit 014d05b

Browse files
Add test for streaming finish reason fallback
1 parent e38aaf0 commit 014d05b

File tree

1 file changed

+104
-3
lines changed

1 file changed

+104
-3
lines changed

typescript-sdk/integrations/adk-middleware/python/tests/test_adk_agent.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from unittest.mock import Mock, MagicMock, AsyncMock, patch
99

1010

11-
from ag_ui_adk import ADKAgent, SessionManager, EventTranslator
11+
from ag_ui_adk import ADKAgent, SessionManager
12+
from ag_ui_adk.event_translator import EventTranslator
1213
from ag_ui.core import (
1314
RunAgentInput, EventType, UserMessage, Context,
14-
RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage
15+
RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage,
16+
TextMessageContentEvent
1517
)
1618
from google.adk.agents import Agent
1719

@@ -201,7 +203,8 @@ async def run_async(self, *args, **kwargs):
201203
# Confirm branch selection
202204
assert len(streaming_calls) == 1
203205
assert lro_calls == []
204-
206+
207+
@pytest.mark.asyncio
205208
async def test_partial_final_chunk_uses_streaming_translation(self, adk_agent, sample_input):
206209
"""Ensure partial chunks marked as final still use streaming translation."""
207210

@@ -251,6 +254,104 @@ async def run_async(self, *args, **kwargs):
251254
assert translate_calls == 1
252255
assert lro_calls == 0
253256

257+
@pytest.mark.asyncio
258+
async def test_streaming_finish_reason_fallback(self, adk_agent, sample_input):
259+
"""Ensure streaming translator handles final responses missing finish_reason."""
260+
261+
text_part = SimpleNamespace(text="Hello from stream", function_call=None)
262+
streaming_event = SimpleNamespace(
263+
id="event-stream",
264+
author="assistant",
265+
content=SimpleNamespace(parts=[text_part]),
266+
partial=False,
267+
turn_complete=True,
268+
usage_metadata={"tokens": 9},
269+
finish_reason=None,
270+
actions=None,
271+
custom_data=None,
272+
long_running_tool_ids=[],
273+
)
274+
streaming_event.is_final_response = lambda: True
275+
streaming_event.get_function_calls = Mock(return_value=[])
276+
streaming_event.get_function_responses = Mock(return_value=[])
277+
278+
function_call = SimpleNamespace(id="tool-1", name="long_tool", args={"foo": "bar"})
279+
function_part = SimpleNamespace(text=None, function_call=function_call)
280+
lro_event = SimpleNamespace(
281+
id="event-lro",
282+
author="assistant",
283+
content=SimpleNamespace(parts=[function_part]),
284+
partial=False,
285+
turn_complete=True,
286+
usage_metadata={"tokens": 1},
287+
finish_reason="STOP",
288+
actions=None,
289+
custom_data=None,
290+
long_running_tool_ids=[function_call.id],
291+
)
292+
lro_event.is_final_response = lambda: True
293+
lro_event.get_function_calls = Mock(return_value=[])
294+
lro_event.get_function_responses = Mock(return_value=[])
295+
296+
events_to_yield = [streaming_event, lro_event]
297+
298+
class DummyRunner:
299+
async def run_async(self, *args, **kwargs):
300+
for event in events_to_yield:
301+
yield event
302+
303+
captured_stream_events = []
304+
captured_lro_events = []
305+
306+
original_translate = EventTranslator.translate
307+
original_translate_lro = EventTranslator.translate_lro_function_calls
308+
309+
async def translate_spy(self, adk_event, thread_id, run_id):
310+
translate_spy.call_count += 1
311+
translate_spy.adk_events.append(adk_event)
312+
async for event in original_translate(self, adk_event, thread_id, run_id):
313+
captured_stream_events.append(event)
314+
yield event
315+
316+
translate_spy.call_count = 0
317+
translate_spy.adk_events = []
318+
319+
async def translate_lro_spy(self, adk_event):
320+
translate_lro_spy.call_count += 1
321+
translate_lro_spy.adk_events.append(adk_event)
322+
async for event in original_translate_lro(self, adk_event):
323+
captured_lro_events.append(event)
324+
yield event
325+
326+
translate_lro_spy.call_count = 0
327+
translate_lro_spy.adk_events = []
328+
329+
dummy_runner = DummyRunner()
330+
331+
with patch.object(EventTranslator, "translate", translate_spy), \
332+
patch.object(EventTranslator, "translate_lro_function_calls", translate_lro_spy), \
333+
patch.object(adk_agent, "_create_runner", return_value=dummy_runner):
334+
335+
emitted_events = []
336+
async for event in adk_agent.run(sample_input):
337+
emitted_events.append(event)
338+
339+
# Assert streaming translator was used for the first event
340+
assert translate_spy.call_count == 1
341+
assert translate_spy.adk_events[0] is streaming_event
342+
343+
# Confirm streaming content flowed through as expected
344+
text_events = [event for event in emitted_events if isinstance(event, TextMessageContentEvent)]
345+
assert text_events and text_events[0].delta == "Hello from stream"
346+
assert any(isinstance(event, TextMessageContentEvent) for event in captured_stream_events)
347+
348+
# Long-running translation should be invoked only for the STOP event
349+
assert translate_lro_spy.call_count == 1
350+
assert translate_lro_spy.adk_events[0] is lro_event
351+
352+
# Ensure we produced a tool call event to guard against regressions
353+
assert any(event.type == EventType.TOOL_CALL_END for event in captured_lro_events)
354+
254355
@pytest.mark.asyncio
255356
async def test_session_management(self, adk_agent):
256357
"""Test session lifecycle management."""

0 commit comments

Comments
 (0)