Skip to content

Commit 88fb299

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for API key in AdkApp
PiperOrigin-RevId: 825638989
1 parent 02ab764 commit 88fb299

File tree

1 file changed

+51
-17
lines changed
  • vertexai/agent_engines/templates

1 file changed

+51
-17
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def __init__(
540540
"artifact_service_builder": artifact_service_builder,
541541
"memory_service_builder": memory_service_builder,
542542
"instrumentor_builder": instrumentor_builder,
543+
"express_mode_api_key": initializer.global_config.api_key,
543544
}
544545

545546
async def _init_session(
@@ -683,9 +684,18 @@ def set_up(self):
683684

684685
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
685686
project = self._tmpl_attrs.get("project")
686-
os.environ["GOOGLE_CLOUD_PROJECT"] = project
687+
if project:
688+
os.environ["GOOGLE_CLOUD_PROJECT"] = project
687689
location = self._tmpl_attrs.get("location")
688-
os.environ["GOOGLE_CLOUD_LOCATION"] = location
690+
if location:
691+
os.environ["GOOGLE_CLOUD_LOCATION"] = location
692+
express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key")
693+
if express_mode_api_key and not project:
694+
os.environ["GOOGLE_API_KEY"] = express_mode_api_key
695+
# Clear location if project is not set and express mode api key is
696+
# provided.
697+
os.environ.pop("GOOGLE_CLOUD_LOCATION", None)
698+
location = None
689699

690700
# Disable content capture in custom ADK spans unless user enabled
691701
# tracing explicitly with the old flag
@@ -769,21 +779,37 @@ def tracing_enabled() -> bool:
769779
VertexAiSessionService,
770780
)
771781

772-
self._tmpl_attrs["session_service"] = VertexAiSessionService(
773-
project=project,
774-
location=location,
775-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
776-
)
782+
if is_version_sufficient("1.18.0"):
783+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
784+
project=project,
785+
location=location,
786+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
787+
express_mode_api_key=express_mode_api_key,
788+
)
789+
else:
790+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
791+
project=project,
792+
location=location,
793+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
794+
)
777795
except (ImportError, AttributeError):
778796
from google.adk.sessions.vertex_ai_session_service_g3 import (
779797
VertexAiSessionService,
780798
)
781799

782-
self._tmpl_attrs["session_service"] = VertexAiSessionService(
783-
project=project,
784-
location=location,
785-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
786-
)
800+
if is_version_sufficient("1.18.0"):
801+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
802+
project=project,
803+
location=location,
804+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
805+
express_mode_api_key=express_mode_api_key,
806+
)
807+
else:
808+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
809+
project=project,
810+
location=location,
811+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
812+
)
787813

788814
else:
789815
self._tmpl_attrs["session_service"] = InMemorySessionService()
@@ -799,11 +825,19 @@ def tracing_enabled() -> bool:
799825
VertexAiMemoryBankService,
800826
)
801827

802-
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
803-
project=project,
804-
location=location,
805-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
806-
)
828+
if is_version_sufficient("1.18.0"):
829+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
830+
project=project,
831+
location=location,
832+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
833+
express_mode_api_key=express_mode_api_key,
834+
)
835+
else:
836+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
837+
project=project,
838+
location=location,
839+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
840+
)
807841
except (ImportError, AttributeError):
808842
# TODO(ysian): Handle this via _g3 import for google3.
809843
pass

0 commit comments

Comments
 (0)