Skip to content

Commit 11c23a3

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Support agent engine sandbox http request in genai sdk
PiperOrigin-RevId: 853935707
1 parent 1c3b451 commit 11c23a3

File tree

3 files changed

+259
-0
lines changed

3 files changed

+259
-0
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
"opentelemetry-exporter-otlp-proto-http < 2",
169169
"pydantic >= 2.11.1, < 3",
170170
"typing_extensions",
171+
"google-cloud-iam",
171172
]
172173

173174
evaluation_extra_require = [
@@ -256,6 +257,7 @@
256257
"bigframes; python_version>='3.10' and python_version<'3.14'",
257258
# google-api-core 2.x is required since kfp requires protobuf > 4
258259
"google-api-core >= 2.11, < 3.0.0",
260+
"google-cloud-iam",
259261
"grpcio-testing",
260262
"grpcio-tools >= 1.63.0; python_version>='3.13'",
261263
"ipython",
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import importlib
17+
import os
18+
19+
from unittest import mock
20+
21+
from google import auth
22+
from google.auth import credentials as auth_credentials
23+
from google.cloud import aiplatform
24+
import vertexai
25+
from google.cloud.aiplatform import initializer
26+
from google.genai import client
27+
from google.genai import types as genai_types
28+
import pytest
29+
30+
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
31+
_TEST_LOCATION = "us-central1"
32+
_TEST_PROJECT = "test-project"
33+
_TEST_RESOURCE_ID = "1028944691210842416"
34+
_TEST_SANDBOX_ID = "sandbox-123"
35+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
36+
_TEST_AGENT_ENGINE_RESOURCE_NAME = (
37+
f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}"
38+
)
39+
_TEST_SANDBOX_RESOURCE_NAME = (
40+
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}/sandboxes/{_TEST_SANDBOX_ID}"
41+
)
42+
_TEST_AGENT_ENGINE_ENV_KEY = "GOOGLE_CLOUD_AGENT_ENGINE_ENV"
43+
_TEST_AGENT_ENGINE_ENV_VALUE = "test_env_value"
44+
_TEST_SERVICE_ACCOUNT_EMAIL = "[email protected]"
45+
46+
47+
@pytest.fixture(scope="module")
48+
def google_auth_mock():
49+
with mock.patch.object(auth, "default") as google_auth_mock:
50+
google_auth_mock.return_value = (
51+
auth_credentials.AnonymousCredentials(),
52+
_TEST_PROJECT,
53+
)
54+
yield google_auth_mock
55+
56+
57+
@pytest.mark.usefixtures("google_auth_mock")
58+
class TestSandbox:
59+
def setup_method(self):
60+
importlib.reload(initializer)
61+
importlib.reload(aiplatform)
62+
importlib.reload(vertexai)
63+
os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE
64+
self.client = vertexai.Client(
65+
project=_TEST_PROJECT,
66+
location=_TEST_LOCATION,
67+
credentials=_TEST_CREDENTIALS,
68+
)
69+
70+
def teardown_method(self):
71+
initializer.global_pool.shutdown(wait=True)
72+
73+
@mock.patch.object(client.Client, "_get_api_client")
74+
def test_send_command(self, mock_get_api_client):
75+
mock_sandbox = mock.Mock()
76+
mock_sandbox.connection_info.load_balancer_ip = "127.0.0.1"
77+
mock_sandbox.connection_info.load_balancer_hostname = None
78+
mock_http_client = mock_get_api_client.return_value
79+
mock_http_client.request.return_value = genai_types.HttpResponse(
80+
body=b"{}", headers={}
81+
)
82+
83+
self.client.agent_engines.sandboxes.send_command(
84+
http_method="GET",
85+
access_token="test_token",
86+
sandbox_environment=mock_sandbox,
87+
path="test/path",
88+
)
89+
90+
call_args = mock_get_api_client.call_args
91+
assert call_args is not None
92+
_, kwargs = call_args
93+
http_options = kwargs["http_options"]
94+
assert http_options.base_url == "http://127.0.0.1/test/path"
95+
assert http_options.headers["Authorization"] == "Bearer test_token"
96+
97+
mock_http_client.request.assert_called_with("GET", "test/path", {})

vertexai/_genai/sandboxes.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
import json
2020
import logging
2121
import mimetypes
22+
import secrets
23+
import time
2224
from typing import Any, Iterator, Optional, Union
2325
from urllib.parse import urlencode
2426

27+
from google import genai
28+
from google.cloud import iam_credentials_v1
2529
from google.genai import _api_module
2630
from google.genai import _common
31+
from google.genai import types as genai_types
2732
from google.genai._common import get_value_by_path as getv
2833
from google.genai._common import set_value_by_path as setv
2934
from google.genai.pagers import Pager
@@ -704,6 +709,161 @@ def delete(
704709
"""
705710
return self._delete(name=name, config=config)
706711

712+
def generate_access_token(
713+
self,
714+
service_account_email: str,
715+
sandbox_id: str,
716+
port: str = "8080",
717+
timeout: int = 3600,
718+
) -> str:
719+
"""Signs a JWT with a Google Cloud service account.
720+
721+
Args:
722+
service_account_email (str):
723+
Required. The email of the service account to use for signing.
724+
sandbox_id (str):
725+
Required. The resource name of the sandbox to generate a token for.
726+
port (str):
727+
Optional. The port to use for the token. Defaults to "8080".
728+
timeout (int):
729+
Optional. The timeout in seconds for the token. Defaults to 3600.
730+
731+
Returns:
732+
str: The signed JWT.
733+
"""
734+
client = iam_credentials_v1.IAMCredentialsClient()
735+
name = f"projects/-/serviceAccounts/{service_account_email}"
736+
custom_claims = {"port": port, "sandbox_id": sandbox_id}
737+
payload = {
738+
"iat": int(time.time()),
739+
"exp": int(time.time()) + timeout,
740+
"iss": service_account_email,
741+
"nonce": secrets.randbelow(1000000000) + 1,
742+
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
743+
**custom_claims,
744+
}
745+
request = iam_credentials_v1.SignJwtRequest(
746+
name=name,
747+
payload=json.dumps(payload),
748+
)
749+
response = client.sign_jwt(request=request)
750+
return response.signed_jwt
751+
752+
def send_command(
753+
self,
754+
*,
755+
http_method: str,
756+
access_token: str,
757+
sandbox_environment: types.SandboxEnvironment,
758+
path: str = None,
759+
query_params: Optional[dict[str, object]] = None,
760+
headers: Optional[dict[str, str]] = None,
761+
request_dict: Optional[dict[str, object]] = None,
762+
) -> genai_types.HttpResponse:
763+
"""Sends a command to the sandbox.
764+
765+
Args:
766+
http_method (str):
767+
Required. The HTTP method to use for the command.
768+
access_token (str):
769+
Required. The access token to use for authorization.
770+
sandbox_environment (types.SandboxEnvironment):
771+
Required. The sandbox environment to send the command to.
772+
path (str):
773+
Optional. The path to send the command to.
774+
query_params (dict[str, object]):
775+
Optional. The query parameters to include in the command.
776+
headers (dict[str, str]):
777+
Optional. The headers to include in the command.
778+
request_dict (dict[str, object]):
779+
Optional. The request body to include in the command.
780+
781+
Returns:
782+
genai_types.HttpResponse: The response from the sandbox.
783+
"""
784+
headers = headers or {}
785+
request_dict = request_dict or {}
786+
connection_info = sandbox_environment.connection_info
787+
if not connection_info:
788+
raise ValueError("Connection info is not available.")
789+
if connection_info.load_balancer_hostname:
790+
endpoint = "https://" + connection_info.load_balancer_hostname
791+
elif connection_info.load_balancer_ip:
792+
endpoint = "http://" + connection_info.load_balancer_ip
793+
else:
794+
raise ValueError("Load balancer hostname or ip is not available.")
795+
796+
path = path or ""
797+
if query_params:
798+
path = f"{path}?{urlencode(query_params)}"
799+
headers["Authorization"] = f"Bearer {access_token}"
800+
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
801+
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
802+
http_client = genai.Client(vertexai=True, http_options=http_options)
803+
# Full path is constructed in this function. The passed in path into request
804+
# function will not be used.
805+
response = http_client._api_client.request(http_method, path, request_dict)
806+
return genai_types.HttpResponse(
807+
headers=response.headers,
808+
body=response.body,
809+
)
810+
811+
def generate_browser_ws_headers(
812+
self,
813+
sandbox_environment: types.SandboxEnvironment,
814+
service_account_email: str,
815+
timeout: int = 3600,
816+
) -> tuple[str, dict[str, str]]:
817+
"""Generates the websocket upgrade headers for the browser.
818+
819+
Args:
820+
sandbox_environment (types.SandboxEnvironment):
821+
Required. The sandbox environment to generate websocket headers for.
822+
service_account_email (str):
823+
Required. The email of the service account to use for signing.
824+
timeout (int):
825+
Optional. The timeout in seconds for the token. Defaults to 3600.
826+
827+
Returns:
828+
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
829+
the headers for websocket upgrade.
830+
"""
831+
sandbox_id = sandbox_environment.name
832+
# port 8080 is the default port for http endpoint.
833+
http_access_token = self.generate_access_token(
834+
service_account_email, sandbox_id, "8080", timeout
835+
)
836+
response = self.send_command(
837+
http_method="GET",
838+
access_token=http_access_token,
839+
sandbox_environment=sandbox_environment,
840+
path="/cdp_ws_endpoint",
841+
)
842+
if not response:
843+
raise ValueError("Failed to get the websocket endpoint.")
844+
body_dict = json.loads(response.body)
845+
ws_path = body_dict["endpoint"]
846+
847+
ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
848+
if sandbox_environment and sandbox_environment.connection_info:
849+
connection_info = sandbox_environment.connection_info
850+
if connection_info.load_balancer_hostname:
851+
ws_url = "wss://" + connection_info.load_balancer_hostname
852+
elif connection_info.load_balancer_ip:
853+
ws_url = "ws://" + connection_info.load_balancer_ip
854+
else:
855+
raise ValueError("Load balancer hostname or ip is not available.")
856+
ws_url = ws_url + "/" + ws_path
857+
858+
# port 9222 is the default port for the browser websocket endpoint.
859+
ws_access_token = self.generate_access_token(
860+
service_account_email, sandbox_id, "9222", timeout
861+
)
862+
863+
headers = {}
864+
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
865+
return ws_url, headers
866+
707867

708868
class AsyncSandboxes(_api_module.BaseModule):
709869

0 commit comments

Comments
 (0)