@@ -540,6 +540,7 @@ def __init__(
540540 "artifact_service_builder" : artifact_service_builder ,
541541 "memory_service_builder" : memory_service_builder ,
542542 "instrumentor_builder" : instrumentor_builder ,
543+ "express_mode_api_key" : initializer .global_config .api_key ,
543544 }
544545
545546 async def _init_session (
@@ -683,9 +684,18 @@ def set_up(self):
683684
684685 os .environ ["GOOGLE_GENAI_USE_VERTEXAI" ] = "1"
685686 project = self ._tmpl_attrs .get ("project" )
686- os .environ ["GOOGLE_CLOUD_PROJECT" ] = project
687+ if project :
688+ os .environ ["GOOGLE_CLOUD_PROJECT" ] = project
687689 location = self ._tmpl_attrs .get ("location" )
688- os .environ ["GOOGLE_CLOUD_LOCATION" ] = location
690+ if location :
691+ os .environ ["GOOGLE_CLOUD_LOCATION" ] = location
692+ express_mode_api_key = self ._tmpl_attrs .get ("express_mode_api_key" )
693+ if express_mode_api_key and not project :
694+ os .environ ["GOOGLE_API_KEY" ] = express_mode_api_key
695+ # Clear location if project is not set and express mode api key is
696+ # provided.
697+ os .environ .pop ("GOOGLE_CLOUD_LOCATION" , None )
698+ location = None
689699
690700 # Disable content capture in custom ADK spans unless user enabled
691701 # tracing explicitly with the old flag
@@ -769,21 +779,37 @@ def tracing_enabled() -> bool:
769779 VertexAiSessionService ,
770780 )
771781
772- self ._tmpl_attrs ["session_service" ] = VertexAiSessionService (
773- project = project ,
774- location = location ,
775- agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
776- )
782+ if is_version_sufficient ("1.18.0" ):
783+ self ._tmpl_attrs ["session_service" ] = VertexAiSessionService (
784+ project = project ,
785+ location = location ,
786+ agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
787+ express_mode_api_key = express_mode_api_key ,
788+ )
789+ else :
790+ self ._tmpl_attrs ["session_service" ] = VertexAiSessionService (
791+ project = project ,
792+ location = location ,
793+ agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
794+ )
777795 except (ImportError , AttributeError ):
778796 from google .adk .sessions .vertex_ai_session_service_g3 import (
779797 VertexAiSessionService ,
780798 )
781799
782- self ._tmpl_attrs ["session_service" ] = VertexAiSessionService (
783- project = project ,
784- location = location ,
785- agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
786- )
800+ if is_version_sufficient ("1.18.0" ):
801+ self ._tmpl_attrs ["session_service" ] = VertexAiSessionService (
802+ project = project ,
803+ location = location ,
804+ agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
805+ express_mode_api_key = express_mode_api_key ,
806+ )
807+ else :
808+ self ._tmpl_attrs ["session_service" ] = VertexAiSessionService (
809+ project = project ,
810+ location = location ,
811+ agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
812+ )
787813
788814 else :
789815 self ._tmpl_attrs ["session_service" ] = InMemorySessionService ()
@@ -799,11 +825,19 @@ def tracing_enabled() -> bool:
799825 VertexAiMemoryBankService ,
800826 )
801827
802- self ._tmpl_attrs ["memory_service" ] = VertexAiMemoryBankService (
803- project = project ,
804- location = location ,
805- agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
806- )
828+ if is_version_sufficient ("1.18.0" ):
829+ self ._tmpl_attrs ["memory_service" ] = VertexAiMemoryBankService (
830+ project = project ,
831+ location = location ,
832+ agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
833+ express_mode_api_key = express_mode_api_key ,
834+ )
835+ else :
836+ self ._tmpl_attrs ["memory_service" ] = VertexAiMemoryBankService (
837+ project = project ,
838+ location = location ,
839+ agent_engine_id = os .environ .get ("GOOGLE_CLOUD_AGENT_ENGINE_ID" ),
840+ )
807841 except (ImportError , AttributeError ):
808842 # TODO(ysian): Handle this via _g3 import for google3.
809843 pass
0 commit comments