Skip to content

Commit 10a95e4

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Support agent engine sandbox http request in genai sdk
PiperOrigin-RevId: 816865842
1 parent cffa558 commit 10a95e4

File tree

3 files changed

+262
-0
lines changed

3 files changed

+262
-0
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
"opentelemetry-exporter-otlp-proto-http < 2",
169169
"pydantic >= 2.11.1, < 3",
170170
"typing_extensions",
171+
"google-cloud-iam",
171172
]
172173

173174
evaluation_extra_require = [
@@ -256,6 +257,7 @@
256257
"bigframes; python_version>='3.10' and python_version<'3.14'",
257258
# google-api-core 2.x is required since kfp requires protobuf > 4
258259
"google-api-core >= 2.11, < 3.0.0",
260+
"google-cloud-iam",
259261
"grpcio-testing",
260262
"grpcio-tools >= 1.63.0; python_version>='3.13'",
261263
"ipython",
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import importlib
17+
import os
18+
19+
from unittest import mock
20+
from urllib.parse import urlencode
21+
22+
from google import auth
23+
from google.auth import credentials as auth_credentials
24+
from google.cloud import aiplatform
25+
import vertexai
26+
from google.cloud.aiplatform import initializer
27+
from vertexai._genai import _agent_engines_utils
28+
from vertexai._genai import types as _genai_types
29+
from google.genai import client
30+
from google.genai import types as genai_types
31+
import pytest
32+
33+
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
34+
_TEST_LOCATION = "us-central1"
35+
_TEST_PROJECT = "test-project"
36+
_TEST_RESOURCE_ID = "1028944691210842416"
37+
_TEST_SANDBOX_ID = "sandbox-123"
38+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
39+
_TEST_AGENT_ENGINE_RESOURCE_NAME = (
40+
f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}"
41+
)
42+
_TEST_SANDBOX_RESOURCE_NAME = (
43+
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}/sandboxes/{_TEST_SANDBOX_ID}"
44+
)
45+
_TEST_AGENT_ENGINE_ENV_KEY = "GOOGLE_CLOUD_AGENT_ENGINE_ENV"
46+
_TEST_AGENT_ENGINE_ENV_VALUE = "test_env_value"
47+
_TEST_SERVICE_ACCOUNT_EMAIL = "[email protected]"
48+
49+
50+
@pytest.fixture(scope="module")
51+
def google_auth_mock():
52+
with mock.patch.object(auth, "default") as google_auth_mock:
53+
google_auth_mock.return_value = (
54+
auth_credentials.AnonymousCredentials(),
55+
_TEST_PROJECT,
56+
)
57+
yield google_auth_mock
58+
59+
60+
@pytest.mark.usefixtures("google_auth_mock")
61+
class TestSandbox:
62+
def setup_method(self):
63+
importlib.reload(initializer)
64+
importlib.reload(aiplatform)
65+
importlib.reload(vertexai)
66+
os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE
67+
self.client = vertexai.Client(
68+
project=_TEST_PROJECT,
69+
location=_TEST_LOCATION,
70+
credentials=_TEST_CREDENTIALS,
71+
)
72+
73+
def teardown_method(self):
74+
initializer.global_pool.shutdown(wait=True)
75+
76+
@mock.patch.object(client.Client, "_get_api_client")
77+
def test_send_command(self, mock_get_api_client):
78+
mock_sandbox = mock.Mock()
79+
mock_sandbox.connection_info.load_balancer_ip = "127.0.0.1"
80+
mock_sandbox.connection_info.load_balancer_hostname = None
81+
mock_http_client = mock_get_api_client.return_value
82+
mock_http_client.request.return_value = (
83+
genai_types.HttpResponse(body=b"{}", headers={})
84+
)
85+
86+
self.client.agent_engines.sandboxes.send_command(
87+
http_method="GET",
88+
access_token="test_token",
89+
sandbox_environment=mock_sandbox,
90+
path="test/path",
91+
)
92+
93+
call_args = mock_get_api_client.call_args
94+
assert call_args is not None
95+
_, kwargs = call_args
96+
http_options = kwargs["http_options"]
97+
assert http_options.base_url == "http://127.0.0.1/test/path"
98+
assert http_options.headers["Authorization"] == "Bearer test_token"
99+
100+
mock_http_client.request.assert_called_with("GET", "test/path", {})

vertexai/_genai/sandboxes.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
import json
2020
import logging
2121
import mimetypes
22+
import secrets
23+
import time
2224
from typing import Any, Iterator, Optional, Union
2325
from urllib.parse import urlencode
2426

27+
from google import genai
28+
from google.cloud import iam_credentials_v1
2529
from google.genai import _api_module
2630
from google.genai import _common
31+
from google.genai import types as genai_types
2732
from google.genai._common import get_value_by_path as getv
2833
from google.genai._common import set_value_by_path as setv
2934
from google.genai.pagers import Pager
@@ -704,6 +709,161 @@ def delete(
704709
"""
705710
return self._delete(name=name, config=config)
706711

712+
def generate_access_token(
713+
self,
714+
service_account_email: str,
715+
sandbox_id: str,
716+
port: str = "8080",
717+
timeout: int = 3600,
718+
) -> str:
719+
"""Signs a JWT with a Google Cloud service account.
720+
721+
Args:
722+
service_account_email (str):
723+
Required. The email of the service account to use for signing.
724+
sandbox_id (str):
725+
Required. The resource name of the sandbox to generate a token for.
726+
port (str):
727+
Optional. The port to use for the token. Defaults to "8080".
728+
timeout (int):
729+
Optional. The timeout in seconds for the token. Defaults to 3600.
730+
731+
Returns:
732+
str: The signed JWT.
733+
"""
734+
client = iam_credentials_v1.IAMCredentialsClient()
735+
name = f"projects/-/serviceAccounts/{service_account_email}"
736+
custom_claims = {"port": port, "sandbox_id": sandbox_id}
737+
payload = {
738+
"iat": int(time.time()),
739+
"exp": int(time.time()) + timeout,
740+
"iss": service_account_email,
741+
"nonce": secrets.randbelow(1000000000) + 1,
742+
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
743+
**custom_claims,
744+
}
745+
request = iam_credentials_v1.SignJwtRequest(
746+
name=name,
747+
payload=json.dumps(payload),
748+
)
749+
response = client.sign_jwt(request=request)
750+
return response.signed_jwt
751+
752+
def send_command(
753+
self,
754+
*,
755+
http_method: str,
756+
access_token: str,
757+
sandbox_environment: types.SandboxEnvironment,
758+
path: str = None,
759+
query_params: Optional[dict[str, object]] = None,
760+
headers: Optional[dict[str, str]] = None,
761+
request_dict: Optional[dict[str, object]] = None,
762+
) -> genai_types.HttpResponse:
763+
"""Sends a command to the sandbox.
764+
765+
Args:
766+
http_method (str):
767+
Required. The HTTP method to use for the command.
768+
access_token (str):
769+
Required. The access token to use for authorization.
770+
sandbox_environment (types.SandboxEnvironment):
771+
Required. The sandbox environment to send the command to.
772+
path (str):
773+
Optional. The path to send the command to.
774+
query_params (dict[str, object]):
775+
Optional. The query parameters to include in the command.
776+
headers (dict[str, str]):
777+
Optional. The headers to include in the command.
778+
request_dict (dict[str, object]):
779+
Optional. The request body to include in the command.
780+
781+
Returns:
782+
genai_types.HttpResponse: The response from the sandbox.
783+
"""
784+
headers = headers or {}
785+
request_dict = request_dict or {}
786+
connection_info = sandbox_environment.connection_info
787+
if not connection_info:
788+
raise ValueError("Connection info is not available.")
789+
if connection_info.load_balancer_hostname:
790+
endpoint = "https://" + connection_info.load_balancer_hostname
791+
elif connection_info.load_balancer_ip:
792+
endpoint = "http://" + connection_info.load_balancer_ip
793+
else:
794+
raise ValueError("Load balancer hostname or ip is not available.")
795+
796+
path = path or ""
797+
if query_params:
798+
path = f"{path}?{urlencode(query_params)}"
799+
headers["Authorization"] = f"Bearer {access_token}"
800+
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
801+
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
802+
http_client = genai.Client(vertexai=True, http_options=http_options)
803+
# Full path is constructed in this function. The passed in path into request
804+
# function will not be used.
805+
response = http_client._api_client.request(http_method, path, request_dict)
806+
return genai_types.HttpResponse(
807+
headers=response.headers,
808+
body=response.body,
809+
)
810+
811+
def generate_browser_ws_headers(
812+
self,
813+
sandbox_environment: types.SandboxEnvironment,
814+
service_account_email: str,
815+
timeout: int = 3600,
816+
) -> tuple[str, dict[str, str]]:
817+
"""Generates the websocket upgrade headers for the browser.
818+
819+
Args:
820+
sandbox_environment (types.SandboxEnvironment):
821+
Required. The sandbox environment to generate websocket headers for.
822+
service_account_email (str):
823+
Required. The email of the service account to use for signing.
824+
timeout (int):
825+
Optional. The timeout in seconds for the token. Defaults to 3600.
826+
827+
Returns:
828+
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
829+
the headers for websocket upgrade.
830+
"""
831+
sandbox_id = sandbox_environment.name
832+
# port 8080 is the default port for http endpoint.
833+
http_access_token = self.generate_access_token(
834+
service_account_email, sandbox_id, "8080", timeout
835+
)
836+
response = self.send_command(
837+
http_method="GET",
838+
access_token=http_access_token,
839+
sandbox_environment=sandbox_environment,
840+
path="/cdp_ws_endpoint",
841+
)
842+
if not response:
843+
raise ValueError("Failed to get the websocket endpoint.")
844+
body_dict = json.loads(response.body)
845+
ws_path = body_dict["endpoint"]
846+
847+
ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
848+
if sandbox_environment and sandbox_environment.connection_info:
849+
connection_info = sandbox_environment.connection_info
850+
if connection_info.load_balancer_hostname:
851+
ws_url = "wss://" + connection_info.load_balancer_hostname
852+
elif connection_info.load_balancer_ip:
853+
ws_url = "ws://" + connection_info.load_balancer_ip
854+
else:
855+
raise ValueError("Load balancer hostname or ip is not available.")
856+
ws_url = ws_url + "/" + ws_path
857+
858+
# port 9222 is the default port for the browser websocket endpoint.
859+
ws_access_token = self.generate_access_token(
860+
service_account_email, sandbox_id, "9222", timeout
861+
)
862+
863+
headers = {}
864+
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
865+
return ws_url, headers
866+
707867

708868
class AsyncSandboxes(_api_module.BaseModule):
709869

0 commit comments

Comments
 (0)