2525
2626from google import genai
2727from google .genai import errors
28+ from google .genai .types import HttpOptions
2829from google .genai .types import Part
2930from PIL .Image import Image
3031
@@ -108,6 +109,7 @@ def __init__(
108109 api_key : Optional [str ] = None ,
109110 project : Optional [str ] = None ,
110111 location : Optional [str ] = None ,
112+ use_vertex_flex_api : Optional [bool ] = False ,
111113 * ,
112114 min_batch_size : Optional [int ] = None ,
113115 max_batch_size : Optional [int ] = None ,
@@ -139,6 +141,13 @@ def __init__(
139141 location: the GCP project to use for Vertex AI requests. Setting this
140142 parameter routes requests to Vertex AI. If this paramter is provided,
141143 project must also be provided and api_key should not be set.
144+ use_vertex_flex_api: if true, use the Vertex Flex API. This is a
145+ cost-effective option for accessing Gemini models if you can tolerate
146+ longer response times and throttling. This is often beneficial for
147+ data processing workloads which usually have higher latency tolerance
148+ than live serving paths. See
149+ https://docs.cloud.google.com/vertex-ai/generative-ai/docs/flex-paygo
150+ for more details.
142151 min_batch_size: optional. the minimum batch size to use when batching
143152 inputs.
144153 max_batch_size: optional. the maximum batch size to use when batching
@@ -178,6 +187,8 @@ def __init__(
178187 self .location = location
179188 self .use_vertex = True
180189
190+ self .use_vertex_flex_api = use_vertex_flex_api
191+
181192 super ().__init__ (
182193 namespace = 'GeminiModelHandler' ,
183194 retry_filter = _retry_on_appropriate_service_error ,
@@ -192,8 +203,19 @@ def create_client(self) -> genai.Client:
192203 provided when the GeminiModelHandler class is instantiated.
193204 """
194205 if self .use_vertex :
195- return genai .Client (
196- vertexai = True , project = self .project , location = self .location )
206+ if self .use_vertex_flex_api :
207+ return genai .Client (
208+ vertexai = True ,
209+ project = self .project ,
210+ location = self .location ,
211+ http_options = HttpOptions (
212+ api_version = "v1" ,
213+ headers = {"X-Vertex-AI-LLM-Request-Type" : "flex" },
214+ # Set timeout in the unit of millisecond.
215+ timeout = 600000 ))
216+ else :
217+ return genai .Client (
218+ vertexai = True , project = self .project , location = self .location )
197219 return genai .Client (api_key = self .api_key )
198220
199221 def request (
0 commit comments