Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 105 additions & 44 deletions integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def __init__(self):
# Track streaming message state
self._streaming_message_id: Optional[str] = None # Current streaming message ID
self._is_streaming: bool = False # Whether we're currently streaming a message
self._current_stream_text: str = "" # Accumulates text for the active stream
self._last_streamed_text: Optional[str] = None # Snapshot of most recently streamed text
self._last_streamed_run_id: Optional[str] = None # Run identifier for the last streamed text
self.long_running_tool_ids: List[str] = [] # Track the long running tool IDs

async def translate(
Expand Down Expand Up @@ -179,6 +182,7 @@ async def translate(
return

# Handle text content
# --- THIS IS THE RESTORED LINE ---
if adk_event.content and hasattr(adk_event.content, 'parts') and adk_event.content.parts:
async for event in self._translate_text_content(
adk_event, thread_id, run_id
Expand Down Expand Up @@ -253,26 +257,34 @@ async def _translate_text_content(
Yields:
Text message events (START, CONTENT, END)
"""

# Check for is_final_response *before* checking for text.
# An empty final response is a valid stream-closing signal.
is_final_response = False
if hasattr(adk_event, 'is_final_response') and callable(adk_event.is_final_response):
is_final_response = adk_event.is_final_response()
elif hasattr(adk_event, 'is_final_response'):
is_final_response = adk_event.is_final_response

# Extract text from all parts
text_parts = []
# The check for adk_event.content.parts happens in the main translate method
for part in adk_event.content.parts:
if part.text:
if part.text: # Note: part.text == "" is False
text_parts.append(part.text)

if not text_parts:
# If no text AND it's not a final response, we can safely skip.
# Otherwise, we must continue to process the final_response signal.
if not text_parts and not is_final_response:
return



combined_text = "".join(text_parts)

# Use proper ADK streaming detection (handle None values)
is_partial = getattr(adk_event, 'partial', False)
turn_complete = getattr(adk_event, 'turn_complete', False)

# Check if this is the final response (complete message - skip to avoid duplication)
is_final_response = False
if hasattr(adk_event, 'is_final_response') and callable(adk_event.is_final_response):
is_final_response = adk_event.is_final_response()
elif hasattr(adk_event, 'is_final_response'):
is_final_response = adk_event.is_final_response
# (is_final_response is already calculated above)

# Handle None values: if a turn is complete or a final chunk arrives, end streaming
has_finish_reason = bool(getattr(adk_event, 'finish_reason', None))
Expand All @@ -287,58 +299,83 @@ async def _translate_text_content(
f"should_send_end={should_send_end}, currently_streaming={self._is_streaming}")

if is_final_response:
# This is the final, complete message event.

# If a final text response wasn't streamed (not generated by an LLM) then deliver it in 3 events
if not self._is_streaming and not adk_event.usage_metadata and should_send_end:
logger.info(f"⏭️ Deliver non-llm response via message events "
f"event_id={adk_event.id}")
# Case 1: A stream is actively running. We must close it.
if self._is_streaming and self._streaming_message_id:
logger.info("⏭️ Final response event received. Closing active stream.")

if self._current_stream_text:
# Save the complete streamed text for de-duplication
self._last_streamed_text = self._current_stream_text
self._last_streamed_run_id = run_id
self._current_stream_text = ""

end_event = TextMessageEndEvent(
type=EventType.TEXT_MESSAGE_END,
message_id=self._streaming_message_id
)
logger.info(f"📤 TEXT_MESSAGE_END (from final response): {end_event.model_dump_json()}")
yield end_event

self._streaming_message_id = None
self._is_streaming = False
logger.info("🏁 Streaming completed via final response")
return # We are done.

# Case 2: No stream is active.
# This event contains the *entire* message.
# We must send it, *unless* it's a duplicate of a stream that *just* finished.

# Check for duplicates from a *previous* stream in this *same run*.
is_duplicate = (
self._last_streamed_run_id == run_id and
self._last_streamed_text is not None and
combined_text == self._last_streamed_text
)

combined_text = "".join(text_parts)
if is_duplicate:
logger.info(
"⏭️ Skipping final response event (duplicate content detected from finished stream)"
)
else:
# Not a duplicate, or no previous stream. Send the full message.
logger.info(
f"⏩ Delivering complete non-streamed message or final content event_id={adk_event.id}"
)
message_events = [
TextMessageStartEvent(
type=EventType.TEXT_MESSAGE_START,
message_id=adk_event.id,
role="assistant"
message_id=adk_event.id, # Use event ID for non-streamed
role="assistant",
),
TextMessageContentEvent(
type=EventType.TEXT_MESSAGE_CONTENT,
message_id=adk_event.id,
delta=combined_text
delta=combined_text,
),
TextMessageEndEvent(
type=EventType.TEXT_MESSAGE_END,
message_id=adk_event.id
)
message_id=adk_event.id,
),
]
for msg in message_events:
yield msg

logger.info("⏭️ Skipping final response event (content already streamed)")

# If we're currently streaming, this final response means we should end the stream
if self._is_streaming and self._streaming_message_id:
end_event = TextMessageEndEvent(
type=EventType.TEXT_MESSAGE_END,
message_id=self._streaming_message_id
)
logger.info(f"📤 TEXT_MESSAGE_END (from final response): {end_event.model_dump_json()}")
yield end_event

# Reset streaming state
self._streaming_message_id = None
self._is_streaming = False
logger.info("🏁 Streaming completed via final response")

# Clean up state regardless, as this is the end of the line for text.
self._current_stream_text = ""
self._last_streamed_text = None
self._last_streamed_run_id = None
return


combined_text = "".join(text_parts) # Don't add newlines for streaming

# Handle streaming logic
# Handle streaming logic (if not is_final_response)
if not self._is_streaming:
# Start of new message - emit START event
self._streaming_message_id = str(uuid.uuid4())
self._is_streaming = True

self._current_stream_text = ""

start_event = TextMessageStartEvent(
type=EventType.TEXT_MESSAGE_START,
message_id=self._streaming_message_id,
Expand All @@ -349,6 +386,7 @@ async def _translate_text_content(

# Always emit content (unless empty)
if combined_text:
self._current_stream_text += combined_text
content_event = TextMessageContentEvent(
type=EventType.TEXT_MESSAGE_CONTENT,
message_id=self._streaming_message_id,
Expand All @@ -365,8 +403,12 @@ async def _translate_text_content(
)
logger.info(f"📤 TEXT_MESSAGE_END: {end_event.model_dump_json()}")
yield end_event

# Reset streaming state
if self._current_stream_text:
self._last_streamed_text = self._current_stream_text
self._last_streamed_run_id = run_id
self._current_stream_text = ""
self._streaming_message_id = None
self._is_streaming = False
logger.info("🏁 Streaming completed, state reset")
Expand Down Expand Up @@ -541,9 +583,23 @@ def _create_state_snapshot_event(
A StateSnapshotEvent
"""

FullSnapShot = {
"context": {
"conversation": [],
"user": {
"name": state_snapshot.get("user_name", ""),
"timezone": state_snapshot.get("timezone", "UTC")
},
"app": {
"version": state_snapshot.get("app_version", "unknown")
}
},
"state": state_snapshot.get("custom_state", {})
}

return StateSnapshotEvent(
type=EventType.STATE_SNAPSHOT,
snapshot=state_snapshot
snapshot=FullSnapShot
)

async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None]:
Expand All @@ -556,15 +612,16 @@ async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None]
"""
if self._is_streaming and self._streaming_message_id:
logger.warning(f"🚨 Force-closing unterminated streaming message: {self._streaming_message_id}")

end_event = TextMessageEndEvent(
type=EventType.TEXT_MESSAGE_END,
message_id=self._streaming_message_id
)
logger.info(f"📤 TEXT_MESSAGE_END (forced): {end_event.model_dump_json()}")
yield end_event

# Reset streaming state
self._current_stream_text = ""
self._streaming_message_id = None
self._is_streaming = False
logger.info("🔄 Streaming state reset after force-close")
Expand All @@ -578,5 +635,9 @@ def reset(self):
self._active_tool_calls.clear()
self._streaming_message_id = None
self._is_streaming = False
self._current_stream_text = ""
self._last_streamed_text = None
self._last_streamed_run_id = None
self.long_running_tool_ids.clear()
logger.debug("Reset EventTranslator state (including streaming state)")

Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ async def test_translate_text_content_final_response_no_streaming(self, translat
async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"):
events.append(event)

assert len(events) == 0 # No events
assert len(events) == 3 # START, CONTENT, END for first final payload
assert isinstance(events[0], TextMessageStartEvent)
assert isinstance(events[1], TextMessageContentEvent)
assert isinstance(events[2], TextMessageEndEvent)

@pytest.mark.asyncio
async def test_translate_text_content_final_response_from_agent_callback(self, translator, mock_adk_event_with_content):
Expand All @@ -362,6 +365,123 @@ async def test_translate_text_content_final_response_from_agent_callback(self, t
assert events[1].delta == mock_adk_event_with_content.content.parts[0].text
assert isinstance(events[2], TextMessageEndEvent)

@pytest.mark.asyncio
async def test_translate_text_content_final_response_after_stream_duplicate_suppressed(self, translator):
"""Final LLM payload matching streamed text should be suppressed."""

stream_event = MagicMock(spec=ADKEvent)
stream_event.id = "event-1"
stream_event.author = "model"
stream_event.content = MagicMock()
stream_part = MagicMock()
stream_part.text = "Hello"
stream_event.content.parts = [stream_part]
stream_event.partial = False
stream_event.turn_complete = False
stream_event.is_final_response = False
stream_event.usage_metadata = {"tokens": 1}

events = []
async for event in translator.translate(stream_event, "thread_1", "run_1"):
events.append(event)

assert len(events) == 2 # START + CONTENT
assert isinstance(events[0], TextMessageStartEvent)
assert isinstance(events[1], TextMessageContentEvent)

final_stream_event = MagicMock(spec=ADKEvent)
final_stream_event.id = "event-2"
final_stream_event.author = "model"
final_stream_event.content = MagicMock()
final_stream_part = MagicMock()
final_stream_part.text = ""
final_stream_event.content.parts = [final_stream_part]
final_stream_event.partial = False
final_stream_event.turn_complete = True
final_stream_event.is_final_response = True
final_stream_event.usage_metadata = {"tokens": 1}

events = []
async for event in translator.translate(final_stream_event, "thread_1", "run_1"):
events.append(event)

assert len(events) == 1 # END only
assert isinstance(events[0], TextMessageEndEvent)

final_payload = MagicMock(spec=ADKEvent)
final_payload.id = "event-3"
final_payload.author = "model"
final_payload.content = MagicMock()
final_payload_part = MagicMock()
final_payload_part.text = "Hello"
final_payload.content.parts = [final_payload_part]
final_payload.partial = False
final_payload.turn_complete = True
final_payload.is_final_response = True
final_payload.usage_metadata = {"tokens": 2}

events = []
async for event in translator.translate(final_payload, "thread_1", "run_1"):
events.append(event)

assert events == [] # duplicate suppressed

@pytest.mark.asyncio
async def test_translate_text_content_final_response_after_stream_new_content(self, translator):
"""Final LLM payload with new content should be emitted."""

stream_event = MagicMock(spec=ADKEvent)
stream_event.id = "event-1"
stream_event.author = "model"
stream_event.content = MagicMock()
stream_part = MagicMock()
stream_part.text = "Hello"
stream_event.content.parts = [stream_part]
stream_event.partial = False
stream_event.turn_complete = False
stream_event.is_final_response = False
stream_event.usage_metadata = {"tokens": 1}

async for _ in translator.translate(stream_event, "thread_1", "run_1"):
pass

final_stream_event = MagicMock(spec=ADKEvent)
final_stream_event.id = "event-2"
final_stream_event.author = "model"
final_stream_event.content = MagicMock()
final_stream_part = MagicMock()
final_stream_part.text = ""
final_stream_event.content.parts = [final_stream_part]
final_stream_event.partial = False
final_stream_event.turn_complete = True
final_stream_event.is_final_response = True
final_stream_event.usage_metadata = {"tokens": 1}

async for _ in translator.translate(final_stream_event, "thread_1", "run_1"):
pass

final_payload = MagicMock(spec=ADKEvent)
final_payload.id = "event-3"
final_payload.author = "model"
final_payload.content = MagicMock()
final_payload_part = MagicMock()
final_payload_part.text = "Hello again"
final_payload.content.parts = [final_payload_part]
final_payload.partial = False
final_payload.turn_complete = True
final_payload.is_final_response = True
final_payload.usage_metadata = {"tokens": 2}

events = []
async for event in translator.translate(final_payload, "thread_1", "run_1"):
events.append(event)

assert len(events) == 3
assert isinstance(events[0], TextMessageStartEvent)
assert isinstance(events[1], TextMessageContentEvent)
assert events[1].delta == "Hello again"
assert isinstance(events[2], TextMessageEndEvent)

@pytest.mark.asyncio
async def test_translate_text_content_empty_text(self, translator, mock_adk_event):
"""Test text content with empty text."""
Expand Down
Loading