88import logging
99
1010import torch
11- from .ModelWrapper import (LLMWrapper , HuggingFaceWrapper , OpenAIWrapper , AnthropicWrapper ,
11+ from .ModelWrapper import (LLMWrapper , HuggingFaceWrapper , OpenAIWrapper , GeminiWrapper ,
1212 DecoderOnlyWrapper , EncoderOnlyWrapper , EncoderDecoderWrapper )
1313
1414
@@ -31,7 +31,7 @@ def load_model(
3131
3232 Args:
3333 model_path_or_name: Path to local model or HuggingFace model name
34- model_type: Type of model ("auto", "huggingface", "openai", "anthropic")
34+ model_type: Type of model ("auto", "huggingface", "openai", "gemini", " anthropic")
3535 device: Device for computation
3636 **kwargs: Additional arguments for model loading
3737
@@ -48,29 +48,32 @@ def load_model(
4848 return self ._load_huggingface_model (model_path_or_name , device , ** kwargs )
4949 elif model_type == "openai" :
5050 return self ._load_openai_model (model_path_or_name , ** kwargs )
51- elif model_type == "anthropic " :
52- return self ._load_anthropic_model (model_path_or_name , ** kwargs )
51+ elif model_type == "gemini " :
52+ return self ._load_gemini_model (model_path_or_name , ** kwargs )
5353 else :
54- raise ValueError (f"Unsupported model type: { model_type } " )
54+ raise ValueError (f"Unsupported model type: { model_type } . Supported: huggingface, openai, gemini " )
5555
5656 def _detect_model_type (self , model_path_or_name : str ) -> str :
5757 """Auto-detect model type based on path/name patterns."""
58- # Check for OpenAI model names
59- openai_models = [
60- "gpt-3.5-turbo" , "gpt-4" , "gpt-4-turbo" , "gpt-4o" ,
61- "text-davinci-003" , "text-curie-001" , "text-babbage-001"
58+ # Check for OpenAI model names (including newer models)
59+ openai_prefixes = [
60+ "gpt-3.5" , "gpt-4" , "gpt-4o" , "gpt-4-turbo" ,
61+ "o1-" , "o3-" , # Reasoning models
62+ "text-davinci" , "text-curie" , "text-babbage" , "text-ada" ,
6263 ]
6364
64- if any (model_path_or_name .startswith (name ) for name in openai_models ):
65+ model_lower = model_path_or_name .lower ()
66+ if any (model_lower .startswith (prefix ) for prefix in openai_prefixes ):
6567 return "openai"
66-
67- # Check for Anthropic model names
68- anthropic_models = [
69- "claude-3" , "claude-2" , "claude-instant"
68+
69+ # Check for Google Gemini model names
70+ gemini_prefixes = [
71+ "gemini-" ,
72+ "models/gemini-" ,
73+ "gemini-pro" , # Older naming
7074 ]
71-
72- if any (model_path_or_name .startswith (name ) for name in anthropic_models ):
73- return "anthropic"
75+ if any (model_lower .startswith (prefix ) for prefix in gemini_prefixes ):
76+ return "gemini"
7477
7578 # Check if it's a local path
7679 if os .path .exists (model_path_or_name ):
@@ -206,35 +209,56 @@ def _load_openai_model(
206209 """Load OpenAI model."""
207210 # Get API key from environment if not provided
208211 if api_key is None :
209- api_key = os .getenv ("OPENAI_API_KEY" )
212+ api_key = os .getenv ("OPENAI_API_KEY" ) or os . getenv ( "APIKEY_OPENAI" )
210213
211214 if api_key is None :
212215 raise ValueError ("OpenAI API key required. Set OPENAI_API_KEY environment variable." )
216+
217+ allowed_kwargs = {
218+ "batch_poll_interval_seconds" ,
219+ "batch_timeout_seconds" ,
220+ "batch_max_requests" ,
221+ "prefer_batch_api" ,
222+ }
223+ wrapper_kwargs = {key : value for key , value in kwargs .items () if key in allowed_kwargs }
213224
214225 return OpenAIWrapper (
215226 model_name = model_name ,
216227 api_key = api_key ,
217- ** kwargs
228+ ** wrapper_kwargs
218229 )
219-
220- def _load_anthropic_model (
230+
231+ def _load_gemini_model (
221232 self ,
222233 model_name : str ,
223234 api_key : Optional [str ] = None ,
224235 ** kwargs
225- ) -> AnthropicWrapper :
226- """Load Anthropic model."""
227- # Get API key from environment if not provided
236+ ) -> GeminiWrapper :
237+ """Load Gemini model."""
228238 if api_key is None :
229- api_key = os .getenv ("ANTHROPIC_API_KEY" )
230-
239+ api_key = (
240+ os .getenv ("GEMINI_API_KEY" )
241+ or os .getenv ("GOOGLE_API_KEY" )
242+ or os .getenv ("APIKEY_GOOGLE" )
243+ )
244+
231245 if api_key is None :
232- raise ValueError ("Anthropic API key required. Set ANTHROPIC_API_KEY environment variable." )
233-
234- return AnthropicWrapper (
246+ raise ValueError ("Gemini API key required. Set GEMINI_API_KEY or GOOGLE_API_KEY." )
247+
248+ allowed_kwargs = {
249+ "api_base" ,
250+ "batch_poll_interval_seconds" ,
251+ "batch_timeout_seconds" ,
252+ "batch_max_requests" ,
253+ "batch_max_payload_bytes" ,
254+ "prefer_batch_api" ,
255+ }
256+ wrapper_kwargs = {key : value for key , value in kwargs .items () if key in allowed_kwargs }
257+
258+ return GeminiWrapper (
235259 model_name = model_name ,
236260 api_key = api_key ,
237- ** kwargs
261+ ** wrapper_kwargs
238262 )
239263
240264 def _is_unsupported_model (self , model_name : str ) -> bool :
@@ -272,8 +296,8 @@ def list_available_models(self, model_type: str = "huggingface") -> Dict[str, An
272296 return self ._list_huggingface_models ()
273297 elif model_type == "openai" :
274298 return self ._list_openai_models ()
275- elif model_type == "anthropic " :
276- return self ._list_anthropic_models ()
299+ elif model_type == "gemini " :
300+ return self ._list_gemini_models ()
277301 else :
278302 return {}
279303
@@ -333,17 +357,15 @@ def _list_openai_models(self) -> Dict[str, Any]:
333357 "text-ada-001"
334358 ]
335359 }
336-
337- def _list_anthropic_models (self ) -> Dict [str , Any ]:
338- """List available Anthropic models."""
360+
361+ def _list_gemini_models (self ) -> Dict [str , Any ]:
362+ """List available Gemini models."""
339363 return {
340364 "chat_models" : [
341- "claude-3-opus-20240229" ,
342- "claude-3-sonnet-20240229" ,
343- "claude-3-haiku-20240307" ,
344- "claude-2.1" ,
345- "claude-2.0" ,
346- "claude-instant-1.2"
365+ "gemini-2.0-flash" ,
366+ "gemini-2.0-flash-lite" ,
367+ "gemini-1.5-pro" ,
368+ "gemini-1.5-flash" ,
347369 ]
348370 }
349371
@@ -367,7 +389,7 @@ def get_model_info(self, model_path_or_name: str) -> Dict[str, Any]:
367389
368390 if model_type == "huggingface" :
369391 info .update (self ._get_huggingface_info (model_path_or_name ))
370- elif model_type in ["openai" , "anthropic" ]:
392+ elif model_type in ["openai" , "gemini" , " anthropic" ]:
371393 info .update ({"requires_api_key" : True })
372394
373395 return info
@@ -404,7 +426,7 @@ def load_model(
404426
405427 Args:
406428 model_path_or_name: Model identifier
407- model_type: Model type ("auto", "huggingface", "openai", "anthropic")
429+ model_type: Model type ("auto", "huggingface", "openai", "gemini", " anthropic")
408430 device: Computation device
409431 config_dict: Model configuration dictionary
410432 **kwargs: Additional model loading arguments
0 commit comments