Skip to content

Commit 446f2e9

Browse files
Merge pull request #49 from evgeny-l/adk-middleware-before-agent-callback
fix: support of before_agent_callback in case of a direct content response
2 parents 7abe5b9 + 60bd9a1 commit 446f2e9

File tree

4 files changed

+200
-7
lines changed

4 files changed

+200
-7
lines changed

typescript-sdk/integrations/adk-middleware/src/adk_middleware/adk_agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,12 @@ async def _run_adk_in_background(
858858
new_message=new_message,
859859
run_config=run_config
860860
):
861-
if not adk_event.is_final_response():
862-
# Translate and emit events
861+
862+
final_response = adk_event.is_final_response()
863+
has_content = adk_event.content and hasattr(adk_event.content, 'parts') and adk_event.content.parts
864+
865+
if not final_response or (not adk_event.usage_metadata and has_content):
866+
# Translate and emit events
863867
async for ag_ui_event in event_translator.translate(
864868
adk_event,
865869
input.thread_id,

typescript-sdk/integrations/adk-middleware/src/adk_middleware/event_translator.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,36 @@ async def _translate_text_content(
168168
should_send_end = is_final_response and not is_partial
169169

170170
logger.info(f"📥 Text event - partial={is_partial}, turn_complete={turn_complete}, "
171-
f"is_final_response={is_final_response}, should_send_end={should_send_end}, "
172-
f"currently_streaming={self._is_streaming}")
173-
174-
# Skip final response events to avoid duplicate content, but send END if streaming
171+
f"is_final_response={is_final_response}, should_send_end={should_send_end}, "
172+
f"currently_streaming={self._is_streaming}")
173+
175174
if is_final_response:
175+
176+
# If a final text response wasn't streamed (not generated by an LLM) then deliver it in 3 events
177+
if not self._is_streaming and not adk_event.usage_metadata and should_send_end:
178+
logger.info(f"⏭️ Deliver non-llm response via message events "
179+
f"event_id={adk_event.id}")
180+
181+
combined_text = "".join(text_parts)
182+
message_events = [
183+
TextMessageStartEvent(
184+
type=EventType.TEXT_MESSAGE_START,
185+
message_id=adk_event.id,
186+
role="assistant"
187+
),
188+
TextMessageContentEvent(
189+
type=EventType.TEXT_MESSAGE_CONTENT,
190+
message_id=adk_event.id,
191+
delta=combined_text
192+
),
193+
TextMessageEndEvent(
194+
type=EventType.TEXT_MESSAGE_END,
195+
message_id=adk_event.id
196+
)
197+
]
198+
for msg in message_events:
199+
yield msg
200+
176201
logger.info("⏭️ Skipping final response event (content already streamed)")
177202

178203
# If we're currently streaming, this final response means we should end the stream

typescript-sdk/integrations/adk-middleware/tests/test_event_translator_comprehensive.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def mock_adk_event_with_content(self):
5050
event.partial = False
5151
event.turn_complete = True
5252
event.is_final_response = False
53+
event.usage_metadata = {'tokens': 22}
5354
return event
5455

5556
@pytest.mark.asyncio
@@ -275,7 +276,26 @@ async def test_translate_text_content_final_response_no_streaming(self, translat
275276
events.append(event)
276277

277278
assert len(events) == 0 # No events
278-
279+
280+
@pytest.mark.asyncio
281+
async def test_translate_text_content_final_response_from_agent_callback(self, translator, mock_adk_event_with_content):
282+
"""Test final response when it was received from an agent callback function."""
283+
mock_adk_event_with_content.is_final_response = True
284+
mock_adk_event_with_content.usage_metadata = None
285+
286+
# Not streaming
287+
translator._is_streaming = False
288+
289+
events = []
290+
async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"):
291+
events.append(event)
292+
293+
assert len(events) == 3 # START, CONTENT , END
294+
assert isinstance(events[0], TextMessageStartEvent)
295+
assert isinstance(events[1], TextMessageContentEvent)
296+
assert events[1].delta == mock_adk_event_with_content.content.parts[0].text
297+
assert isinstance(events[2], TextMessageEndEvent)
298+
279299
@pytest.mark.asyncio
280300
async def test_translate_text_content_empty_text(self, translator, mock_adk_event):
281301
"""Test text content with empty text."""

typescript-sdk/integrations/adk-middleware/tests/test_text_events.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ag_ui.core import RunAgentInput, UserMessage
1111
from adk_middleware import ADKAgent
1212
from google.adk.agents import Agent
13+
from google.genai import types
1314

1415

1516
async def test_message_events():
@@ -87,6 +88,142 @@ async def test_message_events():
8788
return validate_message_event_pattern(start_count, end_count, content_count, text_message_events)
8889

8990

91+
async def test_message_events_from_before_agent_callback():
92+
"""Test that we get proper message events with correct START/CONTENT/END patterns,
93+
even if we return the message from before_agent_callback.
94+
"""
95+
96+
if not os.getenv("GOOGLE_API_KEY"):
97+
print("⚠️ GOOGLE_API_KEY not set - using mock test")
98+
return await test_with_mock()
99+
100+
print("🧪 Testing with real Google ADK agent...")
101+
102+
event_message = "This message was not generated."
103+
def return_predefined_message(callback_context):
104+
return types.Content(
105+
parts=[types.Part(text=event_message)],
106+
role="model" # Assign model role to the overriding response
107+
)
108+
109+
# Create real agent
110+
agent = Agent(
111+
name="test_agent",
112+
instruction="You are a helpful assistant. Keep responses brief.",
113+
before_agent_callback=return_predefined_message
114+
)
115+
116+
# Create middleware with direct agent embedding
117+
adk_agent = ADKAgent(
118+
adk_agent=agent,
119+
app_name="test_app",
120+
user_id="test_user",
121+
use_in_memory_services=True,
122+
)
123+
124+
# Test input
125+
test_input = RunAgentInput(
126+
thread_id="test_thread",
127+
run_id="test_run",
128+
messages=[
129+
UserMessage(
130+
id="msg_1",
131+
role="user",
132+
content="Say hello in exactly 3 words."
133+
)
134+
],
135+
state={},
136+
context=[],
137+
tools=[],
138+
forwarded_props={}
139+
)
140+
141+
print("🚀 Running test request...")
142+
143+
events = []
144+
text_message_events = []
145+
146+
try:
147+
async for event in adk_agent.run(test_input):
148+
events.append(event)
149+
event_type = str(event.type)
150+
print(f"📧 {event_type}")
151+
152+
# Track text message events specifically
153+
if "TEXT_MESSAGE" in event_type:
154+
text_message_events.append(event_type)
155+
156+
except Exception as e:
157+
print(f"❌ Error during test: {e}")
158+
return False
159+
160+
print(f"\n📊 Results:")
161+
print(f" Total events: {len(events)}")
162+
print(f" Text message events: {text_message_events}")
163+
164+
# Analyze message event patterns
165+
start_count = text_message_events.count("EventType.TEXT_MESSAGE_START")
166+
end_count = text_message_events.count("EventType.TEXT_MESSAGE_END")
167+
content_count = text_message_events.count("EventType.TEXT_MESSAGE_CONTENT")
168+
169+
print(f" START events: {start_count}")
170+
print(f" END events: {end_count}")
171+
print(f" CONTENT events: {content_count}")
172+
173+
pattern_is_valid = validate_message_event_pattern(start_count, end_count, content_count, text_message_events)
174+
if not pattern_is_valid:
175+
return False
176+
177+
expected_text_events = [
178+
{
179+
"type": "EventType.TEXT_MESSAGE_START",
180+
},
181+
{
182+
"type": "EventType.TEXT_MESSAGE_CONTENT",
183+
"delta": event_message
184+
},
185+
{
186+
"type": "EventType.TEXT_MESSAGE_END",
187+
}
188+
]
189+
return validate_message_events(events, expected_text_events)
190+
191+
192+
def validate_message_events(events, expected_events):
193+
"""Compare expected events by type and delta (if delta exists)."""
194+
# Filter events to only those specified in expected_events
195+
event_types_to_check = {expected["type"] for expected in expected_events}
196+
197+
filtered_events = []
198+
for event in events:
199+
event_type_str = f"EventType.{event.type.value}"
200+
if event_type_str in event_types_to_check:
201+
filtered_events.append(event)
202+
203+
if len(filtered_events) != len(expected_events):
204+
print(f"❌ Event count mismatch: expected {len(expected_events)}, got {len(filtered_events)}")
205+
return False
206+
207+
for i, (event, expected) in enumerate(zip(filtered_events, expected_events)):
208+
# Check event type
209+
event_type_str = f"EventType.{event.type.value}"
210+
if event_type_str != expected["type"]:
211+
print(f"❌ Event {i}: type mismatch - expected {expected['type']}, got {event_type_str}")
212+
return False
213+
214+
# Check delta if specified
215+
if "delta" in expected:
216+
if not hasattr(event, 'delta'):
217+
print(f"❌ Event {i}: expected delta field but event has none")
218+
return False
219+
if event.delta != expected["delta"]:
220+
print(f"❌ Event {i}: delta mismatch - expected '{expected['delta']}', got '{event.delta}'")
221+
return False
222+
223+
print("✅ All expected events validated successfully")
224+
return True
225+
226+
90227
def validate_message_event_pattern(start_count, end_count, content_count, text_message_events):
91228
"""Validate that message events follow proper patterns."""
92229

@@ -319,6 +456,13 @@ async def test_text_message_events():
319456
assert result, "Text message events test failed"
320457

321458

459+
@pytest.mark.asyncio
460+
async def test_text_message_events_from_before_agent_callback():
461+
"""Test that we get proper message events with correct START/CONTENT/END patterns."""
462+
result = await test_message_events_from_before_agent_callback()
463+
assert result, "Text message events for before_agent_callback test failed"
464+
465+
322466
@pytest.mark.asyncio
323467
async def test_message_event_edge_cases():
324468
"""Test edge cases for message event patterns."""

0 commit comments

Comments
 (0)