Skip to content

Commit 5b90d3d

Browse files
Addressing regression in ADK 0.3.2 (#626)
* Add state snapshot passthrough tests * Update complex event translator test for snapshot
1 parent 3c66dcc commit 5b90d3d

File tree

2 files changed

+65
-22
lines changed

2 files changed

+65
-22
lines changed

integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,16 @@ async def translate(
223223

224224

225225
# Handle state changes
226-
if hasattr(adk_event, 'actions') and adk_event.actions and hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta:
227-
yield self._create_state_delta_event(
228-
adk_event.actions.state_delta, thread_id, run_id
229-
)
226+
if hasattr(adk_event, 'actions') and adk_event.actions:
227+
if hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta:
228+
yield self._create_state_delta_event(
229+
adk_event.actions.state_delta, thread_id, run_id
230+
)
231+
232+
if hasattr(adk_event.actions, 'state_snapshot'):
233+
state_snapshot = adk_event.actions.state_snapshot
234+
if state_snapshot is not None:
235+
yield self._create_state_snapshot_event(state_snapshot)
230236

231237

232238
# Handle custom events or metadata
@@ -583,23 +589,9 @@ def _create_state_snapshot_event(
583589
A StateSnapshotEvent
584590
"""
585591

586-
FullSnapShot = {
587-
"context": {
588-
"conversation": [],
589-
"user": {
590-
"name": state_snapshot.get("user_name", ""),
591-
"timezone": state_snapshot.get("timezone", "UTC")
592-
},
593-
"app": {
594-
"version": state_snapshot.get("app_version", "unknown")
595-
}
596-
},
597-
"state": state_snapshot.get("custom_state", {})
598-
}
599-
600592
return StateSnapshotEvent(
601593
type=EventType.STATE_SNAPSHOT,
602-
snapshot=FullSnapShot
594+
snapshot=state_snapshot
603595
)
604596

605597
async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None]:

integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ag_ui.core import (
1414
EventType, TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent,
1515
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent,
16-
StateDeltaEvent, CustomEvent
16+
StateDeltaEvent, StateSnapshotEvent, CustomEvent
1717
)
1818
from google.adk.events import Event as ADKEvent
1919
from ag_ui_adk.event_translator import EventTranslator
@@ -185,6 +185,7 @@ async def test_translate_state_delta_event(self, translator, mock_adk_event):
185185
# Mock event with state delta
186186
mock_actions = MagicMock()
187187
mock_actions.state_delta = {"key1": "value1", "key2": "value2"}
188+
mock_actions.state_snapshot = None
188189
mock_adk_event.actions = mock_actions
189190

190191
events = []
@@ -201,6 +202,55 @@ async def test_translate_state_delta_event(self, translator, mock_adk_event):
201202
assert any(patch["path"] == "/key1" and patch["value"] == "value1" for patch in patches)
202203
assert any(patch["path"] == "/key2" and patch["value"] == "value2" for patch in patches)
203204

205+
@pytest.mark.asyncio
206+
async def test_translate_state_snapshot_event_passthrough(self, translator, mock_adk_event):
207+
"""Test state snapshot events preserve the ADK payload."""
208+
209+
state_snapshot = {
210+
"user_name": "Alice",
211+
"timezone": "UTC",
212+
"custom_state": {
213+
"view": {"active_tab": "details"},
214+
"progress": 0.75,
215+
},
216+
"extra_field": [1, 2, 3],
217+
}
218+
219+
mock_adk_event.actions = SimpleNamespace(
220+
state_delta=None,
221+
state_snapshot=state_snapshot,
222+
)
223+
224+
events = []
225+
async for event in translator.translate(mock_adk_event, "thread_1", "run_1"):
226+
events.append(event)
227+
228+
snapshot_events = [event for event in events if isinstance(event, StateSnapshotEvent)]
229+
assert snapshot_events, "Expected a StateSnapshotEvent to be emitted"
230+
231+
snapshot_event = snapshot_events[0]
232+
assert snapshot_event.type == EventType.STATE_SNAPSHOT
233+
assert snapshot_event.snapshot == state_snapshot
234+
assert snapshot_event.snapshot["user_name"] == "Alice"
235+
assert snapshot_event.snapshot["custom_state"]["view"]["active_tab"] == "details"
236+
assert "extra_field" in snapshot_event.snapshot
237+
238+
def test_create_state_snapshot_event_passthrough(self, translator):
239+
"""Direct helper should forward the snapshot unchanged."""
240+
241+
state_snapshot = {
242+
"user_name": "Bob",
243+
"custom_state": {"step": 3},
244+
"timezone": "PST",
245+
}
246+
247+
event = translator._create_state_snapshot_event(state_snapshot)
248+
249+
assert isinstance(event, StateSnapshotEvent)
250+
assert event.type == EventType.STATE_SNAPSHOT
251+
assert event.snapshot == state_snapshot
252+
assert set(event.snapshot.keys()) == {"user_name", "custom_state", "timezone"}
253+
204254
@pytest.mark.asyncio
205255
async def test_translate_custom_event(self, translator, mock_adk_event):
206256
"""Test custom event creation."""
@@ -857,14 +907,15 @@ async def test_complex_event_with_multiple_features(self, translator, mock_adk_e
857907
async for event in translator.translate(mock_adk_event, "thread_1", "run_1"):
858908
events.append(event)
859909

860-
# Should have text events, state delta, and custom event
861-
assert len(events) == 5 # START, CONTENT, STATE_DELTA, CUSTOM , END
910+
# Should have text events, state delta, state snapshot, and custom event
911+
assert len(events) == 6 # START, CONTENT, STATE_DELTA, STATE_SNAPSHOT, CUSTOM, END
862912

863913
# Check event types
864914
event_types = [type(event) for event in events]
865915
assert TextMessageStartEvent in event_types
866916
assert TextMessageContentEvent in event_types
867917
assert StateDeltaEvent in event_types
918+
assert StateSnapshotEvent in event_types
868919
assert CustomEvent in event_types
869920
assert TextMessageEndEvent in event_types
870921

0 commit comments

Comments
 (0)