55import logging
66from typing import List
77
8- import requests
9-
108from ..models .enums import ModelProvider
119from ..models .provider_clients import ProviderClients
10+ from .providers .registry import ProviderRegistry
1211
1312
1413class ModelManager :
@@ -17,77 +16,17 @@ class ModelManager:
1716 @staticmethod
1817 def get_available_models (provider_clients : ProviderClients , provider : ModelProvider ) -> List [str ]:
1918 """Retrieve available models from a specific provider."""
20- models = []
19+ provider_impl = ProviderRegistry . get_provider ( provider )
2120
22- try :
23- if provider == ModelProvider .OPENAI :
24- if provider_clients .openai_client :
25- response = provider_clients .openai_client .models .list ()
26- models = [model .id for model in response .data ]
27- else :
28- logging .error ("OpenAI client not initialized" )
29-
30- elif provider == ModelProvider .ANTHROPIC :
31- if provider_clients .anthropic_client :
32- # Use Anthropic's models endpoint
33- try :
34- headers = {
35- "x-api-key" : provider_clients .anthropic_client .api_key ,
36- "anthropic-version" : "2023-06-01"
37- }
38- response = requests .get (
39- "https://api.anthropic.com/v1/models" ,
40- headers = headers ,
41- timeout = 15
42- )
43- response .raise_for_status ()
44- model_data = response .json ().get ("data" , [])
45- models = [model ["id" ] for model in model_data ]
46- except Exception as e :
47- logging .error ("Error fetching Anthropic models: %s" , str (e ))
48- # Fallback to commonly used models if API call fails
49- models = [
50- "claude-3-7-sonnet-latest" ,
51- "claude-3-5-haiku-latest" ,
52- "claude-3-5-sonnet-latest" ,
53- "claude-3-opus-20240229" ,
54- ]
55- else :
56- logging .error ("Anthropic client not initialized" )
57-
58- elif provider == ModelProvider .DEEPSEEK :
59- if provider_clients .deepseek_api_key :
60- headers = {
61- "Authorization" : f"Bearer { provider_clients .deepseek_api_key } " ,
62- "Content-Type" : "application/json"
63- }
64- response = requests .get (
65- f"{ provider_clients .deepseek_base_url } /models" ,
66- headers = headers ,
67- timeout = 15
68- )
69- response .raise_for_status ()
70- models = [model ["id" ] for model in response .json ().get ("data" , [])]
71- else :
72- logging .error ("DeepSeek API key not set" )
73-
74- elif provider == ModelProvider .AZURE_OPENAI :
75- return ModelManager ._get_azure_openai_models (provider_clients )
21+ if not provider_impl :
22+ logging .error ("Provider %s not registered" , provider .value )
23+ return []
7624
25+ try :
26+ return provider_impl .get_models (provider_clients )
7727 except Exception as e :
7828 logging .error ("Error fetching models from %s: %s" , provider .value , str (e ))
79-
80- return models
81-
82- @staticmethod
83- def _get_azure_openai_models (provider_clients : ProviderClients ) -> List [str ]:
84- """Retrieve models from Azure OpenAI."""
85- if provider_clients .azure_openai_client :
86- response = provider_clients .azure_openai_client .models .list ()
87- return [model .id for model in response .data ]
88-
89- logging .error ("Azure OpenAI client not initialized" )
90- return []
29+ return []
9130
9231 @staticmethod
9332 def validate_model (provider_clients : ProviderClients , provider : ModelProvider , model : str ) -> bool :
@@ -119,13 +58,13 @@ def get_default_model(provider: ModelProvider) -> str:
11958 Returns:
12059 str: The default model ID
12160 """
122- default_models = {
123- ModelProvider . OPENAI : "gpt-4o-mini" ,
124- ModelProvider . ANTHROPIC : "claude-3-5-haiku-latest" ,
125- ModelProvider . DEEPSEEK : "deepseek-chat" ,
126- ModelProvider . AZURE_OPENAI : "gpt-35-turbo" ,
127- }
128- return default_models . get ( provider , "" )
61+ provider_impl = ProviderRegistry . get_provider ( provider )
62+
63+ if not provider_impl :
64+ logging . warning ( "Provider %s not registered, returning empty default" , provider . value )
65+ return ""
66+
67+ return provider_impl . get_default_model ( )
12968
13069 @staticmethod
13170 def verify_model_capabilities (
@@ -170,25 +109,23 @@ def suggest_model(provider_clients: ProviderClients, provider: ModelProvider,
170109 Returns:
171110 str: The suggested model ID
172111 """
173- # For translation tasks, prefer more capable models when available
174- if task == "translation" :
175- preferred_models = {
176- ModelProvider .OPENAI : ["gpt-4" , "gpt-4o" , "gpt-3.5-turbo" ],
177- ModelProvider .ANTHROPIC : ["claude-3-opus" , "claude-3-5-sonnet" , "claude-3-5-haiku" ],
178- ModelProvider .DEEPSEEK : ["deepseek-chat" ]
179- }
180-
181- available_models = ModelManager .get_available_models (provider_clients , provider )
182-
183- # Try to find a match from the preferred models list
184- for preferred in preferred_models .get (provider , []):
185- for available in available_models :
186- if preferred in available .lower ():
187- return available
188-
189- # Fall back to the first available model or the default
190- if available_models :
191- return available_models [0 ]
192-
193- # Default to the standard default model
194- return ModelManager .get_default_model (provider )
112+ provider_impl = ProviderRegistry .get_provider (provider )
113+
114+ if not provider_impl :
115+ logging .warning ("Provider %s not registered" , provider .value )
116+ return ""
117+
118+ preferred_models = provider_impl .get_preferred_models (task )
119+ available_models = ModelManager .get_available_models (provider_clients , provider )
120+
121+ # Try to find a match from the preferred models list
122+ for preferred in preferred_models :
123+ for available in available_models :
124+ if preferred in available .lower ():
125+ return available
126+
127+ # Fall back to the first available model or the default
128+ if available_models :
129+ return available_models [0 ]
130+
131+ return provider_impl .get_default_model ()
0 commit comments