Skip to content

Commit 94a1052

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for app input in AdkApp template
PiperOrigin-RevId: 825039994
1 parent 9ef8d05 commit 94a1052

File tree

1 file changed

+40
-3
lines changed
  • vertexai/agent_engines/templates

1 file changed

+40
-3
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@
3737
except (ImportError, AttributeError):
3838
Event = Any
3939

40+
try:
41+
from google.adk.apps import App
42+
43+
App = App
44+
except (ImportError, AttributeError):
45+
App = Any
46+
4047
try:
4148
from google.adk.agents import BaseAgent
4249

@@ -449,7 +456,8 @@ class AdkApp:
449456
def __init__(
450457
self,
451458
*,
452-
agent: "BaseAgent",
459+
agent: "BaseAgent" = None,
460+
app: "App" = None,
453461
app_name: Optional[str] = None,
454462
plugins: Optional[List["BasePlugin"]] = None,
455463
enable_tracing: Optional[bool] = None,
@@ -505,10 +513,28 @@ def __init__(
505513
)
506514
raise ValueError(msg)
507515

516+
if not agent and not app:
517+
raise ValueError("One of `agent` or `app` must be provided.")
518+
if app:
519+
if app_name:
520+
raise ValueError(
521+
"When app is provided, app_name should not be provided."
522+
)
523+
if agent:
524+
raise ValueError(
525+
"When app is provided, agent should not be provided."
526+
)
527+
if plugins:
528+
raise ValueError(
529+
"When app is provided, plugins should not be provided and"
530+
" should be provided in the app instead."
531+
)
532+
508533
self._tmpl_attrs: Dict[str, Any] = {
509534
"project": initializer.global_config.project,
510535
"location": initializer.global_config.location,
511536
"agent": agent,
537+
"app": app,
512538
"app_name": app_name,
513539
"plugins": plugins,
514540
"enable_tracing": enable_tracing,
@@ -625,6 +651,7 @@ def clone(self):
625651

626652
return self.__class__(
627653
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
654+
app=copy.deepcopy(self._tmpl_attrs.get("app")),
628655
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
629656
app_name=self._tmpl_attrs.get("app_name"),
630657
plugins=self._tmpl_attrs.get("plugins"),
@@ -775,18 +802,28 @@ def tracing_enabled() -> bool:
775802

776803
self._tmpl_attrs["runner"] = Runner(
777804
agent=self._tmpl_attrs.get("agent"),
805+
app=self._tmpl_attrs.get("app"),
778806
plugins=self._tmpl_attrs.get("plugins"),
779807
session_service=self._tmpl_attrs.get("session_service"),
780808
artifact_service=self._tmpl_attrs.get("artifact_service"),
781809
memory_service=self._tmpl_attrs.get("memory_service"),
782-
app_name=self._tmpl_attrs.get("app_name"),
810+
app_name=(
811+
self._tmpl_attrs.get("app").name
812+
if self._tmpl_attrs.get("app")
813+
else self._tmpl_attrs.get("app_name")
814+
),
783815
)
784816
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
785817
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
786818
self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService()
787819
self._tmpl_attrs["in_memory_runner"] = Runner(
788-
app_name=self._tmpl_attrs.get("app_name"),
820+
app_name=(
821+
self._tmpl_attrs.get("app").name
822+
if self._tmpl_attrs.get("app")
823+
else self._tmpl_attrs.get("app_name")
824+
),
789825
agent=self._tmpl_attrs.get("agent"),
826+
app=self._tmpl_attrs.get("app"),
790827
plugins=self._tmpl_attrs.get("plugins"),
791828
session_service=self._tmpl_attrs.get("in_memory_session_service"),
792829
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),

0 commit comments

Comments
 (0)