diff --git a/python_gpt_po/models/enums.py b/python_gpt_po/models/enums.py index 58803d4..103dde3 100644 --- a/python_gpt_po/models/enums.py +++ b/python_gpt_po/models/enums.py @@ -10,3 +10,10 @@ class ModelProvider(Enum): OPENAI = "openai" ANTHROPIC = "anthropic" DEEPSEEK = "deepseek" + + +ModelProviderList = [ + ModelProvider.OPENAI.value, + ModelProvider.ANTHROPIC.value, + ModelProvider.DEEPSEEK.value +] diff --git a/python_gpt_po/models/provider_clients.py b/python_gpt_po/models/provider_clients.py index 11e9cfe..f451535 100644 --- a/python_gpt_po/models/provider_clients.py +++ b/python_gpt_po/models/provider_clients.py @@ -7,6 +7,8 @@ from anthropic import Anthropic from openai import OpenAI +from .enums import ModelProvider + class ProviderClients: """Class to store API clients for various providers.""" @@ -23,11 +25,11 @@ def initialize_clients(self, api_keys: Dict[str, str]): Args: api_keys (Dict[str, str]): Dictionary of provider names to API keys """ - if api_keys.get("openai"): - self.openai_client = OpenAI(api_key=api_keys["openai"]) + if api_keys.get(ModelProvider.OPENAI.value): + self.openai_client = OpenAI(api_key=api_keys[ModelProvider.OPENAI.value]) - if api_keys.get("anthropic"): - self.anthropic_client = Anthropic(api_key=api_keys["anthropic"]) + if api_keys.get(ModelProvider.ANTHROPIC.value): + self.anthropic_client = Anthropic(api_key=api_keys[ModelProvider.ANTHROPIC.value]) - if api_keys.get("deepseek"): - self.deepseek_api_key = api_keys["deepseek"] + if api_keys.get(ModelProvider.DEEPSEEK.value): + self.deepseek_api_key = api_keys[ModelProvider.DEEPSEEK.value] diff --git a/python_gpt_po/utils/cli.py b/python_gpt_po/utils/cli.py index e105afd..e0da78f 100644 --- a/python_gpt_po/utils/cli.py +++ b/python_gpt_po/utils/cli.py @@ -9,7 +9,7 @@ import sys from typing import Dict, List, Optional -from ..models.enums import ModelProvider +from ..models.enums import ModelProvider, ModelProviderList from .helpers import get_version @@ -99,7 +99,7 @@ def parse_args(): # Provider settings provider_group.add_argument( "--provider", - choices=["openai", "anthropic", "deepseek"], + choices=ModelProviderList, help="AI provider to use (default: first provider with available API key)" ) provider_group.add_argument( @@ -248,9 +248,9 @@ def get_api_keys_from_args(args) -> Dict[str, str]: Dict[str, str]: Dictionary of provider names to API keys """ return { - "openai": args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", ""), - "anthropic": args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", ""), - "deepseek": args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "") + ModelProvider.OPENAI.value: args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", ""), + ModelProvider.ANTHROPIC.value: args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", ""), + ModelProvider.DEEPSEEK.value: args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "") } @@ -264,7 +264,7 @@ def auto_select_provider(api_keys: Dict[str, str]) -> Optional[ModelProvider]: Returns: Optional[ModelProvider]: The auto-selected provider or None if no keys available """ - for provider_name in ["openai", "anthropic", "deepseek"]: + for provider_name in ModelProviderList: if api_keys.get(provider_name): provider = ModelProvider(provider_name) logging.info("Auto-selected provider: %s (based on available API key)", provider_name)