Skip to content

Commit 60c0bbf

Browse files
Merge pull request #48 from pescheckit/feature_splitting-model-in-seperate-classes
Splitting model in seperate classes
2 parents 8d8d189 + 4d58e73 commit 60c0bbf

File tree

11 files changed

+469
-114
lines changed

11 files changed

+469
-114
lines changed

python_gpt_po/services/model_manager.py

Lines changed: 35 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import logging
66
from typing import List
77

8-
import requests
9-
108
from ..models.enums import ModelProvider
119
from ..models.provider_clients import ProviderClients
10+
from .providers.registry import ProviderRegistry
1211

1312

1413
class ModelManager:
@@ -17,77 +16,17 @@ class ModelManager:
1716
@staticmethod
1817
def get_available_models(provider_clients: ProviderClients, provider: ModelProvider) -> List[str]:
1918
"""Retrieve available models from a specific provider."""
20-
models = []
19+
provider_impl = ProviderRegistry.get_provider(provider)
2120

22-
try:
23-
if provider == ModelProvider.OPENAI:
24-
if provider_clients.openai_client:
25-
response = provider_clients.openai_client.models.list()
26-
models = [model.id for model in response.data]
27-
else:
28-
logging.error("OpenAI client not initialized")
29-
30-
elif provider == ModelProvider.ANTHROPIC:
31-
if provider_clients.anthropic_client:
32-
# Use Anthropic's models endpoint
33-
try:
34-
headers = {
35-
"x-api-key": provider_clients.anthropic_client.api_key,
36-
"anthropic-version": "2023-06-01"
37-
}
38-
response = requests.get(
39-
"https://api.anthropic.com/v1/models",
40-
headers=headers,
41-
timeout=15
42-
)
43-
response.raise_for_status()
44-
model_data = response.json().get("data", [])
45-
models = [model["id"] for model in model_data]
46-
except Exception as e:
47-
logging.error("Error fetching Anthropic models: %s", str(e))
48-
# Fallback to commonly used models if API call fails
49-
models = [
50-
"claude-3-7-sonnet-latest",
51-
"claude-3-5-haiku-latest",
52-
"claude-3-5-sonnet-latest",
53-
"claude-3-opus-20240229",
54-
]
55-
else:
56-
logging.error("Anthropic client not initialized")
57-
58-
elif provider == ModelProvider.DEEPSEEK:
59-
if provider_clients.deepseek_api_key:
60-
headers = {
61-
"Authorization": f"Bearer {provider_clients.deepseek_api_key}",
62-
"Content-Type": "application/json"
63-
}
64-
response = requests.get(
65-
f"{provider_clients.deepseek_base_url}/models",
66-
headers=headers,
67-
timeout=15
68-
)
69-
response.raise_for_status()
70-
models = [model["id"] for model in response.json().get("data", [])]
71-
else:
72-
logging.error("DeepSeek API key not set")
73-
74-
elif provider == ModelProvider.AZURE_OPENAI:
75-
return ModelManager._get_azure_openai_models(provider_clients)
21+
if not provider_impl:
22+
logging.error("Provider %s not registered", provider.value)
23+
return []
7624

25+
try:
26+
return provider_impl.get_models(provider_clients)
7727
except Exception as e:
7828
logging.error("Error fetching models from %s: %s", provider.value, str(e))
79-
80-
return models
81-
82-
@staticmethod
83-
def _get_azure_openai_models(provider_clients: ProviderClients) -> List[str]:
84-
"""Retrieve models from Azure OpenAI."""
85-
if provider_clients.azure_openai_client:
86-
response = provider_clients.azure_openai_client.models.list()
87-
return [model.id for model in response.data]
88-
89-
logging.error("Azure OpenAI client not initialized")
90-
return []
29+
return []
9130

9231
@staticmethod
9332
def validate_model(provider_clients: ProviderClients, provider: ModelProvider, model: str) -> bool:
@@ -119,13 +58,13 @@ def get_default_model(provider: ModelProvider) -> str:
11958
Returns:
12059
str: The default model ID
12160
"""
122-
default_models = {
123-
ModelProvider.OPENAI: "gpt-4o-mini",
124-
ModelProvider.ANTHROPIC: "claude-3-5-haiku-latest",
125-
ModelProvider.DEEPSEEK: "deepseek-chat",
126-
ModelProvider.AZURE_OPENAI: "gpt-35-turbo",
127-
}
128-
return default_models.get(provider, "")
61+
provider_impl = ProviderRegistry.get_provider(provider)
62+
63+
if not provider_impl:
64+
logging.warning("Provider %s not registered, returning empty default", provider.value)
65+
return ""
66+
67+
return provider_impl.get_default_model()
12968

13069
@staticmethod
13170
def verify_model_capabilities(
@@ -170,25 +109,23 @@ def suggest_model(provider_clients: ProviderClients, provider: ModelProvider,
170109
Returns:
171110
str: The suggested model ID
172111
"""
173-
# For translation tasks, prefer more capable models when available
174-
if task == "translation":
175-
preferred_models = {
176-
ModelProvider.OPENAI: ["gpt-4", "gpt-4o", "gpt-3.5-turbo"],
177-
ModelProvider.ANTHROPIC: ["claude-3-opus", "claude-3-5-sonnet", "claude-3-5-haiku"],
178-
ModelProvider.DEEPSEEK: ["deepseek-chat"]
179-
}
180-
181-
available_models = ModelManager.get_available_models(provider_clients, provider)
182-
183-
# Try to find a match from the preferred models list
184-
for preferred in preferred_models.get(provider, []):
185-
for available in available_models:
186-
if preferred in available.lower():
187-
return available
188-
189-
# Fall back to the first available model or the default
190-
if available_models:
191-
return available_models[0]
192-
193-
# Default to the standard default model
194-
return ModelManager.get_default_model(provider)
112+
provider_impl = ProviderRegistry.get_provider(provider)
113+
114+
if not provider_impl:
115+
logging.warning("Provider %s not registered", provider.value)
116+
return ""
117+
118+
preferred_models = provider_impl.get_preferred_models(task)
119+
available_models = ModelManager.get_available_models(provider_clients, provider)
120+
121+
# Try to find a match from the preferred models list
122+
for preferred in preferred_models:
123+
for available in available_models:
124+
if preferred in available.lower():
125+
return available
126+
127+
# Fall back to the first available model or the default
128+
if available_models:
129+
return available_models[0]
130+
131+
return provider_impl.get_default_model()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
Provider implementations for model management.
3+
"""
4+
from . import provider_init # noqa: F401 - Auto-registers providers
5+
from .base import ModelProviderInterface
6+
from .registry import ProviderRegistry
7+
8+
__all__ = ["ModelProviderInterface", "ProviderRegistry"]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Anthropic provider implementation.
3+
"""
4+
import logging
5+
from typing import List
6+
7+
import requests
8+
9+
from ...models.provider_clients import ProviderClients
10+
from .base import ModelProviderInterface
11+
12+
13+
class AnthropicProvider(ModelProviderInterface):
14+
"""Anthropic model provider implementation."""
15+
16+
def get_models(self, provider_clients: ProviderClients) -> List[str]:
17+
"""Retrieve available models from Anthropic."""
18+
models = []
19+
20+
if not self.is_client_initialized(provider_clients):
21+
logging.error("Anthropic client not initialized")
22+
return models
23+
24+
try:
25+
headers = {
26+
"x-api-key": provider_clients.anthropic_client.api_key,
27+
"anthropic-version": "2023-06-01"
28+
}
29+
response = requests.get(
30+
"https://api.anthropic.com/v1/models",
31+
headers=headers,
32+
timeout=15
33+
)
34+
response.raise_for_status()
35+
model_data = response.json().get("data", [])
36+
models = [model["id"] for model in model_data]
37+
except Exception as e:
38+
logging.error("Error fetching Anthropic models: %s", str(e))
39+
models = self.get_fallback_models()
40+
41+
return models
42+
43+
def get_default_model(self) -> str:
44+
"""Get the default Anthropic model."""
45+
return "claude-3-5-haiku-latest"
46+
47+
def get_preferred_models(self, task: str = "translation") -> List[str]:
48+
"""Get preferred Anthropic models for a task."""
49+
if task == "translation":
50+
return ["claude-3-opus", "claude-3-5-sonnet", "claude-3-5-haiku"]
51+
return ["claude-3-5-haiku-latest"]
52+
53+
def is_client_initialized(self, provider_clients: ProviderClients) -> bool:
54+
"""Check if Anthropic client is initialized."""
55+
return provider_clients.anthropic_client is not None
56+
57+
def get_fallback_models(self) -> List[str]:
58+
"""Get fallback models for Anthropic."""
59+
return [
60+
"claude-3-7-sonnet-latest",
61+
"claude-3-5-haiku-latest",
62+
"claude-3-5-sonnet-latest",
63+
"claude-3-opus-20240229",
64+
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Azure OpenAI provider implementation.
3+
"""
4+
import logging
5+
from typing import List
6+
7+
from ...models.provider_clients import ProviderClients
8+
from .base import ModelProviderInterface
9+
10+
11+
class AzureOpenAIProvider(ModelProviderInterface):
12+
"""Azure OpenAI model provider implementation."""
13+
14+
def get_models(self, provider_clients: ProviderClients) -> List[str]:
15+
"""Retrieve available models from Azure OpenAI."""
16+
models = []
17+
18+
if not self.is_client_initialized(provider_clients):
19+
logging.error("Azure OpenAI client not initialized")
20+
return models
21+
22+
try:
23+
response = provider_clients.azure_openai_client.models.list()
24+
models = [model.id for model in response.data]
25+
except Exception as e:
26+
logging.error("Error fetching Azure OpenAI models: %s", str(e))
27+
models = self.get_fallback_models()
28+
29+
return models
30+
31+
def get_default_model(self) -> str:
32+
"""Get the default Azure OpenAI model."""
33+
return "gpt-35-turbo"
34+
35+
def get_preferred_models(self, task: str = "translation") -> List[str]:
36+
"""Get preferred Azure OpenAI models for a task."""
37+
if task == "translation":
38+
return ["gpt-4", "gpt-35-turbo"]
39+
return ["gpt-35-turbo"]
40+
41+
def is_client_initialized(self, provider_clients: ProviderClients) -> bool:
42+
"""Check if Azure OpenAI client is initialized."""
43+
return provider_clients.azure_openai_client is not None
44+
45+
def get_fallback_models(self) -> List[str]:
46+
"""Get fallback models for Azure OpenAI."""
47+
return ["gpt-35-turbo", "gpt-4"]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
Base interface for model providers.
3+
"""
4+
from abc import ABC, abstractmethod
5+
from typing import List
6+
7+
from ...models.provider_clients import ProviderClients
8+
9+
10+
class ModelProviderInterface(ABC):
11+
"""Abstract base class for model providers."""
12+
13+
@abstractmethod
14+
def get_models(self, provider_clients: ProviderClients) -> List[str]:
15+
"""Retrieve available models from the provider.
16+
17+
Args:
18+
provider_clients: Initialized provider clients
19+
20+
Returns:
21+
List of available model IDs
22+
"""
23+
24+
@abstractmethod
25+
def get_default_model(self) -> str:
26+
"""Get the default model for this provider.
27+
28+
Returns:
29+
Default model ID
30+
"""
31+
32+
@abstractmethod
33+
def get_preferred_models(self, task: str = "translation") -> List[str]:
34+
"""Get preferred models for a specific task.
35+
36+
Args:
37+
task: The task type (default: "translation")
38+
39+
Returns:
40+
List of preferred model IDs in order of preference
41+
"""
42+
43+
@abstractmethod
44+
def is_client_initialized(self, provider_clients: ProviderClients) -> bool:
45+
"""Check if the provider client is initialized.
46+
47+
Args:
48+
provider_clients: Provider clients instance
49+
50+
Returns:
51+
True if client is initialized, False otherwise
52+
"""
53+
54+
def get_fallback_models(self) -> List[str]:
55+
"""Get fallback models when API calls fail.
56+
57+
Returns:
58+
List of fallback model IDs
59+
"""
60+
return []

0 commit comments

Comments
 (0)