11from __future__ import annotations
22
3+ import os
34import time
4- from functools import lru_cache
5+ from functools import lru_cache , partial
56from pathlib import Path
67
78import magic
9+ import vertexai
810from google import genai
9- from google .auth .credentials import Credentials
11+ from google .auth .exceptions import GoogleAuthError
1012from google .genai import types
13+ from google .genai .errors import APIError
1114from google .genai .types import (
1215 CountTokensConfig ,
1316 File ,
4245 Type ,
4346 Union ,
4447)
48+ from vertexai .generative_models import GenerativeModel , SafetySetting
4549
4650from patchwork .common .client .llm .protocol import NOT_GIVEN , LlmClient , NotGiven
4751from patchwork .common .client .llm .utils import json_schema_to_model
52+ from patchwork .logger import logger
4853
4954
5055class GoogleLlmClient (LlmClient ):
@@ -53,6 +58,28 @@ class GoogleLlmClient(LlmClient):
5358 dict (category = "HARM_CATEGORY_SEXUALLY_EXPLICIT" , threshold = "BLOCK_NONE" ),
5459 dict (category = "HARM_CATEGORY_DANGEROUS_CONTENT" , threshold = "BLOCK_NONE" ),
5560 dict (category = "HARM_CATEGORY_HARASSMENT" , threshold = "BLOCK_NONE" ),
61+ dict (category = "HARM_CATEGORY_CIVIC_INTEGRITY" , threshold = "BLOCK_NONE" ),
62+ ]
63+ __VERTEX_SAFETY_SETTINGS = [
64+ SafetySetting (
65+ category = SafetySetting .HarmCategory .HARM_CATEGORY_HATE_SPEECH ,
66+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
67+ ),
68+ SafetySetting (
69+ category = SafetySetting .HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT ,
70+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
71+ ),
72+ SafetySetting (
73+ category = SafetySetting .HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT ,
74+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
75+ ),
76+ SafetySetting (
77+ category = SafetySetting .HarmCategory .HARM_CATEGORY_HARASSMENT , threshold = SafetySetting .HarmBlockThreshold .OFF
78+ ),
79+ SafetySetting (
80+ category = SafetySetting .HarmCategory .HARM_CATEGORY_CIVIC_INTEGRITY ,
81+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
82+ ),
5683 ]
5784 __MODEL_PREFIX = "models/"
5885
@@ -63,6 +90,12 @@ def __init__(self, api_key: Optional[str] = None, is_gcp: bool = False):
6390 self .client = genai .Client (api_key = api_key )
6491 else :
6592 self .client = genai .Client (api_key = api_key , vertexai = True )
93+ location = os .environ .get ("GOOGLE_CLOUD_LOCATION" , "global" )
94+ vertexai .init (
95+ project = os .environ .get ("GOOGLE_CLOUD_PROJECT" ),
96+ location = location ,
97+ api_endpoint = f"{ location } -aiplatform.googleapis.com" ,
98+ )
6699
67100 @lru_cache (maxsize = 1 )
68101 def __get_models_info (self ) -> list [Model ]:
@@ -173,6 +206,8 @@ def is_prompt_supported(
173206 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
174207 file : Path | NotGiven = NOT_GIVEN ,
175208 ) -> int :
209+ if self .__is_gcp :
210+ return 1
176211 system , contents = self .__openai_messages_to_google_messages (messages )
177212
178213 file_ref = self .__upload (file )
@@ -188,7 +223,12 @@ def is_prompt_supported(
188223 ),
189224 )
190225 token_count = token_response .total_tokens
226+ except GoogleAuthError :
227+ raise
228+ except APIError :
229+ raise
191230 except Exception as e :
231+ logger .debug (f"Error during token count at GoogleLlmClient: { e } " )
192232 return - 1
193233 model_limit = self .__get_model_limits (model )
194234 return model_limit - token_count
@@ -255,15 +295,25 @@ def chat_completion(
255295 if file_ref is not None :
256296 contents .append (file_ref )
257297
258- response = self .client .models .generate_content (
259- model = model ,
260- contents = contents ,
261- config = GenerateContentConfig (
262- system_instruction = system_content ,
263- safety_settings = self .__SAFETY_SETTINGS ,
264- ** NotGiven .remove_not_given (generation_dict ),
265- ),
266- )
298+ if not self .__is_gcp :
299+ generate_content_func = partial (
300+ self .client .models .generate_content ,
301+ model = model ,
302+ config = GenerateContentConfig (
303+ system_instruction = system_content ,
304+ safety_settings = self .__SAFETY_SETTINGS ,
305+ ** NotGiven .remove_not_given (generation_dict ),
306+ ),
307+ )
308+ else :
309+ vertexai_model = GenerativeModel (model , system_instruction = system_content )
310+ generate_content_func = partial (
311+ vertexai_model .generate_content ,
312+ safety_settings = self .__VERTEX_SAFETY_SETTINGS ,
313+ generation_config = NotGiven .remove_not_given (generation_dict ),
314+ )
315+
316+ response = generate_content_func (contents = contents )
267317 return self .__google_response_to_openai_response (response , model )
268318
269319 @staticmethod
0 commit comments