Skip to content

Commit a2e07d5

Browse files
yeesiancopybara-github
authored andcommitted
feat: Allow list of events to be passed to AdkApp.async_stream_query
PiperOrigin-RevId: 844834249
1 parent 47be102 commit a2e07d5

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,42 @@ def __init__(self, name: str, model: str):
7878
"streaming_mode": "sse",
7979
"max_llm_calls": 500,
8080
}
81+
_TEST_SESSION_EVENTS = [{"author": "user",
82+
"content": {"parts": [{"text": "What is the exchange rate from US dollars to "
83+
"Swedish krona on 2025-09-25?"}],
84+
"role": "user"},
85+
"id": "8967297909049524224",
86+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
87+
"timestamp": 1765832134.629513},
88+
{"author": "currency_exchange_agent",
89+
"content": {"parts": [{"functionCall": {"args": {"currency_date": "2025-09-25",
90+
"currency_from": "USD",
91+
"currency_to": "SEK"},
92+
"id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7",
93+
"name": "get_exchange_rate"}}],
94+
"role": "model"},
95+
"id": "3155402589927899136",
96+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
97+
"timestamp": 1765832134.723713},
98+
{"author": "currency_exchange_agent",
99+
"content": {"parts": [{"functionResponse": {"id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7",
100+
"name": "get_exchange_rate",
101+
"response": {"amount": 1,
102+
"base": "USD",
103+
"date": "2025-09-25",
104+
"rates": {"SEK": 9.4118}}}}],
105+
"role": "user"},
106+
"id": "1678221912150376448",
107+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
108+
"timestamp": 1765832135.764961},
109+
{"author": "currency_exchange_agent",
110+
"content": {"parts": [{"text": "The exchange rate from US dollars to Swedish "
111+
"krona on 2025-09-25 is 1 USD to 9.4118 SEK."}],
112+
"role": "model"},
113+
"id": "2470855446567583744",
114+
"invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065",
115+
"timestamp": 1765832135.853299}
116+
]
81117

82118

83119
@pytest.fixture(scope="module")
@@ -392,6 +428,46 @@ async def test_async_stream_query(self):
392428
events.append(event)
393429
assert len(events) == 1
394430

431+
@pytest.mark.asyncio
432+
async def test_async_stream_query_with_empty_session_events(self):
433+
app = reasoning_engines.AdkApp(
434+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
435+
)
436+
assert app._tmpl_attrs.get("runner") is None
437+
app.set_up()
438+
app._tmpl_attrs["runner"] = _MockRunner()
439+
events = []
440+
async for event in app.async_stream_query(
441+
user_id=_TEST_USER_ID,
442+
session_events=[],
443+
message="test message",
444+
):
445+
events.append(event)
446+
assert app._tmpl_attrs.get("session_service") is not None
447+
sessions = app.list_sessions(user_id=_TEST_USER_ID)
448+
assert len(sessions.sessions) == 1
449+
450+
@pytest.mark.asyncio
451+
async def test_async_stream_query_with_session_events(
452+
self,
453+
):
454+
app = reasoning_engines.AdkApp(
455+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
456+
)
457+
assert app._tmpl_attrs.get("runner") is None
458+
app.set_up()
459+
app._tmpl_attrs["runner"] = _MockRunner()
460+
events = []
461+
async for event in app.async_stream_query(
462+
user_id=_TEST_USER_ID,
463+
session_events=_TEST_SESSION_EVENTS,
464+
message="on the day after that?",
465+
):
466+
events.append(event)
467+
assert app._tmpl_attrs.get("session_service") is not None
468+
sessions = app.list_sessions(user_id=_TEST_USER_ID)
469+
assert len(sessions.sessions) == 1
470+
395471
@pytest.mark.asyncio
396472
@mock.patch.dict(
397473
os.environ,

vertexai/agent_engines/templates/adk.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ async def async_stream_query(
932932
message: Union[str, Dict[str, Any]],
933933
user_id: str,
934934
session_id: Optional[str] = None,
935+
session_events: Optional[List[Dict[str, Any]]] = None,
935936
run_config: Optional[Dict[str, Any]] = None,
936937
**kwargs,
937938
) -> AsyncIterable[Dict[str, Any]]:
@@ -944,7 +945,11 @@ async def async_stream_query(
944945
Required. The ID of the user.
945946
session_id (str):
946947
Optional. The ID of the session. If not provided, a new
947-
session will be created for the user.
948+
session will be created for the user. If this is specified, then
949+
`session_events` will be ignored.
950+
session_events (Optional[List[Dict[str, Any]]]):
951+
Optional. The session events to use for the query. This will be
952+
used to initialize the session if `session_id` is not provided.
948953
run_config (Optional[Dict[str, Any]]):
949954
Optional. The run config to use for the query. If you want to
950955
pass in a `run_config` pydantic object, you can pass in a dict
@@ -955,6 +960,11 @@ async def async_stream_query(
955960
956961
Yields:
957962
Event dictionaries asynchronously.
963+
964+
Raises:
965+
TypeError: If message is not a string or a dictionary representing
966+
a Content object.
967+
ValueError: If both session_id and session_events are specified.
958968
"""
959969
from vertexai.agent_engines import _utils
960970
from google.genai import types
@@ -971,9 +981,25 @@ async def async_stream_query(
971981

972982
if not self._tmpl_attrs.get("runner"):
973983
self.set_up()
984+
if session_id and session_events:
985+
raise ValueError(
986+
"Only one of session_id and session_events should be specified."
987+
)
974988
if not session_id:
975989
session = await self.async_create_session(user_id=user_id)
976990
session_id = session.id
991+
if session_events is not None:
992+
# We allow for session_events to be an empty list.
993+
from google.adk.events.event import Event
994+
995+
session_service = self._tmpl_attrs.get("session_service")
996+
for event in session_events:
997+
if not isinstance(event, Event):
998+
event = Event.model_validate(event)
999+
await session_service.append_event(
1000+
session=session,
1001+
event=event,
1002+
)
9771003

9781004
run_config = _validate_run_config(run_config)
9791005
if run_config:

0 commit comments

Comments
 (0)