diff --git a/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py b/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py index 0a7c51c0a..922bd0279 100644 --- a/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py +++ b/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py @@ -137,6 +137,9 @@ def __init__(self): # Track streaming message state self._streaming_message_id: Optional[str] = None # Current streaming message ID self._is_streaming: bool = False # Whether we're currently streaming a message + self._current_stream_text: str = "" # Accumulates text for the active stream + self._last_streamed_text: Optional[str] = None # Snapshot of most recently streamed text + self._last_streamed_run_id: Optional[str] = None # Run identifier for the last streamed text self.long_running_tool_ids: List[str] = [] # Track the long running tool IDs async def translate( @@ -179,6 +182,7 @@ async def translate( return # Handle text content + # --- THIS IS THE RESTORED LINE --- if adk_event.content and hasattr(adk_event.content, 'parts') and adk_event.content.parts: async for event in self._translate_text_content( adk_event, thread_id, run_id @@ -253,26 +257,34 @@ async def _translate_text_content( Yields: Text message events (START, CONTENT, END) """ + + # Check for is_final_response *before* checking for text. + # An empty final response is a valid stream-closing signal. + is_final_response = False + if hasattr(adk_event, 'is_final_response') and callable(adk_event.is_final_response): + is_final_response = adk_event.is_final_response() + elif hasattr(adk_event, 'is_final_response'): + is_final_response = adk_event.is_final_response + # Extract text from all parts text_parts = [] + # The check for adk_event.content.parts happens in the main translate method for part in adk_event.content.parts: - if part.text: + if part.text: # Note: part.text == "" is False text_parts.append(part.text) - if not text_parts: + # If no text AND it's not a final response, we can safely skip. + # Otherwise, we must continue to process the final_response signal. + if not text_parts and not is_final_response: return - - + + combined_text = "".join(text_parts) + # Use proper ADK streaming detection (handle None values) is_partial = getattr(adk_event, 'partial', False) turn_complete = getattr(adk_event, 'turn_complete', False) - # Check if this is the final response (complete message - skip to avoid duplication) - is_final_response = False - if hasattr(adk_event, 'is_final_response') and callable(adk_event.is_final_response): - is_final_response = adk_event.is_final_response() - elif hasattr(adk_event, 'is_final_response'): - is_final_response = adk_event.is_final_response + # (is_final_response is already calculated above) # Handle None values: if a turn is complete or a final chunk arrives, end streaming has_finish_reason = bool(getattr(adk_event, 'finish_reason', None)) @@ -287,58 +299,83 @@ async def _translate_text_content( f"should_send_end={should_send_end}, currently_streaming={self._is_streaming}") if is_final_response: + # This is the final, complete message event. - # If a final text response wasn't streamed (not generated by an LLM) then deliver it in 3 events - if not self._is_streaming and not adk_event.usage_metadata and should_send_end: - logger.info(f"⏭️ Deliver non-llm response via message events " - f"event_id={adk_event.id}") + # Case 1: A stream is actively running. We must close it. + if self._is_streaming and self._streaming_message_id: + logger.info("⏭️ Final response event received. Closing active stream.") + + if self._current_stream_text: + # Save the complete streamed text for de-duplication + self._last_streamed_text = self._current_stream_text + self._last_streamed_run_id = run_id + self._current_stream_text = "" + + end_event = TextMessageEndEvent( + type=EventType.TEXT_MESSAGE_END, + message_id=self._streaming_message_id + ) + logger.info(f"📤 TEXT_MESSAGE_END (from final response): {end_event.model_dump_json()}") + yield end_event + + self._streaming_message_id = None + self._is_streaming = False + logger.info("🏁 Streaming completed via final response") + return # We are done. + + # Case 2: No stream is active. + # This event contains the *entire* message. + # We must send it, *unless* it's a duplicate of a stream that *just* finished. + + # Check for duplicates from a *previous* stream in this *same run*. + is_duplicate = ( + self._last_streamed_run_id == run_id and + self._last_streamed_text is not None and + combined_text == self._last_streamed_text + ) - combined_text = "".join(text_parts) + if is_duplicate: + logger.info( + "⏭️ Skipping final response event (duplicate content detected from finished stream)" + ) + else: + # Not a duplicate, or no previous stream. Send the full message. + logger.info( + f"⏩ Delivering complete non-streamed message or final content event_id={adk_event.id}" + ) message_events = [ TextMessageStartEvent( type=EventType.TEXT_MESSAGE_START, - message_id=adk_event.id, - role="assistant" + message_id=adk_event.id, # Use event ID for non-streamed + role="assistant", ), TextMessageContentEvent( type=EventType.TEXT_MESSAGE_CONTENT, message_id=adk_event.id, - delta=combined_text + delta=combined_text, ), TextMessageEndEvent( type=EventType.TEXT_MESSAGE_END, - message_id=adk_event.id - ) + message_id=adk_event.id, + ), ] for msg in message_events: yield msg - logger.info("⏭️ Skipping final response event (content already streamed)") - - # If we're currently streaming, this final response means we should end the stream - if self._is_streaming and self._streaming_message_id: - end_event = TextMessageEndEvent( - type=EventType.TEXT_MESSAGE_END, - message_id=self._streaming_message_id - ) - logger.info(f"📤 TEXT_MESSAGE_END (from final response): {end_event.model_dump_json()}") - yield end_event - - # Reset streaming state - self._streaming_message_id = None - self._is_streaming = False - logger.info("🏁 Streaming completed via final response") - + # Clean up state regardless, as this is the end of the line for text. + self._current_stream_text = "" + self._last_streamed_text = None + self._last_streamed_run_id = None return + - combined_text = "".join(text_parts) # Don't add newlines for streaming - - # Handle streaming logic + # Handle streaming logic (if not is_final_response) if not self._is_streaming: # Start of new message - emit START event self._streaming_message_id = str(uuid.uuid4()) self._is_streaming = True - + self._current_stream_text = "" + start_event = TextMessageStartEvent( type=EventType.TEXT_MESSAGE_START, message_id=self._streaming_message_id, @@ -349,6 +386,7 @@ async def _translate_text_content( # Always emit content (unless empty) if combined_text: + self._current_stream_text += combined_text content_event = TextMessageContentEvent( type=EventType.TEXT_MESSAGE_CONTENT, message_id=self._streaming_message_id, @@ -365,8 +403,12 @@ async def _translate_text_content( ) logger.info(f"📤 TEXT_MESSAGE_END: {end_event.model_dump_json()}") yield end_event - + # Reset streaming state + if self._current_stream_text: + self._last_streamed_text = self._current_stream_text + self._last_streamed_run_id = run_id + self._current_stream_text = "" self._streaming_message_id = None self._is_streaming = False logger.info("🏁 Streaming completed, state reset") @@ -541,9 +583,23 @@ def _create_state_snapshot_event( A StateSnapshotEvent """ + FullSnapShot = { + "context": { + "conversation": [], + "user": { + "name": state_snapshot.get("user_name", ""), + "timezone": state_snapshot.get("timezone", "UTC") + }, + "app": { + "version": state_snapshot.get("app_version", "unknown") + } + }, + "state": state_snapshot.get("custom_state", {}) + } + return StateSnapshotEvent( type=EventType.STATE_SNAPSHOT, - snapshot=state_snapshot + snapshot=FullSnapShot ) async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None]: @@ -556,15 +612,16 @@ async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None] """ if self._is_streaming and self._streaming_message_id: logger.warning(f"🚨 Force-closing unterminated streaming message: {self._streaming_message_id}") - + end_event = TextMessageEndEvent( type=EventType.TEXT_MESSAGE_END, message_id=self._streaming_message_id ) logger.info(f"📤 TEXT_MESSAGE_END (forced): {end_event.model_dump_json()}") yield end_event - + # Reset streaming state + self._current_stream_text = "" self._streaming_message_id = None self._is_streaming = False logger.info("🔄 Streaming state reset after force-close") @@ -578,5 +635,9 @@ def reset(self): self._active_tool_calls.clear() self._streaming_message_id = None self._is_streaming = False + self._current_stream_text = "" + self._last_streamed_text = None + self._last_streamed_run_id = None self.long_running_tool_ids.clear() logger.debug("Reset EventTranslator state (including streaming state)") + \ No newline at end of file diff --git a/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py b/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py index 4944f1dd8..8ef9d17ea 100644 --- a/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py +++ b/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py @@ -341,7 +341,10 @@ async def test_translate_text_content_final_response_no_streaming(self, translat async for event in translator.translate(mock_adk_event_with_content, "thread_1", "run_1"): events.append(event) - assert len(events) == 0 # No events + assert len(events) == 3 # START, CONTENT, END for first final payload + assert isinstance(events[0], TextMessageStartEvent) + assert isinstance(events[1], TextMessageContentEvent) + assert isinstance(events[2], TextMessageEndEvent) @pytest.mark.asyncio async def test_translate_text_content_final_response_from_agent_callback(self, translator, mock_adk_event_with_content): @@ -362,6 +365,123 @@ async def test_translate_text_content_final_response_from_agent_callback(self, t assert events[1].delta == mock_adk_event_with_content.content.parts[0].text assert isinstance(events[2], TextMessageEndEvent) + @pytest.mark.asyncio + async def test_translate_text_content_final_response_after_stream_duplicate_suppressed(self, translator): + """Final LLM payload matching streamed text should be suppressed.""" + + stream_event = MagicMock(spec=ADKEvent) + stream_event.id = "event-1" + stream_event.author = "model" + stream_event.content = MagicMock() + stream_part = MagicMock() + stream_part.text = "Hello" + stream_event.content.parts = [stream_part] + stream_event.partial = False + stream_event.turn_complete = False + stream_event.is_final_response = False + stream_event.usage_metadata = {"tokens": 1} + + events = [] + async for event in translator.translate(stream_event, "thread_1", "run_1"): + events.append(event) + + assert len(events) == 2 # START + CONTENT + assert isinstance(events[0], TextMessageStartEvent) + assert isinstance(events[1], TextMessageContentEvent) + + final_stream_event = MagicMock(spec=ADKEvent) + final_stream_event.id = "event-2" + final_stream_event.author = "model" + final_stream_event.content = MagicMock() + final_stream_part = MagicMock() + final_stream_part.text = "" + final_stream_event.content.parts = [final_stream_part] + final_stream_event.partial = False + final_stream_event.turn_complete = True + final_stream_event.is_final_response = True + final_stream_event.usage_metadata = {"tokens": 1} + + events = [] + async for event in translator.translate(final_stream_event, "thread_1", "run_1"): + events.append(event) + + assert len(events) == 1 # END only + assert isinstance(events[0], TextMessageEndEvent) + + final_payload = MagicMock(spec=ADKEvent) + final_payload.id = "event-3" + final_payload.author = "model" + final_payload.content = MagicMock() + final_payload_part = MagicMock() + final_payload_part.text = "Hello" + final_payload.content.parts = [final_payload_part] + final_payload.partial = False + final_payload.turn_complete = True + final_payload.is_final_response = True + final_payload.usage_metadata = {"tokens": 2} + + events = [] + async for event in translator.translate(final_payload, "thread_1", "run_1"): + events.append(event) + + assert events == [] # duplicate suppressed + + @pytest.mark.asyncio + async def test_translate_text_content_final_response_after_stream_new_content(self, translator): + """Final LLM payload with new content should be emitted.""" + + stream_event = MagicMock(spec=ADKEvent) + stream_event.id = "event-1" + stream_event.author = "model" + stream_event.content = MagicMock() + stream_part = MagicMock() + stream_part.text = "Hello" + stream_event.content.parts = [stream_part] + stream_event.partial = False + stream_event.turn_complete = False + stream_event.is_final_response = False + stream_event.usage_metadata = {"tokens": 1} + + async for _ in translator.translate(stream_event, "thread_1", "run_1"): + pass + + final_stream_event = MagicMock(spec=ADKEvent) + final_stream_event.id = "event-2" + final_stream_event.author = "model" + final_stream_event.content = MagicMock() + final_stream_part = MagicMock() + final_stream_part.text = "" + final_stream_event.content.parts = [final_stream_part] + final_stream_event.partial = False + final_stream_event.turn_complete = True + final_stream_event.is_final_response = True + final_stream_event.usage_metadata = {"tokens": 1} + + async for _ in translator.translate(final_stream_event, "thread_1", "run_1"): + pass + + final_payload = MagicMock(spec=ADKEvent) + final_payload.id = "event-3" + final_payload.author = "model" + final_payload.content = MagicMock() + final_payload_part = MagicMock() + final_payload_part.text = "Hello again" + final_payload.content.parts = [final_payload_part] + final_payload.partial = False + final_payload.turn_complete = True + final_payload.is_final_response = True + final_payload.usage_metadata = {"tokens": 2} + + events = [] + async for event in translator.translate(final_payload, "thread_1", "run_1"): + events.append(event) + + assert len(events) == 3 + assert isinstance(events[0], TextMessageStartEvent) + assert isinstance(events[1], TextMessageContentEvent) + assert events[1].delta == "Hello again" + assert isinstance(events[2], TextMessageEndEvent) + @pytest.mark.asyncio async def test_translate_text_content_empty_text(self, translator, mock_adk_event): """Test text content with empty text."""