Skip to content

Commit 769ff0f

Browse files
ADK middleware: prefer LRO routing and harden translator; add tests
- Prefer LRO routing in ADKAgent when long‑running tool call IDs are present in event.content.parts (prevents misrouting into streaming path and tool loops; preserves HITL pause) - Force‑close any active streaming text before emitting LRO tool events (guarantees TEXT_MESSAGE_END precedes TOOL_CALL_START) - Harden EventTranslator.translate to filter out long‑running tool calls from the general path; only emit non‑LRO calls (avoids duplicate tool events) - Add tests: * test_lro_filtering.py (translator‑level filtering + LRO‑only emission) * test_integration_mixed_partials.py (streaming → non‑LRO → final LRO: order, no duplicates, correct IDs)
1 parent 0479cb6 commit 769ff0f

File tree

4 files changed

+261
-16
lines changed

4 files changed

+261
-16
lines changed

typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/adk_agent.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -923,9 +923,24 @@ async def _run_adk_in_background(
923923
(not final_response) # Not marked as final by is_final_response()
924924
)
925925

926-
# Process as streaming if it's a chunk OR if it has content but no finish_reason
927-
# This ensures we capture all content, regardless of usage_metadata presence
928-
if is_streaming_chunk or (has_content and not getattr(adk_event, 'finish_reason', None)):
926+
# Prefer LRO routing when a long-running tool call is present
927+
has_lro_function_call = False
928+
try:
929+
lro_ids = set(getattr(adk_event, 'long_running_tool_ids', []) or [])
930+
if lro_ids and adk_event.content and getattr(adk_event.content, 'parts', None):
931+
for part in adk_event.content.parts:
932+
func = getattr(part, 'function_call', None)
933+
func_id = getattr(func, 'id', None) if func else None
934+
if func_id and func_id in lro_ids:
935+
has_lro_function_call = True
936+
break
937+
except Exception:
938+
# Be conservative: if detection fails, do not block streaming path
939+
has_lro_function_call = False
940+
941+
# Process as streaming if it's a chunk OR if it has content but no finish_reason,
942+
# but only when there is no LRO function call present (LRO takes precedence)
943+
if (not has_lro_function_call) and (is_streaming_chunk or (has_content and not getattr(adk_event, 'finish_reason', None))):
929944
# Regular translation path
930945
async for ag_ui_event in event_translator.translate(
931946
adk_event,
@@ -937,7 +952,12 @@ async def _run_adk_in_background(
937952
await event_queue.put(ag_ui_event)
938953
logger.debug(f"Event queued: {type(ag_ui_event).__name__} (thread {input.thread_id}, queue size after: {event_queue.qsize()})")
939954
else:
940-
# LongRunning Tool events are usually emmitted in final response
955+
# LongRunning Tool events are usually emitted in final response
956+
# Ensure any active streaming text message is closed BEFORE tool calls
957+
async for end_event in event_translator.force_close_streaming_message():
958+
await event_queue.put(end_event)
959+
logger.debug(f"Event queued (forced close): {type(end_event).__name__} (thread {input.thread_id}, queue size after: {event_queue.qsize()})")
960+
941961
async for ag_ui_event in event_translator.translate_lro_function_calls(
942962
adk_event
943963
):
@@ -1003,4 +1023,4 @@ async def close(self):
10031023
self._session_lookup_cache.clear()
10041024

10051025
# Stop session manager cleanup task
1006-
await self._session_manager.stop_cleanup_task()
1026+
await self._session_manager.stop_cleanup_task()

typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,24 @@ async def translate(
8787
if hasattr(adk_event, 'get_function_calls'):
8888
function_calls = adk_event.get_function_calls()
8989
if function_calls:
90-
logger.debug(f"ADK function calls detected: {len(function_calls)} calls")
91-
92-
# CRITICAL FIX: End any active text message stream before starting tool calls
93-
# Per AG-UI protocol: TEXT_MESSAGE_END must be sent before TOOL_CALL_START
94-
async for event in self.force_close_streaming_message():
95-
yield event
96-
97-
# NOW ACTUALLY YIELD THE EVENTS
98-
async for event in self._translate_function_calls(function_calls):
99-
yield event
90+
# Filter out long-running tool calls; those are handled by translate_lro_function_calls
91+
try:
92+
lro_ids = set(getattr(adk_event, 'long_running_tool_ids', []) or [])
93+
except Exception:
94+
lro_ids = set()
95+
96+
non_lro_calls = [fc for fc in function_calls if getattr(fc, 'id', None) not in lro_ids]
97+
98+
if non_lro_calls:
99+
logger.debug(f"ADK function calls detected (non-LRO): {len(non_lro_calls)} of {len(function_calls)} total")
100+
# CRITICAL FIX: End any active text message stream before starting tool calls
101+
# Per AG-UI protocol: TEXT_MESSAGE_END must be sent before TOOL_CALL_START
102+
async for event in self.force_close_streaming_message():
103+
yield event
104+
105+
# Yield only non-LRO function call events
106+
async for event in self._translate_function_calls(non_lro_calls):
107+
yield event
100108

101109
# Handle function responses and yield the tool response event
102110
# this is essential for scenerios when user has to render function response at frontend
@@ -466,4 +474,4 @@ def reset(self):
466474
self._streaming_message_id = None
467475
self._is_streaming = False
468476
self.long_running_tool_ids.clear()
469-
logger.debug("Reset EventTranslator state (including streaming state)")
477+
logger.debug("Reset EventTranslator state (including streaming state)")
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env python
2+
"""Integration test: mixed partials with non-LRO calls before final LRO.
3+
4+
Scenario:
5+
- Stream text in partial chunks
6+
- Mid-stream, a non-LRO function call appears (should close text and emit tool events)
7+
- Finally, an LRO function call arrives (should close any open text and emit LRO tool events)
8+
9+
Asserts order, deduplication, and correct tool ids.
10+
"""
11+
12+
import pytest
13+
from unittest.mock import MagicMock, AsyncMock, Mock, patch
14+
15+
from ag_ui.core import (
16+
RunAgentInput, UserMessage
17+
)
18+
from ag_ui_adk import ADKAgent
19+
20+
21+
@pytest.fixture
22+
def adk_agent_instance():
23+
from google.adk.agents import Agent
24+
mock_agent = Mock(spec=Agent)
25+
mock_agent.name = "test_agent"
26+
return ADKAgent(adk_agent=mock_agent, app_name="test_app", user_id="test_user")
27+
28+
29+
@pytest.mark.asyncio
30+
async def test_mixed_partials_non_lro_then_lro(adk_agent_instance):
31+
# Helper to create partial text events
32+
def mk_partial(text):
33+
e = MagicMock()
34+
e.author = "assistant"
35+
e.content = MagicMock(); e.content.parts = [MagicMock(text=text)]
36+
e.partial = True
37+
e.turn_complete = False
38+
e.is_final_response = lambda: False
39+
# No function responses in these partials
40+
e.get_function_responses = lambda: []
41+
e.get_function_calls = lambda: []
42+
return e
43+
44+
# First partial text only
45+
evt1 = mk_partial("Hello")
46+
47+
# Second partial: text + non-LRO function call
48+
normal_id = "normal-999"
49+
normal_func = MagicMock(); normal_func.id = normal_id; normal_func.name = "regular_tool"; normal_func.args = {"b": 2}
50+
evt2 = mk_partial(" world")
51+
evt2.get_function_calls = lambda: [normal_func]
52+
evt2.long_running_tool_ids = []
53+
54+
# Final: LRO function call
55+
lro_id = "lro-777"
56+
lro_func = MagicMock(); lro_func.id = lro_id; lro_func.name = "long_running_tool"; lro_func.args = {"v": 1}
57+
lro_part = MagicMock(); lro_part.function_call = lro_func
58+
59+
evt3 = MagicMock()
60+
evt3.author = "assistant"
61+
evt3.content = MagicMock(); evt3.content.parts = [lro_part]
62+
evt3.partial = False
63+
evt3.turn_complete = True
64+
evt3.is_final_response = lambda: True
65+
evt3.get_function_calls = lambda: []
66+
evt3.get_function_responses = lambda: []
67+
evt3.long_running_tool_ids = [lro_id]
68+
69+
async def mock_run_async(*args, **kwargs):
70+
yield evt1
71+
yield evt2
72+
yield evt3
73+
74+
mock_runner = AsyncMock(); mock_runner.run_async = mock_run_async
75+
76+
sample_input = RunAgentInput(
77+
thread_id="thread_mixed",
78+
run_id="run_mixed",
79+
messages=[UserMessage(id="u1", role="user", content="go")],
80+
tools=[], context=[], state={}, forwarded_props={},
81+
)
82+
83+
with patch.object(adk_agent_instance, "_create_runner", return_value=mock_runner):
84+
events = []
85+
async for e in adk_agent_instance.run(sample_input):
86+
events.append(e)
87+
88+
types = [str(ev.type).split(".")[-1] for ev in events]
89+
90+
# Expect at least one START and 2 CONTENTs from streaming
91+
assert types.count("TEXT_MESSAGE_START") == 1
92+
assert types.count("TEXT_MESSAGE_CONTENT") >= 2
93+
94+
# Non-LRO tool call should appear exactly once
95+
normal_starts = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_START") and getattr(ev, "tool_call_id", None) == normal_id]
96+
normal_args = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_ARGS") and getattr(ev, "tool_call_id", None) == normal_id]
97+
normal_ends = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_END") and getattr(ev, "tool_call_id", None) == normal_id]
98+
assert len(normal_starts) == len(normal_args) == len(normal_ends) == 1
99+
100+
# Ensure a TEXT_MESSAGE_END precedes the normal tool start
101+
text_ends = [i for i, t in enumerate(types) if t == "TEXT_MESSAGE_END"]
102+
assert len(text_ends) >= 1
103+
assert text_ends[-1] < normal_starts[0], "TEXT_MESSAGE_END must precede first non-LRO TOOL_CALL_START"
104+
105+
# LRO tool call should appear exactly once and after the non-LRO
106+
lro_starts = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_START") and getattr(ev, "tool_call_id", None) == lro_id]
107+
lro_args = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_ARGS") and getattr(ev, "tool_call_id", None) == lro_id]
108+
lro_ends = [i for i, ev in enumerate(events) if str(ev.type).endswith("TOOL_CALL_END") and getattr(ev, "tool_call_id", None) == lro_id]
109+
assert len(lro_starts) == len(lro_args) == len(lro_ends) == 1
110+
assert lro_starts[0] > normal_starts[0]
111+
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python
2+
"""Tests for LRO-aware routing and translator filtering.
3+
4+
These tests verify that:
5+
- EventTranslator.translate skips long-running tool calls and only emits non-LRO calls
6+
- translate_lro_function_calls emits events only for long-running tool calls
7+
"""
8+
9+
import asyncio
10+
from unittest.mock import MagicMock
11+
12+
from ag_ui.core import EventType
13+
from ag_ui_adk import EventTranslator
14+
15+
16+
async def test_translate_skips_lro_function_calls():
17+
"""Ensure non-LRO tool calls are emitted and LRO calls are skipped in translate."""
18+
translator = EventTranslator()
19+
20+
# Prepare mock ADK event
21+
adk_event = MagicMock()
22+
adk_event.author = "assistant"
23+
adk_event.content = MagicMock()
24+
adk_event.content.parts = [] # no text
25+
26+
# Two function calls, one is long-running
27+
lro_id = "tool-call-lro-1"
28+
normal_id = "tool-call-normal-2"
29+
30+
lro_call = MagicMock()
31+
lro_call.id = lro_id
32+
lro_call.name = "long_running_tool"
33+
lro_call.args = {"x": 1}
34+
35+
normal_call = MagicMock()
36+
normal_call.id = normal_id
37+
normal_call.name = "regular_tool"
38+
normal_call.args = {"y": 2}
39+
40+
adk_event.get_function_calls = lambda: [lro_call, normal_call]
41+
# Mark the long-running call id on the event
42+
adk_event.long_running_tool_ids = [lro_id]
43+
44+
events = []
45+
async for e in translator.translate(adk_event, "thread", "run"):
46+
events.append(e)
47+
48+
# We expect only the non-LRO tool call events to be emitted
49+
# Sequence: TOOL_CALL_START(normal), TOOL_CALL_ARGS(normal), TOOL_CALL_END(normal)
50+
event_types = [str(ev.type).split('.')[-1] for ev in events]
51+
assert event_types.count("TOOL_CALL_START") == 1
52+
assert event_types.count("TOOL_CALL_ARGS") == 1
53+
assert event_types.count("TOOL_CALL_END") == 1
54+
55+
# Ensure the emitted tool_call_id is the normal one
56+
ids = set(getattr(ev, 'tool_call_id', None) for ev in events)
57+
assert normal_id in ids
58+
assert lro_id not in ids
59+
60+
61+
async def test_translate_lro_function_calls_only_emits_lro():
62+
"""Ensure translate_lro_function_calls emits only for long-running calls."""
63+
translator = EventTranslator()
64+
65+
# Prepare mock ADK event with content parts containing function calls
66+
lro_id = "tool-call-lro-3"
67+
normal_id = "tool-call-normal-4"
68+
69+
lro_call = MagicMock()
70+
lro_call.id = lro_id
71+
lro_call.name = "long_running_tool"
72+
lro_call.args = {"a": 123}
73+
74+
normal_call = MagicMock()
75+
normal_call.id = normal_id
76+
normal_call.name = "regular_tool"
77+
normal_call.args = {"b": 456}
78+
79+
# Build parts with both calls
80+
lro_part = MagicMock()
81+
lro_part.function_call = lro_call
82+
normal_part = MagicMock()
83+
normal_part.function_call = normal_call
84+
85+
adk_event = MagicMock()
86+
adk_event.content = MagicMock()
87+
adk_event.content.parts = [lro_part, normal_part]
88+
adk_event.long_running_tool_ids = [lro_id]
89+
90+
events = []
91+
async for e in translator.translate_lro_function_calls(adk_event):
92+
events.append(e)
93+
94+
# Expect only the LRO call events
95+
# Sequence: TOOL_CALL_START(lro), TOOL_CALL_ARGS(lro), TOOL_CALL_END(lro)
96+
event_types = [str(ev.type).split('.')[-1] for ev in events]
97+
assert event_types == ["TOOL_CALL_START", "TOOL_CALL_ARGS", "TOOL_CALL_END"]
98+
for ev in events:
99+
assert getattr(ev, 'tool_call_id', None) == lro_id
100+
101+
102+
if __name__ == "__main__":
103+
asyncio.run(test_translate_skips_lro_function_calls())
104+
asyncio.run(test_translate_lro_function_calls_only_emits_lro())
105+
print("\n✅ LRO filtering tests ran to completion")
106+

0 commit comments

Comments
 (0)