11from __future__ import annotations
22
3- import functools
43import time
5-
6- from google import generativeai
7- from google .generativeai .types .content_types import (
8- add_object_type ,
9- convert_to_nullable ,
10- strip_titles ,
11- unpack_defs ,
4+ from functools import lru_cache
5+ from pathlib import Path
6+
7+ import magic
8+ from google import genai
9+ from google .genai import types
10+ from google .genai .types import (
11+ CountTokensConfig ,
12+ File ,
13+ GenerateContentConfig ,
14+ GenerateContentResponse ,
15+ Model ,
16+ Part ,
1217)
13- from google .generativeai .types .generation_types import GenerateContentResponse
14- from google .generativeai .types .model_types import Model
1518from openai .types import CompletionUsage
1619from openai .types .chat import (
1720 ChatCompletionMessage ,
2730from pydantic_ai .settings import ModelSettings
2831from pydantic_ai .usage import Usage
2932from typing_extensions import Any , AsyncIterator , Dict , Iterable , List , Optional , Union
33+ from pydantic import BaseModel
3034
3135from patchwork .common .client .llm .protocol import NOT_GIVEN , LlmClient , NotGiven
3236from patchwork .common .client .llm .utils import json_schema_to_model
3337
3438
35- @functools .lru_cache
36- def _cached_list_model_from_google () -> list [Model ]:
37- return list (generativeai .list_models ())
38-
39-
4039class GoogleLlmClient (LlmClient ):
4140 __SAFETY_SETTINGS = [
4241 dict (category = "HARM_CATEGORY_HATE_SPEECH" , threshold = "BLOCK_NONE" ),
@@ -48,7 +47,11 @@ class GoogleLlmClient(LlmClient):
4847
4948 def __init__ (self , api_key : str ):
5049 self .__api_key = api_key
51- generativeai .configure (api_key = api_key )
50+ self .client = genai .Client (api_key = api_key )
51+
52+ @lru_cache (maxsize = 1 )
53+ def __get_models_info (self ) -> list [Model ]:
54+ return list (self .client .models .list ())
5255
5356 def __get_pydantic_model (self , model_settings : ModelSettings | None ) -> Model :
5457 if model_settings is None :
@@ -86,17 +89,47 @@ def system(self) -> str:
8689 return "google-gla"
8790
8891 def __get_model_limits (self , model : str ) -> int :
89- for model_info in _cached_list_model_from_google ():
90- if model_info .name == f"{ self .__MODEL_PREFIX } { model } " :
92+ for model_info in self . __get_models_info ():
93+ if model_info .name == f"{ self .__MODEL_PREFIX } { model } " and model_info . input_token_limit is not None :
9194 return model_info .input_token_limit
9295 return 1_000_000
9396
97+ @lru_cache
9498 def get_models (self ) -> set [str ]:
95- return {model .name .removeprefix (self .__MODEL_PREFIX ) for model in _cached_list_model_from_google ()}
99+ return {model_info .name .removeprefix (self .__MODEL_PREFIX ) for model_info in self . __get_models_info ()}
96100
97101 def is_model_supported (self , model : str ) -> bool :
98102 return model in self .get_models ()
99103
104+ def __upload (self , file : Path | NotGiven ) -> Part | File | None :
105+ if file is NotGiven :
106+ return None
107+
108+ file_bytes = file .read_bytes ()
109+
110+ try :
111+ mime_type = magic .Magic (mime = True , uncompress = True ).from_buffer (file_bytes )
112+ return types .Part .from_bytes (data = file_bytes , mime_type = mime_type )
113+ except Exception as e :
114+ pass
115+
116+ cleaned_name = "" .join ([char for char in file .name .lower () if char .isalnum ()])
117+ try :
118+ file_ref = self .client .files .get (name = cleaned_name )
119+ if file_ref .error is None :
120+ return file_ref
121+ except Exception as e :
122+ pass
123+
124+ try :
125+ file_ref = self .client .files .upload (file = file , config = dict (name = cleaned_name ))
126+ if file_ref .error is None :
127+ return file_ref
128+ except Exception as e :
129+ pass
130+
131+ return None
132+
100133 def is_prompt_supported (
101134 self ,
102135 messages : Iterable [ChatCompletionMessageParam ],
@@ -114,11 +147,23 @@ def is_prompt_supported(
114147 tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
115148 top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
116149 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
150+ file : Path | NotGiven = NOT_GIVEN ,
117151 ) -> int :
118- system , chat = self .__openai_messages_to_google_messages (messages )
119- gen_model = generativeai .GenerativeModel (model_name = model , system_instruction = system )
152+ system , contents = self .__openai_messages_to_google_messages (messages )
153+
154+ file_ref = self .__upload (file )
155+ if file_ref is not None :
156+ contents .append (file_ref )
157+
120158 try :
121- token_count = gen_model .count_tokens (chat ).total_tokens
159+ token_response = self .client .models .count_tokens (
160+ model = model ,
161+ contents = contents ,
162+ config = CountTokensConfig (
163+ system_instruction = system ,
164+ ),
165+ )
166+ token_count = token_response .total_tokens
122167 except Exception as e :
123168 return - 1
124169 model_limit = self .__get_model_limits (model )
@@ -162,6 +207,7 @@ def chat_completion(
162207 tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
163208 top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
164209 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
210+ file : Path | NotGiven = NOT_GIVEN ,
165211 ) -> ChatCompletion :
166212 generation_dict = dict (
167213 stop_sequences = [stop ] if isinstance (stop , str ) else stop ,
@@ -181,20 +227,25 @@ def chat_completion(
181227 )
182228
183229 system_content , contents = self .__openai_messages_to_google_messages (messages )
230+ file_ref = self .__upload (file )
231+ if file_ref is not None :
232+ contents .append (file_ref )
184233
185- model_client = generativeai .GenerativeModel (
186- model_name = model ,
187- safety_settings = self .__SAFETY_SETTINGS ,
188- generation_config = NOT_GIVEN .remove_not_given (generation_dict ),
189- system_instruction = system_content ,
234+ response = self .client .models .generate_content (
235+ model = model ,
236+ contents = contents ,
237+ config = GenerateContentConfig (
238+ system_instruction = system_content ,
239+ safety_settings = self .__SAFETY_SETTINGS ,
240+ ** NotGiven .remove_not_given (generation_dict ),
241+ ),
190242 )
191- response = model_client .generate_content (contents = contents )
192243 return self .__google_response_to_openai_response (response , model )
193244
194245 @staticmethod
195246 def __google_response_to_openai_response (google_response : GenerateContentResponse , model : str ) -> ChatCompletion :
196247 choices = []
197- for candidate in google_response .candidates :
248+ for index , candidate in enumerate ( google_response .candidates ) :
198249 # note that instead of system, from openai, its model, from google.
199250 parts = [part .text or part .inline_data for part in candidate .content .parts ]
200251
@@ -207,7 +258,7 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
207258
208259 choice = Choice (
209260 finish_reason = finish_reason_map .get (candidate .finish_reason , "stop" ),
210- index = candidate . index ,
261+ index = index ,
211262 message = ChatCompletionMessage (
212263 content = "\n " .join (parts ),
213264 role = "assistant" ,
@@ -231,18 +282,9 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
231282 )
232283
233284 @staticmethod
234- def json_schema_to_google_schema (json_schema : dict [str , Any ] | None ) -> dict [ str , Any ] | None :
285+ def json_schema_to_google_schema (json_schema : dict [str , Any ] | None ) -> Type [ BaseModel ] | None :
235286 if json_schema is None :
236287 return None
237288
238289 model = json_schema_to_model (json_schema )
239- parameters = model .model_json_schema ()
240- defs = parameters .pop ("$defs" , {})
241-
242- for name , value in defs .items ():
243- unpack_defs (value , defs )
244- unpack_defs (parameters , defs )
245- convert_to_nullable (parameters )
246- add_object_type (parameters )
247- strip_titles (parameters )
248- return parameters
290+ return model
0 commit comments