Skip to content

Commit 7867525

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for API key in AdkApp
PiperOrigin-RevId: 825638989
1 parent 6737a70 commit 7867525

File tree

3 files changed

+71
-54
lines changed

3 files changed

+71
-54
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,9 +1070,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
10701070
)
10711071

10721072
mock_agent_engine = mock.Mock()
1073-
mock_agent_engine.async_create_session = mock.AsyncMock(
1074-
return_value={"id": "session1"}
1075-
)
1073+
mock_agent_engine.create_session.return_value = {"id": "session1"}
10761074
stream_query_return_value = [
10771075
{
10781076
"id": "1",
@@ -1088,13 +1086,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
10881086
},
10891087
]
10901088

1091-
async def _async_iterator(iterable):
1092-
for item in iterable:
1093-
yield item
1094-
1095-
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1096-
stream_query_return_value
1097-
)
1089+
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
10981090
mock_vertexai_client.return_value.agent_engines.get.return_value = (
10991091
mock_agent_engine
11001092
)
@@ -1108,10 +1100,10 @@ async def _async_iterator(iterable):
11081100
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
11091101
name="projects/test-project/locations/us-central1/reasoningEngines/123"
11101102
)
1111-
mock_agent_engine.async_create_session.assert_called_once_with(
1103+
mock_agent_engine.create_session.assert_called_once_with(
11121104
user_id="123", state={"a": "1"}
11131105
)
1114-
mock_agent_engine.async_stream_query.assert_called_once_with(
1106+
mock_agent_engine.stream_query.assert_called_once_with(
11151107
user_id="123", session_id="session1", message="agent prompt"
11161108
)
11171109

@@ -1162,9 +1154,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11621154
)
11631155

11641156
mock_agent_engine = mock.Mock()
1165-
mock_agent_engine.async_create_session = mock.AsyncMock(
1166-
return_value={"id": "session1"}
1167-
)
1157+
mock_agent_engine.create_session.return_value = {"id": "session1"}
11681158
stream_query_return_value = [
11691159
{
11701160
"id": "1",
@@ -1180,13 +1170,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11801170
},
11811171
]
11821172

1183-
async def _async_iterator(iterable):
1184-
for item in iterable:
1185-
yield item
1186-
1187-
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1188-
stream_query_return_value
1189-
)
1173+
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
11901174
mock_vertexai_client.return_value.agent_engines.get.return_value = (
11911175
mock_agent_engine
11921176
)
@@ -1200,10 +1184,10 @@ async def _async_iterator(iterable):
12001184
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
12011185
name="projects/test-project/locations/us-central1/reasoningEngines/123"
12021186
)
1203-
mock_agent_engine.async_create_session.assert_called_once_with(
1187+
mock_agent_engine.create_session.assert_called_once_with(
12041188
user_id="123", state={"a": "1"}
12051189
)
1206-
mock_agent_engine.async_stream_query.assert_called_once_with(
1190+
mock_agent_engine.stream_query.assert_called_once_with(
12071191
user_id="123", session_id="session1", message="agent prompt"
12081192
)
12091193

vertexai/_genai/_evals_common.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,10 @@ def agent_run_wrapper(
278278
and type(agent_engine).__name__ == "AgentEngine"
279279
):
280280
agent_engine_instance = agent_engine
281-
return asyncio.run(
282-
inference_fn_arg(
283-
row=row_arg,
284-
contents=contents_arg,
285-
agent_engine=agent_engine_instance,
286-
)
281+
return inference_fn_arg(
282+
row=row_arg,
283+
contents=contents_arg,
284+
agent_engine=agent_engine_instance,
287285
)
288286

289287
future = executor.submit(
@@ -1265,7 +1263,7 @@ def _run_agent(
12651263
)
12661264

12671265

1268-
async def _execute_agent_run_with_retry(
1266+
def _execute_agent_run_with_retry(
12691267
row: pd.Series,
12701268
contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict],
12711269
agent_engine: types.AgentEngine,
@@ -1287,7 +1285,7 @@ async def _execute_agent_run_with_retry(
12871285
)
12881286
user_id = session_inputs.user_id
12891287
session_state = session_inputs.state
1290-
session = await agent_engine.async_create_session(
1288+
session = agent_engine.create_session(
12911289
user_id=user_id,
12921290
state=session_state,
12931291
)
@@ -1298,7 +1296,7 @@ async def _execute_agent_run_with_retry(
12981296
for attempt in range(max_retries):
12991297
try:
13001298
responses = []
1301-
async for event in agent_engine.async_stream_query(
1299+
for event in agent_engine.stream_query(
13021300
user_id=user_id,
13031301
session_id=session["id"],
13041302
message=contents,
@@ -1317,7 +1315,7 @@ async def _execute_agent_run_with_retry(
13171315
)
13181316
if attempt == max_retries - 1:
13191317
return {"error": f"Resource exhausted after retries: {e}"}
1320-
await asyncio.sleep(2**attempt)
1318+
time.sleep(2**attempt)
13211319
except Exception as e: # pylint: disable=broad-exception-caught
13221320
logger.error(
13231321
"Unexpected error during generate_content on attempt %d/%d: %s",
@@ -1328,7 +1326,7 @@ async def _execute_agent_run_with_retry(
13281326

13291327
if attempt == max_retries - 1:
13301328
return {"error": f"Failed after retries: {e}"}
1331-
await asyncio.sleep(1)
1329+
time.sleep(1)
13321330
return {"error": f"Failed to get agent run results after {max_retries} retries"}
13331331

13341332

vertexai/agent_engines/templates/adk.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,13 @@ def __init__(
502502
used for customizing the instrumentation logic of the Agent.
503503
If not provided, a default instrumentor builder will be used.
504504
This parameter is ignored if `enable_tracing` is False.
505+
express_mode_api_key (str):
506+
Optional. The API key to use for Express Mode. If not
507+
provided, the API key from the GOOGLE_API_KEY environment
508+
variable will be used. It will only be used if
509+
GOOGLE_GENAI_USE_VERTEXAI is true. Do not use Google AI Studio
510+
API key for this field. For more details, visit
511+
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
505512
"""
506513
from google.cloud.aiplatform import initializer
507514

@@ -540,6 +547,7 @@ def __init__(
540547
"artifact_service_builder": artifact_service_builder,
541548
"memory_service_builder": memory_service_builder,
542549
"instrumentor_builder": instrumentor_builder,
550+
"express_mode_api_key": initializer.global_config.api_key,
543551
}
544552

545553
async def _init_session(
@@ -683,9 +691,14 @@ def set_up(self):
683691

684692
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
685693
project = self._tmpl_attrs.get("project")
686-
os.environ["GOOGLE_CLOUD_PROJECT"] = project
694+
if project:
695+
os.environ["GOOGLE_CLOUD_PROJECT"] = project
687696
location = self._tmpl_attrs.get("location")
688-
os.environ["GOOGLE_CLOUD_LOCATION"] = location
697+
if location:
698+
os.environ["GOOGLE_CLOUD_LOCATION"] = location
699+
express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key")
700+
if express_mode_api_key and not project and not location:
701+
os.environ["GOOGLE_API_KEY"] = express_mode_api_key
689702

690703
# Disable content capture in custom ADK spans unless user enabled
691704
# tracing explicitly with the old flag
@@ -768,22 +781,37 @@ def tracing_enabled() -> bool:
768781
from google.adk.sessions.vertex_ai_session_service import (
769782
VertexAiSessionService,
770783
)
771-
772-
self._tmpl_attrs["session_service"] = VertexAiSessionService(
773-
project=project,
774-
location=location,
775-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
776-
)
784+
if is_version_sufficient("1.18.0"):
785+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
786+
project=project,
787+
location=location,
788+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
789+
express_mode_api_key=os.environ.get("GOOGLE_API_KEY"),
790+
)
791+
else:
792+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
793+
project=project,
794+
location=location,
795+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
796+
)
777797
except (ImportError, AttributeError):
778798
from google.adk.sessions.vertex_ai_session_service_g3 import (
779799
VertexAiSessionService,
780800
)
781801

782-
self._tmpl_attrs["session_service"] = VertexAiSessionService(
783-
project=project,
784-
location=location,
785-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
786-
)
802+
if is_version_sufficient("1.18.0"):
803+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
804+
project=project,
805+
location=location,
806+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
807+
express_mode_api_key=os.environ.get("GOOGLE_API_KEY"),
808+
)
809+
else:
810+
self._tmpl_attrs["session_service"] = VertexAiSessionService(
811+
project=project,
812+
location=location,
813+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
814+
)
787815

788816
else:
789817
self._tmpl_attrs["session_service"] = InMemorySessionService()
@@ -798,12 +826,19 @@ def tracing_enabled() -> bool:
798826
from google.adk.memory.vertex_ai_memory_bank_service import (
799827
VertexAiMemoryBankService,
800828
)
801-
802-
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
803-
project=project,
804-
location=location,
805-
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
806-
)
829+
if is_version_sufficient("1.18.0"):
830+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
831+
project=project,
832+
location=location,
833+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
834+
express_mode_api_key=os.environ.get("GOOGLE_API_KEY"),
835+
)
836+
else:
837+
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
838+
project=project,
839+
location=location,
840+
agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"),
841+
)
807842
except (ImportError, AttributeError):
808843
# TODO(ysian): Handle this via _g3 import for google3.
809844
pass

0 commit comments

Comments
 (0)