11from __future__ import annotations
22
3- import functools
43import time
5-
6- from google import generativeai
7- from google .generativeai .types .content_types import (
8- add_object_type ,
9- convert_to_nullable ,
10- strip_titles ,
11- unpack_defs ,
4+ from functools import lru_cache
5+ from pathlib import Path
6+
7+ import magic
8+ from google import genai
9+ from google .genai import types
10+ from google .genai .types import (
11+ CountTokensConfig ,
12+ File ,
13+ GenerateContentConfig ,
14+ GenerateContentResponse ,
15+ Model ,
16+ Part ,
1217)
13- from google .generativeai .types .generation_types import GenerateContentResponse
14- from google .generativeai .types .model_types import Model
1518from openai .types import CompletionUsage
1619from openai .types .chat import (
1720 ChatCompletionMessage ,
2124 completion_create_params ,
2225)
2326from openai .types .chat .chat_completion import ChatCompletion , Choice
24- from typing_extensions import Any , Dict , Iterable , List , Optional , Union
27+ from pydantic import BaseModel
28+ from typing_extensions import Any , Dict , Iterable , List , Optional , Type , Union
2529
2630from patchwork .common .client .llm .protocol import NOT_GIVEN , LlmClient , NotGiven
2731from patchwork .common .client .llm .utils import json_schema_to_model
2832
2933
30- @functools .lru_cache
31- def _cached_list_model_from_google () -> list [Model ]:
32- return list (generativeai .list_models ())
33-
34-
3534class GoogleLlmClient (LlmClient ):
3635 __SAFETY_SETTINGS = [
3736 dict (category = "HARM_CATEGORY_HATE_SPEECH" , threshold = "BLOCK_NONE" ),
@@ -43,20 +42,54 @@ class GoogleLlmClient(LlmClient):
4342
4443 def __init__ (self , api_key : str ):
4544 self .__api_key = api_key
46- generativeai .configure (api_key = api_key )
45+ self .client = genai .Client (api_key = api_key )
46+
47+ @lru_cache (maxsize = 1 )
48+ def __get_models_info (self ) -> list [Model ]:
49+ return list (self .client .models .list ())
4750
4851 def __get_model_limits (self , model : str ) -> int :
49- for model_info in _cached_list_model_from_google ():
50- if model_info .name == f"{ self .__MODEL_PREFIX } { model } " :
52+ for model_info in self . __get_models_info ():
53+ if model_info .name == f"{ self .__MODEL_PREFIX } { model } " and model_info . input_token_limit is not None :
5154 return model_info .input_token_limit
5255 return 1_000_000
5356
57+ @lru_cache
5458 def get_models (self ) -> set [str ]:
55- return {model .name .removeprefix (self .__MODEL_PREFIX ) for model in _cached_list_model_from_google ()}
59+ return {model_info .name .removeprefix (self .__MODEL_PREFIX ) for model_info in self . __get_models_info ()}
5660
5761 def is_model_supported (self , model : str ) -> bool :
5862 return model in self .get_models ()
5963
64+ def __upload (self , file : Path | NotGiven ) -> Part | File | None :
65+ if file is NotGiven :
66+ return None
67+
68+ file_bytes = file .read_bytes ()
69+
70+ try :
71+ mime_type = magic .Magic (mime = True , uncompress = True ).from_buffer (file_bytes )
72+ return types .Part .from_bytes (data = file_bytes , mime_type = mime_type )
73+ except Exception as e :
74+ pass
75+
76+ cleaned_name = "" .join ([char for char in file .name .lower () if char .isalnum ()])
77+ try :
78+ file_ref = self .client .files .get (name = cleaned_name )
79+ if file_ref .error is None :
80+ return file_ref
81+ except Exception as e :
82+ pass
83+
84+ try :
85+ file_ref = self .client .files .upload (file = file , config = dict (name = cleaned_name ))
86+ if file_ref .error is None :
87+ return file_ref
88+ except Exception as e :
89+ pass
90+
91+ return None
92+
6093 def is_prompt_supported (
6194 self ,
6295 messages : Iterable [ChatCompletionMessageParam ],
@@ -74,11 +107,23 @@ def is_prompt_supported(
74107 tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
75108 top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
76109 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
110+ file : Path | NotGiven = NOT_GIVEN ,
77111 ) -> int :
78- system , chat = self .__openai_messages_to_google_messages (messages )
79- gen_model = generativeai .GenerativeModel (model_name = model , system_instruction = system )
112+ system , contents = self .__openai_messages_to_google_messages (messages )
113+
114+ file_ref = self .__upload (file )
115+ if file_ref is not None :
116+ contents .append (file_ref )
117+
80118 try :
81- token_count = gen_model .count_tokens (chat ).total_tokens
119+ token_response = self .client .models .count_tokens (
120+ model = model ,
121+ contents = contents ,
122+ config = CountTokensConfig (
123+ system_instruction = system ,
124+ ),
125+ )
126+ token_count = token_response .total_tokens
82127 except Exception as e :
83128 return - 1
84129 model_limit = self .__get_model_limits (model )
@@ -122,6 +167,7 @@ def chat_completion(
122167 tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
123168 top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
124169 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
170+ file : Path | NotGiven = NOT_GIVEN ,
125171 ) -> ChatCompletion :
126172 generation_dict = dict (
127173 stop_sequences = [stop ] if isinstance (stop , str ) else stop ,
@@ -141,20 +187,25 @@ def chat_completion(
141187 )
142188
143189 system_content , contents = self .__openai_messages_to_google_messages (messages )
190+ file_ref = self .__upload (file )
191+ if file_ref is not None :
192+ contents .append (file_ref )
144193
145- model_client = generativeai .GenerativeModel (
146- model_name = model ,
147- safety_settings = self .__SAFETY_SETTINGS ,
148- generation_config = NOT_GIVEN .remove_not_given (generation_dict ),
149- system_instruction = system_content ,
194+ response = self .client .models .generate_content (
195+ model = model ,
196+ contents = contents ,
197+ config = GenerateContentConfig (
198+ system_instruction = system_content ,
199+ safety_settings = self .__SAFETY_SETTINGS ,
200+ ** NotGiven .remove_not_given (generation_dict ),
201+ ),
150202 )
151- response = model_client .generate_content (contents = contents )
152203 return self .__google_response_to_openai_response (response , model )
153204
154205 @staticmethod
155206 def __google_response_to_openai_response (google_response : GenerateContentResponse , model : str ) -> ChatCompletion :
156207 choices = []
157- for candidate in google_response .candidates :
208+ for index , candidate in enumerate ( google_response .candidates ) :
158209 # note that instead of system, from openai, its model, from google.
159210 parts = [part .text or part .inline_data for part in candidate .content .parts ]
160211
@@ -167,7 +218,7 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
167218
168219 choice = Choice (
169220 finish_reason = finish_reason_map .get (candidate .finish_reason , "stop" ),
170- index = candidate . index ,
221+ index = index ,
171222 message = ChatCompletionMessage (
172223 content = "\n " .join (parts ),
173224 role = "assistant" ,
@@ -191,18 +242,9 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
191242 )
192243
193244 @staticmethod
194- def json_schema_to_google_schema (json_schema : dict [str , Any ] | None ) -> dict [ str , Any ] | None :
245+ def json_schema_to_google_schema (json_schema : dict [str , Any ] | None ) -> Type [ BaseModel ] | None :
195246 if json_schema is None :
196247 return None
197248
198249 model = json_schema_to_model (json_schema )
199- parameters = model .model_json_schema ()
200- defs = parameters .pop ("$defs" , {})
201-
202- for name , value in defs .items ():
203- unpack_defs (value , defs )
204- unpack_defs (parameters , defs )
205- convert_to_nullable (parameters )
206- add_object_type (parameters )
207- strip_titles (parameters )
208- return parameters
250+ return model
0 commit comments