Skip to content

Commit a8ffd40

Browse files
Copilotmawad-amd
andauthored
Add early LLM credential validation to prevent wasted work (#150)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
1 parent 6bd1939 commit a8ffd40

File tree

3 files changed

+161
-1
lines changed

3 files changed

+161
-1
lines changed

src/intelliperf/__main__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,14 @@ def main():
213213
else:
214214
logging.basicConfig(level=logging.WARNING, format="[INTELLIPERF] %(levelname)s: %(message)s")
215215

216+
# Validate LLM credentials early before doing any work (only for non-diagnoseOnly formulas)
217+
if args.formula != "diagnoseOnly":
218+
from intelliperf.core.llm import validate_llm_credentials
219+
from intelliperf.utils.env import get_llm_api_key
220+
221+
llm_key = get_llm_api_key()
222+
validate_llm_credentials(api_key=llm_key, model=args.model, provider=args.provider)
223+
216224
# Create an optimizer based on available IntelliPerf formulas.
217225
if args.formula == "bankConflict":
218226
from intelliperf.formulas.bank_conflict import bank_conflict

src/intelliperf/core/llm.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,92 @@
2323
################################################################################
2424

2525

26+
import logging
2627
import sys
2728
from typing import Optional
2829

2930
import dspy
3031
import requests
31-
3232
from intelliperf.core.logger import Logger
3333

3434

35+
def validate_llm_credentials(api_key: str, model: str, provider: str) -> bool:
36+
"""Validate LLM credentials before doing any work.
37+
38+
Args:
39+
api_key: The API key to validate
40+
model: The model name to validate
41+
provider: The provider URL to validate
42+
43+
Returns:
44+
bool: True if credentials are valid, False otherwise.
45+
46+
Raises:
47+
SystemExit: If credentials are invalid, exits with error message.
48+
"""
49+
# Create a temporary LLM instance to validate credentials
50+
temp_llm = LLM(api_key=api_key, system_prompt="test", model=model, provider=provider)
51+
return temp_llm._validate_credentials()
52+
53+
3554
class LLM:
55+
def _validate_credentials(self) -> bool:
56+
"""Validate that the LLM credentials, model, and provider are correct.
57+
58+
Returns:
59+
bool: True if credentials are valid, False otherwise.
60+
61+
Raises:
62+
SystemExit: If credentials are invalid, exits with error message.
63+
"""
64+
try:
65+
if self.use_amd:
66+
# Test AMD/Azure credentials with a minimal request
67+
body = {
68+
"messages": [
69+
{"role": "system", "content": "You are a helpful assistant."},
70+
{"role": "user", "content": "Hi"},
71+
],
72+
"max_Tokens": 10,
73+
"max_Completion_Tokens": 10,
74+
}
75+
url = f"{self.provider}/engines/{self.model}/chat/completions"
76+
resp = requests.post(url, json=body, headers=self.header, timeout=10)
77+
resp.raise_for_status()
78+
logging.info("LLM credentials validated successfully.")
79+
return True
80+
else:
81+
# Test DSPy/OpenAI credentials with a minimal request
82+
try:
83+
# Make a simple test call through dspy
84+
test_signature = "input: str -> output: str"
85+
chain = dspy.ChainOfThought(test_signature)
86+
chain(input="test")
87+
logging.info("LLM credentials validated successfully.")
88+
return True
89+
except Exception as e:
90+
raise e
91+
except requests.exceptions.HTTPError as e:
92+
status_code = e.response.status_code if hasattr(e, "response") else None
93+
error_msg = "LLM credential validation failed.\n"
94+
error_msg += f"Provider: {self.provider}\n"
95+
error_msg += f"Model: {self.model}\n"
96+
if status_code == 401:
97+
error_msg += "Error: Authentication failed. Please check your API key.\n"
98+
elif status_code == 404:
99+
error_msg += "Error: Model or endpoint not found. Please check the model name and provider URL.\n"
100+
else:
101+
error_msg += f"Error: HTTP {status_code} - {str(e)}\n"
102+
logging.error(error_msg)
103+
sys.exit(1)
104+
except Exception as e:
105+
error_msg = "LLM credential validation failed.\n"
106+
error_msg += f"Provider: {self.provider}\n"
107+
error_msg += f"Model: {self.model}\n"
108+
error_msg += f"Error: {str(e)}\n"
109+
logging.error(error_msg)
110+
sys.exit(1)
111+
36112
def _get_model_context_length(self) -> Optional[int]:
37113
"""Query the model's max context length from the API"""
38114
import logging

tests/test_llm_validation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from unittest.mock import Mock, patch
2+
3+
import pytest
4+
from intelliperf.core.llm import validate_llm_credentials
5+
6+
7+
def test_validate_credentials_amd_success():
8+
"""Test successful credential validation for AMD provider."""
9+
with patch("requests.post") as mock_post, patch("dspy.LM"), patch("dspy.configure"):
10+
mock_response = Mock()
11+
mock_response.status_code = 200
12+
mock_response.json.return_value = {"choices": [{"message": {"content": "test"}}]}
13+
mock_post.return_value = mock_response
14+
15+
# Should not raise an exception
16+
result = validate_llm_credentials(
17+
api_key="test-key", model="test-model", provider="https://llm-api.amd.com/azure"
18+
)
19+
assert result is True
20+
21+
22+
def test_validate_credentials_amd_auth_failure():
23+
"""Test credential validation fails with 401 for AMD provider."""
24+
with patch("requests.post") as mock_post, patch("dspy.LM"), patch("dspy.configure"):
25+
mock_response = Mock()
26+
mock_response.status_code = 401
27+
mock_http_error = Exception("401 Unauthorized")
28+
mock_http_error.response = mock_response
29+
mock_response.raise_for_status.side_effect = mock_http_error
30+
mock_post.return_value = mock_response
31+
32+
with pytest.raises(SystemExit):
33+
validate_llm_credentials(api_key="bad-key", model="test-model", provider="https://llm-api.amd.com/azure")
34+
35+
36+
def test_validate_credentials_amd_model_not_found():
37+
"""Test credential validation fails with 404 for AMD provider."""
38+
with patch("requests.post") as mock_post, patch("dspy.LM"), patch("dspy.configure"):
39+
mock_response = Mock()
40+
mock_response.status_code = 404
41+
mock_http_error = Exception("404 Not Found")
42+
mock_http_error.response = mock_response
43+
mock_response.raise_for_status.side_effect = mock_http_error
44+
mock_post.return_value = mock_response
45+
46+
with pytest.raises(SystemExit):
47+
validate_llm_credentials(
48+
api_key="test-key",
49+
model="nonexistent-model",
50+
provider="https://llm-api.amd.com/azure",
51+
)
52+
53+
54+
def test_validate_credentials_openai_success():
55+
"""Test successful credential validation for OpenAI provider."""
56+
with patch("dspy.LM"), patch("dspy.configure"), patch("dspy.ChainOfThought") as mock_chain:
57+
# Mock the chain of thought call
58+
mock_chain_instance = Mock()
59+
mock_chain_instance.return_value = Mock(output="test")
60+
mock_chain.return_value = mock_chain_instance
61+
62+
# Should not raise an exception
63+
result = validate_llm_credentials(api_key="test-key", model="gpt-4", provider="openai")
64+
assert result is True
65+
66+
67+
def test_validate_credentials_openai_failure():
68+
"""Test credential validation fails for OpenAI provider."""
69+
with patch("dspy.LM"), patch("dspy.configure"), patch("dspy.ChainOfThought") as mock_chain:
70+
# Mock the chain of thought to raise an exception
71+
mock_chain_instance = Mock()
72+
mock_chain_instance.side_effect = Exception("Authentication failed")
73+
mock_chain.return_value = mock_chain_instance
74+
75+
with pytest.raises(SystemExit):
76+
validate_llm_credentials(api_key="bad-key", model="gpt-4", provider="openai")

0 commit comments

Comments
 (0)