Skip to content

Commit 8329b52

Browse files
committed
Support Vertex Flex API in GeminiModelHandler
1 parent c9d2e60 commit 8329b52

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Union
2525

2626
from google import genai
27+
from google.api_core.client_options import ClientOptions as HttpOptions
2728
from google.genai import errors
2829
from google.genai.types import Part
2930
from PIL.Image import Image
@@ -108,6 +109,7 @@ def __init__(
108109
api_key: Optional[str] = None,
109110
project: Optional[str] = None,
110111
location: Optional[str] = None,
112+
use_vertex_flex_api: Optional[bool]=False,
111113
*,
112114
min_batch_size: Optional[int] = None,
113115
max_batch_size: Optional[int] = None,
@@ -169,6 +171,8 @@ def __init__(
169171
self.location = location
170172
self.use_vertex = True
171173

174+
self.use_vertex_flex_api = use_vertex_flex_api
175+
172176
super().__init__(
173177
namespace='GeminiModelHandler',
174178
retry_filter=_retry_on_appropriate_service_error,
@@ -180,7 +184,17 @@ def create_client(self) -> genai.Client:
180184
provided when the GeminiModelHandler class is instantiated.
181185
"""
182186
if self.use_vertex:
183-
return genai.Client(
187+
if self.use_vertex_flex_api:
188+
return genai.Client(
189+
vertexai=True, project=self.project, location=self.location,
190+
http_options=HttpOptions(
191+
api_version="v1",
192+
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
193+
# Set timeout in the unit of millisecond.
194+
timeout = 600000
195+
))
196+
else:
197+
return genai.Client(
184198
vertexai=True, project=self.project, location=self.location)
185199
return genai.Client(api_key=self.api_key)
186200

sdks/python/apache_beam/ml/inference/gemini_inference_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,31 @@ def test_missing_all_params(self):
8181
)
8282

8383

84+
@unittest.mock.patch(
85+
'apache_beam.ml.inference.gemini_inference.genai.Client')
86+
@unittest.mock.patch(
87+
'apache_beam.ml.inference.gemini_inference.HttpOptions')
88+
class TestGeminiModelHandler(unittest.TestCase):
89+
def test_create_client_with_flex_api(
90+
self, mock_http_options, mock_genai_client):
91+
handler = GeminiModelHandler(
92+
model_name="gemini-pro",
93+
request_fn=generate_from_string,
94+
project="test-project",
95+
location="us-central1",
96+
use_vertex_flex_api=True)
97+
handler.create_client()
98+
mock_http_options.assert_called_with(
99+
api_version="v1",
100+
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
101+
timeout=600000,
102+
)
103+
mock_genai_client.assert_called_with(
104+
vertexai=True,
105+
project="test-project",
106+
location="us-central1",
107+
http_options=mock_http_options.return_value)
108+
109+
84110
if __name__ == '__main__':
85111
unittest.main()

0 commit comments

Comments
 (0)