Skip to content

Commit 0cfe889

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for app input in AdkApp template
PiperOrigin-RevId: 825039994
1 parent 9ae5f35 commit 0cfe889

File tree

1 file changed

+50
-10
lines changed
  • vertexai/agent_engines/templates

1 file changed

+50
-10
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@
3737
except (ImportError, AttributeError):
3838
Event = Any
3939

40+
try:
41+
from google.adk.apps import App
42+
43+
App = App
44+
except (ImportError, AttributeError):
45+
App = Any
46+
4047
try:
4148
from google.adk.agents import BaseAgent
4249

@@ -449,7 +456,8 @@ class AdkApp:
449456
def __init__(
450457
self,
451458
*,
452-
agent: "BaseAgent",
459+
agent: "BaseAgent" = None,
460+
app: "App" = None,
453461
app_name: Optional[str] = None,
454462
plugins: Optional[List["BasePlugin"]] = None,
455463
enable_tracing: Optional[bool] = None,
@@ -505,10 +513,28 @@ def __init__(
505513
)
506514
raise ValueError(msg)
507515

516+
if not agent and not app:
517+
raise ValueError("One of `agent` or `app` must be provided.")
518+
if app:
519+
if app_name:
520+
raise ValueError(
521+
"When app is provided, app_name should not be provided."
522+
)
523+
if agent:
524+
raise ValueError(
525+
"When app is provided, agent should not be provided."
526+
)
527+
if plugins:
528+
raise ValueError(
529+
"When app is provided, plugins should not be provided and"
530+
" should be provided in the app instead."
531+
)
532+
508533
self._tmpl_attrs: Dict[str, Any] = {
509534
"project": initializer.global_config.project,
510535
"location": initializer.global_config.location,
511536
"agent": agent,
537+
"app": app,
512538
"app_name": app_name,
513539
"plugins": plugins,
514540
"enable_tracing": enable_tracing,
@@ -625,6 +651,7 @@ def clone(self):
625651

626652
return self.__class__(
627653
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
654+
app=copy.deepcopy(self._tmpl_attrs.get("app")),
628655
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
629656
app_name=self._tmpl_attrs.get("app_name"),
630657
plugins=self._tmpl_attrs.get("plugins"),
@@ -775,18 +802,24 @@ def tracing_enabled() -> bool:
775802

776803
self._tmpl_attrs["runner"] = Runner(
777804
agent=self._tmpl_attrs.get("agent"),
805+
app=self._tmpl_attrs.get("app"),
778806
plugins=self._tmpl_attrs.get("plugins"),
779807
session_service=self._tmpl_attrs.get("session_service"),
780808
artifact_service=self._tmpl_attrs.get("artifact_service"),
781809
memory_service=self._tmpl_attrs.get("memory_service"),
782-
app_name=self._tmpl_attrs.get("app_name"),
810+
app_name=(
811+
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("app_name")
812+
),
783813
)
784814
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
785815
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
786816
self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService()
787817
self._tmpl_attrs["in_memory_runner"] = Runner(
788-
app_name=self._tmpl_attrs.get("app_name"),
818+
app_name=(
819+
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("app_name")
820+
),
789821
agent=self._tmpl_attrs.get("agent"),
822+
app=self._tmpl_attrs.get("app"),
790823
plugins=self._tmpl_attrs.get("plugins"),
791824
session_service=self._tmpl_attrs.get("in_memory_session_service"),
792825
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
@@ -968,12 +1001,13 @@ async def streaming_agent_run_with_events(self, request_json: str):
9681001
self.set_up()
9691002
session_service = self._tmpl_attrs.get("in_memory_session_service")
9701003
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1004+
app = self._tmpl_attrs.get("app")
9711005
# Try to get the session, if it doesn't exist, create a new one.
9721006
session = None
9731007
if request.session_id:
9741008
try:
9751009
session = await session_service.get_session(
976-
app_name=self._tmpl_attrs.get("app_name"),
1010+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
9771011
user_id=request.user_id,
9781012
session_id=request.session_id,
9791013
)
@@ -1006,8 +1040,9 @@ async def streaming_agent_run_with_events(self, request_json: str):
10061040
yield converted_event
10071041
finally:
10081042
if session and not request.session_id:
1043+
app = self._tmpl_attrs.get("app")
10091044
await session_service.delete_session(
1010-
app_name=self._tmpl_attrs.get("app_name"),
1045+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10111046
user_id=request.user_id,
10121047
session_id=session.id,
10131048
)
@@ -1039,8 +1074,9 @@ async def async_get_session(
10391074
"""
10401075
if not self._tmpl_attrs.get("session_service"):
10411076
self.set_up()
1077+
app = self._tmpl_attrs.get("app")
10421078
session = await self._tmpl_attrs.get("session_service").get_session(
1043-
app_name=self._tmpl_attrs.get("app_name"),
1079+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10441080
user_id=user_id,
10451081
session_id=session_id,
10461082
**kwargs,
@@ -1116,8 +1152,9 @@ async def async_list_sessions(self, *, user_id: str, **kwargs):
11161152
"""
11171153
if not self._tmpl_attrs.get("session_service"):
11181154
self.set_up()
1155+
app = self._tmpl_attrs.get("app")
11191156
return await self._tmpl_attrs.get("session_service").list_sessions(
1120-
app_name=self._tmpl_attrs.get("app_name"),
1157+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
11211158
user_id=user_id,
11221159
**kwargs,
11231160
)
@@ -1188,8 +1225,9 @@ async def async_create_session(
11881225
"""
11891226
if not self._tmpl_attrs.get("session_service"):
11901227
self.set_up()
1228+
app = self._tmpl_attrs.get("app")
11911229
session = await self._tmpl_attrs.get("session_service").create_session(
1192-
app_name=self._tmpl_attrs.get("app_name"),
1230+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
11931231
user_id=user_id,
11941232
session_id=session_id,
11951233
state=state,
@@ -1269,8 +1307,9 @@ async def async_delete_session(
12691307
"""
12701308
if not self._tmpl_attrs.get("session_service"):
12711309
self.set_up()
1310+
app = self._tmpl_attrs.get("app")
12721311
await self._tmpl_attrs.get("session_service").delete_session(
1273-
app_name=self._tmpl_attrs.get("app_name"),
1312+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
12741313
user_id=user_id,
12751314
session_id=session_id,
12761315
**kwargs,
@@ -1359,8 +1398,9 @@ async def async_search_memory(self, *, user_id: str, query: str):
13591398
"""
13601399
if not self._tmpl_attrs.get("memory_service"):
13611400
self.set_up()
1401+
app = self._tmpl_attrs.get("app")
13621402
return await self._tmpl_attrs.get("memory_service").search_memory(
1363-
app_name=self._tmpl_attrs.get("app_name"),
1403+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
13641404
user_id=user_id,
13651405
query=query,
13661406
)

0 commit comments

Comments
 (0)