Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,16 @@ async def translate(


# Handle state changes
if hasattr(adk_event, 'actions') and adk_event.actions and hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta:
yield self._create_state_delta_event(
adk_event.actions.state_delta, thread_id, run_id
)
if hasattr(adk_event, 'actions') and adk_event.actions:
if hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta:
yield self._create_state_delta_event(
adk_event.actions.state_delta, thread_id, run_id
)

if hasattr(adk_event.actions, 'state_snapshot'):
state_snapshot = adk_event.actions.state_snapshot
if state_snapshot is not None:
yield self._create_state_snapshot_event(state_snapshot)


# Handle custom events or metadata
Expand Down Expand Up @@ -583,23 +589,9 @@ def _create_state_snapshot_event(
A StateSnapshotEvent
"""

FullSnapShot = {
"context": {
"conversation": [],
"user": {
"name": state_snapshot.get("user_name", ""),
"timezone": state_snapshot.get("timezone", "UTC")
},
"app": {
"version": state_snapshot.get("app_version", "unknown")
}
},
"state": state_snapshot.get("custom_state", {})
}

return StateSnapshotEvent(
type=EventType.STATE_SNAPSHOT,
snapshot=FullSnapShot
snapshot=state_snapshot
)

async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ag_ui.core import (
EventType, TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent,
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent,
StateDeltaEvent, CustomEvent
StateDeltaEvent, StateSnapshotEvent, CustomEvent
)
from google.adk.events import Event as ADKEvent
from ag_ui_adk.event_translator import EventTranslator
Expand Down Expand Up @@ -185,6 +185,7 @@ async def test_translate_state_delta_event(self, translator, mock_adk_event):
# Mock event with state delta
mock_actions = MagicMock()
mock_actions.state_delta = {"key1": "value1", "key2": "value2"}
mock_actions.state_snapshot = None
mock_adk_event.actions = mock_actions

events = []
Expand All @@ -201,6 +202,55 @@ async def test_translate_state_delta_event(self, translator, mock_adk_event):
assert any(patch["path"] == "/key1" and patch["value"] == "value1" for patch in patches)
assert any(patch["path"] == "/key2" and patch["value"] == "value2" for patch in patches)

@pytest.mark.asyncio
async def test_translate_state_snapshot_event_passthrough(self, translator, mock_adk_event):
"""Test state snapshot events preserve the ADK payload."""

state_snapshot = {
"user_name": "Alice",
"timezone": "UTC",
"custom_state": {
"view": {"active_tab": "details"},
"progress": 0.75,
},
"extra_field": [1, 2, 3],
}

mock_adk_event.actions = SimpleNamespace(
state_delta=None,
state_snapshot=state_snapshot,
)

events = []
async for event in translator.translate(mock_adk_event, "thread_1", "run_1"):
events.append(event)

snapshot_events = [event for event in events if isinstance(event, StateSnapshotEvent)]
assert snapshot_events, "Expected a StateSnapshotEvent to be emitted"

snapshot_event = snapshot_events[0]
assert snapshot_event.type == EventType.STATE_SNAPSHOT
assert snapshot_event.snapshot == state_snapshot
assert snapshot_event.snapshot["user_name"] == "Alice"
assert snapshot_event.snapshot["custom_state"]["view"]["active_tab"] == "details"
assert "extra_field" in snapshot_event.snapshot

def test_create_state_snapshot_event_passthrough(self, translator):
"""Direct helper should forward the snapshot unchanged."""

state_snapshot = {
"user_name": "Bob",
"custom_state": {"step": 3},
"timezone": "PST",
}

event = translator._create_state_snapshot_event(state_snapshot)

assert isinstance(event, StateSnapshotEvent)
assert event.type == EventType.STATE_SNAPSHOT
assert event.snapshot == state_snapshot
assert set(event.snapshot.keys()) == {"user_name", "custom_state", "timezone"}

@pytest.mark.asyncio
async def test_translate_custom_event(self, translator, mock_adk_event):
"""Test custom event creation."""
Expand Down Expand Up @@ -857,14 +907,15 @@ async def test_complex_event_with_multiple_features(self, translator, mock_adk_e
async for event in translator.translate(mock_adk_event, "thread_1", "run_1"):
events.append(event)

# Should have text events, state delta, and custom event
assert len(events) == 5 # START, CONTENT, STATE_DELTA, CUSTOM , END
# Should have text events, state delta, state snapshot, and custom event
assert len(events) == 6 # START, CONTENT, STATE_DELTA, STATE_SNAPSHOT, CUSTOM, END

# Check event types
event_types = [type(event) for event in events]
assert TextMessageStartEvent in event_types
assert TextMessageContentEvent in event_types
assert StateDeltaEvent in event_types
assert StateSnapshotEvent in event_types
assert CustomEvent in event_types
assert TextMessageEndEvent in event_types

Expand Down
Loading