Skip to content

Commit 6eca9b1

Browse files
author
bram
committed
Made more generic
1 parent 3756f48 commit 6eca9b1

File tree

3 files changed

+78
-75
lines changed

3 files changed

+78
-75
lines changed

python_gpt_po/models/provider_clients.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,45 @@ def __init__(self):
2020
self.azure_openai_client = None
2121
self.anthropic_client = None
2222
self.deepseek_api_key = None
23-
self.deepseek_base_url = "https://api.deepseek.com/v1"
24-
self.ollama_base_url = "http://localhost:11434"
25-
self.ollama_timeout = 120
23+
self.deepseek_base_url = None
24+
self.ollama_base_url = None
25+
self.ollama_timeout = None
26+
27+
def _get_setting(self, args: Namespace, arg_name: str, env_var: str = None,
28+
config_provider: str = None, config_key: str = None, default: any = None) -> any:
29+
"""Generic helper to get a setting from CLI args, env var, config, or default.
30+
31+
Args:
32+
args: Command line arguments
33+
arg_name: Attribute name in args (e.g., 'ollama_base_url')
34+
env_var: Environment variable name (optional)
35+
config_provider: Provider name for config lookup (optional)
36+
config_key: Config key name (optional)
37+
default: Default value
38+
39+
Returns:
40+
Setting value from first available source
41+
"""
42+
# Priority 1: CLI argument
43+
if hasattr(args, arg_name) and getattr(args, arg_name):
44+
return getattr(args, arg_name)
45+
46+
# Priority 2: Environment variable
47+
if env_var:
48+
env_value = os.getenv(env_var)
49+
if env_value:
50+
return env_value
51+
52+
# Priority 3: Config file
53+
if config_provider and config_key:
54+
from ..utils.config_loader import ConfigLoader
55+
folder_path = args.folder if hasattr(args, 'folder') else None
56+
config_value = ConfigLoader.get_provider_setting(config_provider, config_key, None, folder_path)
57+
if config_value is not None:
58+
return config_value
59+
60+
# Priority 4: Default
61+
return default
2662

2763
def initialize_clients(self, args: Namespace) -> Dict[str, str]:
2864
"""Initialize API clients for all providers with available keys.
@@ -32,17 +68,29 @@ def initialize_clients(self, args: Namespace) -> Dict[str, str]:
3268
Returns:
3369
Dict[str, str]: Dictionary of API keys for each provider
3470
"""
35-
openai_key = args.openai_key or args.api_key or os.getenv("OPENAI_API_KEY", "")
71+
# OpenAI
72+
openai_key = self._get_setting(args, 'openai_key', 'OPENAI_API_KEY', 'openai', 'api_key', '')
73+
if not openai_key and hasattr(args, 'api_key'):
74+
openai_key = args.api_key
3675
if openai_key:
3776
self.openai_client = OpenAI(api_key=openai_key)
3877

39-
azure_openai_key = args.azure_openai_key or os.getenv("AZURE_OPENAI_API_KEY", "")
78+
# Azure OpenAI
79+
azure_openai_key = self._get_setting(
80+
args, 'azure_openai_key', 'AZURE_OPENAI_API_KEY', 'azure_openai', 'api_key', ''
81+
)
4082
if azure_openai_key:
41-
endpoint = args.azure_openai_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
83+
endpoint = self._get_setting(
84+
args, 'azure_openai_endpoint', 'AZURE_OPENAI_ENDPOINT',
85+
'azure_openai', 'endpoint', None
86+
)
4287
if not endpoint:
4388
raise ValueError("Missing Azure OpenAI endpoint.")
4489

45-
api_version = args.azure_openai_api_version or os.getenv("AZURE_OPENAI_API_VERSION")
90+
api_version = self._get_setting(
91+
args, 'azure_openai_api_version', 'AZURE_OPENAI_API_VERSION',
92+
'azure_openai', 'api_version', None
93+
)
4694
if not api_version:
4795
raise ValueError("Missing Azure OpenAI API version.")
4896

@@ -52,19 +100,34 @@ def initialize_clients(self, args: Namespace) -> Dict[str, str]:
52100
api_version=api_version
53101
)
54102

55-
antropic_key = args.anthropic_key or os.getenv("ANTHROPIC_API_KEY", "")
103+
# Anthropic
104+
antropic_key = self._get_setting(
105+
args, 'anthropic_key', 'ANTHROPIC_API_KEY', 'anthropic', 'api_key', ''
106+
)
56107
if antropic_key:
57108
self.anthropic_client = Anthropic(api_key=antropic_key)
58109

59-
deepseek_key = args.deepseek_key or os.getenv("DEEPSEEK_API_KEY", "")
110+
# DeepSeek
111+
deepseek_key = self._get_setting(
112+
args, 'deepseek_key', 'DEEPSEEK_API_KEY', 'deepseek', 'api_key', ''
113+
)
60114
if deepseek_key:
61115
self.deepseek_api_key = deepseek_key
62116

63-
# Ollama configuration (no API key needed)
64-
if hasattr(args, 'ollama_base_url') and args.ollama_base_url:
65-
self.ollama_base_url = args.ollama_base_url
66-
if hasattr(args, 'ollama_timeout') and args.ollama_timeout:
67-
self.ollama_timeout = args.ollama_timeout
117+
self.deepseek_base_url = self._get_setting(
118+
args, 'deepseek_base_url', 'DEEPSEEK_BASE_URL',
119+
'deepseek', 'base_url', 'https://api.deepseek.com/v1'
120+
)
121+
122+
# Ollama
123+
self.ollama_base_url = self._get_setting(
124+
args, 'ollama_base_url', 'OLLAMA_BASE_URL',
125+
'ollama', 'base_url', 'http://localhost:11434'
126+
)
127+
self.ollama_timeout = self._get_setting(
128+
args, 'ollama_timeout', None,
129+
'ollama', 'timeout', 120
130+
)
68131

69132
return {
70133
ModelProvider.OPENAI.value: openai_key,

python_gpt_po/utils/cli.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,7 @@ def parse_args() -> Namespace:
221221
parser.print_help()
222222
sys.exit(0)
223223

224-
args = parser.parse_args()
225-
226-
# Apply config file defaults for Ollama if not provided via CLI
227-
if not args.ollama_base_url:
228-
from .config_loader import ConfigLoader
229-
args.ollama_base_url = ConfigLoader.get_ollama_base_url(args.folder if hasattr(args, 'folder') else None)
230-
if not args.ollama_timeout:
231-
from .config_loader import ConfigLoader
232-
args.ollama_timeout = ConfigLoader.get_ollama_timeout(args.folder if hasattr(args, 'folder') else None)
233-
234-
return args
224+
return parser.parse_args()
235225

236226

237227
def show_help_and_exit():

python_gpt_po/utils/config_loader.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -214,56 +214,6 @@ def get_default_model(cls, provider: str, start_path: Optional[str] = None) -> O
214214
default_models = config.get('default_models', {})
215215
return default_models.get(provider)
216216

217-
@classmethod
218-
def get_ollama_base_url(cls, start_path: Optional[str] = None) -> str:
219-
"""
220-
Get Ollama base URL from config, env var, or default.
221-
222-
Priority:
223-
1. Environment variable OLLAMA_BASE_URL
224-
2. Config file provider.ollama.base_url
225-
3. Default: http://localhost:11434
226-
227-
Args:
228-
start_path: Directory to start searching from
229-
230-
Returns:
231-
Ollama base URL
232-
"""
233-
# Priority 1: Environment variable
234-
env_url = os.getenv('OLLAMA_BASE_URL')
235-
if env_url:
236-
return env_url
237-
238-
# Priority 2: Config file, Priority 3: Default
239-
return cls.get_provider_setting('ollama', 'base_url', 'http://localhost:11434', start_path)
240-
241-
@classmethod
242-
def get_ollama_model(cls, start_path: Optional[str] = None) -> str:
243-
"""
244-
Get default Ollama model from config or default.
245-
246-
Args:
247-
start_path: Directory to start searching from
248-
249-
Returns:
250-
Ollama model name
251-
"""
252-
return cls.get_provider_setting('ollama', 'model', 'llama3.2', start_path)
253-
254-
@classmethod
255-
def get_ollama_timeout(cls, start_path: Optional[str] = None) -> int:
256-
"""
257-
Get Ollama request timeout from config or default.
258-
259-
Args:
260-
start_path: Directory to start searching from
261-
262-
Returns:
263-
Timeout in seconds
264-
"""
265-
return cls.get_provider_setting('ollama', 'timeout', 120, start_path)
266-
267217
@classmethod
268218
def get_provider_setting(cls, provider: str, setting: str, default: Any = None,
269219
start_path: Optional[str] = None) -> Any:

0 commit comments

Comments
 (0)