11from __future__ import annotations
22
3- import functools
43import time
5-
6- from google import generativeai
7- from google .generativeai .types .content_types import (
8- add_object_type ,
9- convert_to_nullable ,
10- strip_titles ,
11- unpack_defs ,
4+ from functools import lru_cache
5+ from pathlib import Path
6+
7+ from google import genai
8+ from google .genai .types import (
9+ CountTokensConfig ,
10+ File ,
11+ GenerateContentConfig ,
12+ GenerateContentResponse ,
13+ Model ,
1214)
13- from google .generativeai .types .generation_types import GenerateContentResponse
14- from google .generativeai .types .model_types import Model
1515from openai .types import CompletionUsage
1616from openai .types .chat import (
1717 ChatCompletionMessage ,
2121 completion_create_params ,
2222)
2323from openai .types .chat .chat_completion import ChatCompletion , Choice
24- from typing_extensions import Any , Dict , Iterable , List , Optional , Union
24+ from pydantic import BaseModel
25+ from typing_extensions import Any , Dict , Iterable , List , Optional , Type , Union
2526
2627from patchwork .common .client .llm .protocol import NOT_GIVEN , LlmClient , NotGiven
2728from patchwork .common .client .llm .utils import json_schema_to_model
2829
2930
30- @functools .lru_cache
31- def _cached_list_model_from_google () -> list [Model ]:
32- return list (generativeai .list_models ())
33-
34-
3531class GoogleLlmClient (LlmClient ):
3632 __SAFETY_SETTINGS = [
3733 dict (category = "HARM_CATEGORY_HATE_SPEECH" , threshold = "BLOCK_NONE" ),
@@ -43,20 +39,45 @@ class GoogleLlmClient(LlmClient):
4339
4440 def __init__ (self , api_key : str ):
4541 self .__api_key = api_key
46- generativeai .configure (api_key = api_key )
42+ self .client = genai .Client (api_key = api_key )
43+
44+ @lru_cache (maxsize = 1 )
45+ def __get_models_info (self ) -> list [Model ]:
46+ return list (self .client .models .list ())
4747
4848 def __get_model_limits (self , model : str ) -> int :
49- for model_info in _cached_list_model_from_google ():
50- if model_info .name == f"{ self .__MODEL_PREFIX } { model } " :
49+ for model_info in self . __get_models_info ():
50+ if model_info .name == f"{ self .__MODEL_PREFIX } { model } " and model_info . input_token_limit is not None :
5151 return model_info .input_token_limit
5252 return 1_000_000
5353
54+ @lru_cache
5455 def get_models (self ) -> set [str ]:
55- return {model .name .removeprefix (self .__MODEL_PREFIX ) for model in _cached_list_model_from_google ()}
56+ return {model_info .name .removeprefix (self .__MODEL_PREFIX ) for model_info in self . __get_models_info ()}
5657
5758 def is_model_supported (self , model : str ) -> bool :
5859 return model in self .get_models ()
5960
61+ def __upload (self , file : Path | NotGiven ) -> File | None :
62+ if file is NotGiven :
63+ return None
64+
65+ try :
66+ file_ref = self .client .files .get (file .name )
67+ if file_ref .error is None :
68+ return file_ref
69+ except Exception as e :
70+ pass
71+
72+ try :
73+ file_ref = self .client .files .upload (file = file )
74+ if file_ref .error is None :
75+ return file_ref
76+ except Exception as e :
77+ pass
78+
79+ return None
80+
6081 def is_prompt_supported (
6182 self ,
6283 messages : Iterable [ChatCompletionMessageParam ],
@@ -74,11 +95,23 @@ def is_prompt_supported(
7495 tool_choice : ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN ,
7596 top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
7697 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
98+ file : Path | NotGiven = NOT_GIVEN ,
7799 ) -> int :
78100 system , chat = self .__openai_messages_to_google_messages (messages )
79- gen_model = generativeai .GenerativeModel (model_name = model , system_instruction = system )
101+
102+ file_ref = self .__upload (file )
103+ if file_ref is not None :
104+ chat .append (file_ref )
105+
80106 try :
81- token_count = gen_model .count_tokens (chat ).total_tokens
107+ token_response = self .client .models .count_tokens (
108+ model = model ,
109+ contents = chat ,
110+ config = CountTokensConfig (
111+ system_instructions = system ,
112+ ),
113+ )
114+ token_count = token_response .total_tokens
82115 except Exception as e :
83116 return - 1
84117 model_limit = self .__get_model_limits (model )
@@ -142,13 +175,15 @@ def chat_completion(
142175
143176 system_content , contents = self .__openai_messages_to_google_messages (messages )
144177
145- model_client = generativeai .GenerativeModel (
146- model_name = model ,
147- safety_settings = self .__SAFETY_SETTINGS ,
148- generation_config = NOT_GIVEN .remove_not_given (generation_dict ),
149- system_instruction = system_content ,
178+ response = self .client .models .generate_content (
179+ model = model ,
180+ contents = contents ,
181+ config = GenerateContentConfig (
182+ system_instruction = system_content ,
183+ safety_settings = self .__SAFETY_SETTINGS ,
184+ ** generation_dict ,
185+ ),
150186 )
151- response = model_client .generate_content (contents = contents )
152187 return self .__google_response_to_openai_response (response , model )
153188
154189 @staticmethod
@@ -191,18 +226,9 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
191226 )
192227
193228 @staticmethod
194- def json_schema_to_google_schema (json_schema : dict [str , Any ] | None ) -> dict [ str , Any ] | None :
229+ def json_schema_to_google_schema (json_schema : dict [str , Any ] | None ) -> Type [ BaseModel ] | None :
195230 if json_schema is None :
196231 return None
197232
198233 model = json_schema_to_model (json_schema )
199- parameters = model .model_json_schema ()
200- defs = parameters .pop ("$defs" , {})
201-
202- for name , value in defs .items ():
203- unpack_defs (value , defs )
204- unpack_defs (parameters , defs )
205- convert_to_nullable (parameters )
206- add_object_type (parameters )
207- strip_titles (parameters )
208- return parameters
234+ return model
0 commit comments