Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python_gpt_po/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ class ModelProvider(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
DEEPSEEK = "deepseek"


ModelProviderList = [
ModelProvider.OPENAI.value,
ModelProvider.ANTHROPIC.value,
ModelProvider.DEEPSEEK.value
]
14 changes: 8 additions & 6 deletions python_gpt_po/models/provider_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from anthropic import Anthropic
from openai import OpenAI

from .enums import ModelProvider


class ProviderClients:
"""Class to store API clients for various providers."""
Expand All @@ -23,11 +25,11 @@ def initialize_clients(self, api_keys: Dict[str, str]):
Args:
api_keys (Dict[str, str]): Dictionary of provider names to API keys
"""
if api_keys.get("openai"):
self.openai_client = OpenAI(api_key=api_keys["openai"])
if api_keys.get(ModelProvider.OPENAI.value):
self.openai_client = OpenAI(api_key=api_keys[ModelProvider.OPENAI.value])

if api_keys.get("anthropic"):
self.anthropic_client = Anthropic(api_key=api_keys["anthropic"])
if api_keys.get(ModelProvider.ANTHROPIC.value):
self.anthropic_client = Anthropic(api_key=api_keys[ModelProvider.ANTHROPIC.value])

if api_keys.get("deepseek"):
self.deepseek_api_key = api_keys["deepseek"]
if api_keys.get(ModelProvider.DEEPSEEK.value):
self.deepseek_api_key = api_keys[ModelProvider.DEEPSEEK.value]
12 changes: 6 additions & 6 deletions python_gpt_po/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
from typing import Dict, List, Optional

from ..models.enums import ModelProvider
from ..models.enums import ModelProvider, ModelProviderList
from .helpers import get_version


Expand Down Expand Up @@ -99,7 +99,7 @@ def parse_args():
# Provider settings
provider_group.add_argument(
"--provider",
choices=["openai", "anthropic", "deepseek"],
choices=ModelProviderList,
help="AI provider to use (default: first provider with available API key)"
)
provider_group.add_argument(
Expand Down Expand Up @@ -248,9 +248,9 @@ def get_api_keys_from_args(args) -> Dict[str, str]:
Dict[str, str]: Dictionary of provider names to API keys
"""
return {
"openai": args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", ""),
"anthropic": args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", ""),
"deepseek": args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "")
ModelProvider.OPENAI.value: args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", ""),
ModelProvider.ANTHROPIC.value: args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", ""),
ModelProvider.DEEPSEEK.value: args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "")
}


Expand All @@ -264,7 +264,7 @@ def auto_select_provider(api_keys: Dict[str, str]) -> Optional[ModelProvider]:
Returns:
Optional[ModelProvider]: The auto-selected provider or None if no keys available
"""
for provider_name in ["openai", "anthropic", "deepseek"]:
for provider_name in ModelProviderList:
if api_keys.get(provider_name):
provider = ModelProvider(provider_name)
logging.info("Auto-selected provider: %s (based on available API key)", provider_name)
Expand Down