44"""
55import logging
66from typing import List
7-
87from ..models .enums import ModelProvider
98from ..models .provider_clients import ProviderClients
109from .providers .registry import ProviderRegistry
1110
1211
1312class ModelManager :
1413 """Class to manage models from different providers."""
15-
1614 @staticmethod
1715 def get_available_models (provider_clients : ProviderClients , provider : ModelProvider ) -> List [str ]:
1816 """Retrieve available models from a specific provider."""
1917 provider_impl = ProviderRegistry .get_provider (provider )
20-
2118 if not provider_impl :
2219 logging .error ("Provider %s not registered" , provider .value )
2320 return []
24-
2521 try :
2622 return provider_impl .get_models (provider_clients )
2723 except Exception as e :
@@ -33,12 +29,10 @@ def validate_model(provider_clients: ProviderClients, provider: ModelProvider, m
3329 """
3430 Validates whether the specified model is available for the given provider.
3531 Uses prefix matching so that a shorthand (e.g. "claude") will match a full model name.
36-
3732 Args:
3833 provider_clients (ProviderClients): The initialized provider clients
3934 provider (ModelProvider): The provider to check against
4035 model (str): The model name/ID to validate
41-
4236 Returns:
4337 bool: True if the model is valid, False otherwise
4438 """
@@ -51,19 +45,15 @@ def validate_model(provider_clients: ProviderClients, provider: ModelProvider, m
5145 def get_default_model (provider : ModelProvider ) -> str :
5246 """
5347 Returns the default model for a given provider.
54-
5548 Args:
5649 provider (ModelProvider): The provider to get the default model for
57-
5850 Returns:
5951 str: The default model ID
6052 """
6153 provider_impl = ProviderRegistry .get_provider (provider )
62-
6354 if not provider_impl :
6455 logging .warning ("Provider %s not registered, returning empty default" , provider .value )
6556 return ""
66-
6757 return provider_impl .get_default_model ()
6858
6959 @staticmethod
@@ -75,57 +65,18 @@ def verify_model_capabilities(
7565 ) -> bool :
7666 """
7767 Verifies if a model has the required capabilities.
78-
7968 Args:
8069 provider_clients (ProviderClients): The initialized provider clients
8170 provider (ModelProvider): The provider to check against
8271 model (str): The model to verify
8372 required_capability (str): The capability to check for
84-
8573 Returns:
8674 bool: True if the model has the required capability, False otherwise
8775 """
8876 # This is a simplified implementation - in a real-world scenario,
8977 # you might want to check model specifications/documentation
90-
9178 # For now, assume all models support translation
9279 # In the future, this could check model specs for specific capabilities
9380 if required_capability == "translation" :
9481 return ModelManager .validate_model (provider_clients , provider , model )
95-
9682 return False
97-
98- @staticmethod
99- def suggest_model (provider_clients : ProviderClients , provider : ModelProvider ,
100- task : str = "translation" ) -> str :
101- """
102- Suggests the best model for a given task and provider.
103-
104- Args:
105- provider_clients (ProviderClients): The initialized provider clients
106- provider (ModelProvider): The provider to use
107- task (str): The task the model will be used for
108-
109- Returns:
110- str: The suggested model ID
111- """
112- provider_impl = ProviderRegistry .get_provider (provider )
113-
114- if not provider_impl :
115- logging .warning ("Provider %s not registered" , provider .value )
116- return ""
117-
118- preferred_models = provider_impl .get_preferred_models (task )
119- available_models = ModelManager .get_available_models (provider_clients , provider )
120-
121- # Try to find a match from the preferred models list
122- for preferred in preferred_models :
123- for available in available_models :
124- if preferred in available .lower ():
125- return available
126-
127- # Fall back to the first available model or the default
128- if available_models :
129- return available_models [0 ]
130-
131- return provider_impl .get_default_model ()
0 commit comments