Skip to content

Commit fbcb247

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
feat: Alow VertexAiSession for streaming_agent_run_with_events
PiperOrigin-RevId: 824600367
1 parent 09bf9a9 commit fbcb247

File tree

2 files changed

+58
-37
lines changed
  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

2 files changed

+58
-37
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,6 @@ async def _init_session(
550550
):
551551
"""Initializes the session, and returns the session id."""
552552
from google.adk.events.event import Event
553-
import random
554553

555554
session_state = None
556555
if request.authorizations:
@@ -559,14 +558,9 @@ async def _init_session(
559558
auth = _Authorization(**auth)
560559
session_state[f"temp:{auth_id}"] = auth.access_token
561560

562-
if request.session_id:
563-
session_id = request.session_id
564-
else:
565-
session_id = f"temp_session_{random.randbytes(8).hex()}"
566561
session = await session_service.create_session(
567562
app_name=self._tmpl_attrs.get("app_name"),
568563
user_id=request.user_id,
569-
session_id=session_id,
570564
state=session_state,
571565
)
572566
if not session:
@@ -1012,43 +1006,60 @@ async def streaming_agent_run_with_events(self, request_json: str):
10121006

10131007
import json
10141008
from google.genai import types
1009+
from google.genai.errors import ClientError
10151010

10161011
request = _StreamRunRequest(**json.loads(request_json))
10171012
if not self._tmpl_attrs.get("in_memory_runner"):
10181013
self.set_up()
1014+
if not self._tmpl_attrs.get("runner"):
1015+
self.set_up()
10191016
# Prepare the in-memory session.
10201017
if not self._tmpl_attrs.get("in_memory_artifact_service"):
10211018
self.set_up()
1019+
if not self._tmpl_attrs.get("artifact_service"):
1020+
self.set_up()
10221021
if not self._tmpl_attrs.get("in_memory_session_service"):
10231022
self.set_up()
1024-
session_service = self._tmpl_attrs.get("in_memory_session_service")
1025-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1023+
if not self._tmpl_attrs.get("session_service"):
1024+
self.set_up()
1025+
session_service = self._tmpl_attrs.get("session_service")
1026+
artifact_service = self._tmpl_attrs.get("artifact_service")
10261027
app = self._tmpl_attrs.get("app")
1028+
runner = self._tmpl_attrs.get("runner")
10271029
# Try to get the session, if it doesn't exist, create a new one.
1028-
session = None
10291030
if request.session_id:
10301031
try:
10311032
session = await session_service.get_session(
10321033
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10331034
user_id=request.user_id,
10341035
session_id=request.session_id,
10351036
)
1036-
except RuntimeError:
1037-
pass
1038-
if not session:
1039-
# Fall back to create session if the session is not found.
1040-
session = await self._init_session(
1041-
session_service=session_service,
1042-
artifact_service=artifact_service,
1043-
request=request,
1037+
except ClientError:
1038+
# Fall back to create session if the session is not found.
1039+
# Specifying session_id on creation is not supported,
1040+
# so session id will be regenerated.
1041+
session = await self._init_session(
1042+
session_service=session_service,
1043+
artifact_service=artifact_service,
1044+
request=request,
1045+
)
1046+
else:
1047+
# Not providing a session ID will create a new in-memory session.
1048+
session_service = self._tmpl_attrs.get("in_memory_session_service")
1049+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1050+
runner = self._tmpl_attrs.get("in_memory_runner")
1051+
session = await session_service.create_session(
1052+
app_name=self._tmpl_attrs.get("app_name"),
1053+
user_id=request.user_id,
1054+
session_id=request.session_id,
10441055
)
10451056
if not session:
10461057
raise RuntimeError("Session initialization failed.")
10471058

10481059
# Run the agent
10491060
message_for_agent = types.Content(**request.message)
10501061
try:
1051-
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
1062+
async for event in runner.run_async(
10521063
user_id=request.user_id,
10531064
session_id=session.id,
10541065
new_message=message_for_agent,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,6 @@ async def _init_session(
502502
):
503503
"""Initializes the session, and returns the session id."""
504504
from google.adk.events.event import Event
505-
import random
506505

507506
session_state = None
508507
if request.authorizations:
@@ -511,14 +510,9 @@ async def _init_session(
511510
auth = _Authorization(**auth)
512511
session_state[f"temp:{auth_id}"] = auth.access_token
513512

514-
if request.session_id:
515-
session_id = request.session_id
516-
else:
517-
session_id = f"temp_session_{random.randbytes(8).hex()}"
518513
session = await session_service.create_session(
519514
app_name=self._tmpl_attrs.get("app_name"),
520515
user_id=request.user_id,
521-
session_id=session_id,
522516
state=session_state,
523517
)
524518
if not session:
@@ -881,44 +875,60 @@ async def async_stream_query(
881875
def streaming_agent_run_with_events(self, request_json: str):
882876
import json
883877
from google.genai import types
878+
from google.genai.errors import ClientError
884879

885880
event_queue = queue.Queue(maxsize=1)
886881

887882
async def _invoke_agent_async():
888883
request = _StreamRunRequest(**json.loads(request_json))
889884
if not self._tmpl_attrs.get("in_memory_runner"):
890885
self.set_up()
886+
if not self._tmpl_attrs.get("runner"):
887+
self.set_up()
891888
# Prepare the in-memory session.
892889
if not self._tmpl_attrs.get("in_memory_artifact_service"):
893890
self.set_up()
891+
if not self._tmpl_attrs.get("artifact_service"):
892+
self.set_up()
894893
if not self._tmpl_attrs.get("in_memory_session_service"):
895894
self.set_up()
896-
session_service = self._tmpl_attrs.get("in_memory_session_service")
897-
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
898-
# Try to get the session, if it doesn't exist, create a new one.
899-
session = None
895+
if not self._tmpl_attrs.get("session_service"):
896+
self.set_up()
897+
session_service = self._tmpl_attrs.get("session_service")
898+
artifact_service = self._tmpl_attrs.get("artifact_service")
899+
runner = self._tmpl_attrs.get("runner")
900900
if request.session_id:
901901
try:
902902
session = await session_service.get_session(
903903
app_name=self._tmpl_attrs.get("app_name"),
904904
user_id=request.user_id,
905905
session_id=request.session_id,
906906
)
907-
except RuntimeError:
908-
pass
909-
if not session:
910-
# Fall back to create session if the session is not found.
911-
session = await self._init_session(
912-
session_service=session_service,
913-
artifact_service=artifact_service,
914-
request=request,
907+
except ClientError:
908+
# Fall back to create session if the session is not found.
909+
# Specifying session_id on creation is not supported,
910+
# so session id will be regenerated.
911+
session = await self._init_session(
912+
session_service=session_service,
913+
artifact_service=artifact_service,
914+
request=request,
915+
)
916+
else:
917+
# Not providing a session ID will create a new in-memory session.
918+
session_service = self._tmpl_attrs.get("in_memory_session_service")
919+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
920+
runner = self._tmpl_attrs.get("in_memory_runner")
921+
session = await session_service.create_session(
922+
app_name=self._tmpl_attrs.get("app_name"),
923+
user_id=request.user_id,
924+
session_id=request.session_id,
915925
)
916926
if not session:
917927
raise RuntimeError("Session initialization failed.")
918928
# Run the agent.
919929
message_for_agent = types.Content(**request.message)
920930
try:
921-
for event in self._tmpl_attrs.get("in_memory_runner").run(
931+
for event in runner.run(
922932
user_id=request.user_id,
923933
session_id=session.id,
924934
new_message=message_for_agent,

0 commit comments

Comments
 (0)