Skip to content

Commit 6dcb908

Browse files
authored
Add back flex API option with improved pydoc (#37491)
* Reapply "Support Vertex Flex API in GeminiModelHandler (#36982)" (#37051) This reverts commit 585ad41. * Add pydoc
1 parent 257fa5a commit 6dcb908

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from google import genai
2727
from google.genai import errors
28+
from google.genai.types import HttpOptions
2829
from google.genai.types import Part
2930
from PIL.Image import Image
3031

@@ -108,6 +109,7 @@ def __init__(
108109
api_key: Optional[str] = None,
109110
project: Optional[str] = None,
110111
location: Optional[str] = None,
112+
use_vertex_flex_api: Optional[bool] = False,
111113
*,
112114
min_batch_size: Optional[int] = None,
113115
max_batch_size: Optional[int] = None,
@@ -139,6 +141,13 @@ def __init__(
139141
location: the GCP project to use for Vertex AI requests. Setting this
140142
parameter routes requests to Vertex AI. If this paramter is provided,
141143
project must also be provided and api_key should not be set.
144+
use_vertex_flex_api: if true, use the Vertex Flex API. This is a
145+
cost-effective option for accessing Gemini models if you can tolerate
146+
longer response times and throttling. This is often beneficial for
147+
data processing workloads which usually have higher latency tolerance
148+
than live serving paths. See
149+
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/flex-paygo
150+
for more details.
142151
min_batch_size: optional. the minimum batch size to use when batching
143152
inputs.
144153
max_batch_size: optional. the maximum batch size to use when batching
@@ -178,6 +187,8 @@ def __init__(
178187
self.location = location
179188
self.use_vertex = True
180189

190+
self.use_vertex_flex_api = use_vertex_flex_api
191+
181192
super().__init__(
182193
namespace='GeminiModelHandler',
183194
retry_filter=_retry_on_appropriate_service_error,
@@ -192,8 +203,19 @@ def create_client(self) -> genai.Client:
192203
provided when the GeminiModelHandler class is instantiated.
193204
"""
194205
if self.use_vertex:
195-
return genai.Client(
196-
vertexai=True, project=self.project, location=self.location)
206+
if self.use_vertex_flex_api:
207+
return genai.Client(
208+
vertexai=True,
209+
project=self.project,
210+
location=self.location,
211+
http_options=HttpOptions(
212+
api_version="v1",
213+
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
214+
# Set timeout in the unit of millisecond.
215+
timeout=600000))
216+
else:
217+
return genai.Client(
218+
vertexai=True, project=self.project, location=self.location)
197219
return genai.Client(api_key=self.api_key)
198220

199221
def request(

sdks/python/apache_beam/ml/inference/gemini_inference_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pytype: skip-file
1818

1919
import unittest
20+
from unittest import mock
2021

2122
try:
2223
from google.genai import errors
@@ -81,5 +82,29 @@ def test_missing_all_params(self):
8182
)
8283

8384

85+
@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client')
86+
@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions')
87+
class TestGeminiModelHandler(unittest.TestCase):
88+
def test_create_client_with_flex_api(
89+
self, mock_http_options, mock_genai_client):
90+
handler = GeminiModelHandler(
91+
model_name="gemini-pro",
92+
request_fn=generate_from_string,
93+
project="test-project",
94+
location="us-central1",
95+
use_vertex_flex_api=True)
96+
handler.create_client()
97+
mock_http_options.assert_called_with(
98+
api_version="v1",
99+
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
100+
timeout=600000,
101+
)
102+
mock_genai_client.assert_called_with(
103+
vertexai=True,
104+
project="test-project",
105+
location="us-central1",
106+
http_options=mock_http_options.return_value)
107+
108+
84109
if __name__ == '__main__':
85110
unittest.main()

0 commit comments

Comments
 (0)