Skip to content

Commit 601ab48

Browse files
authored
codex Merge pull request #21 from Krafman/j2k9s2-codex/add-support-for-mistral-api
Add Mistral LLM provider
2 parents bc31d2f + 716fa09 commit 601ab48

File tree

5 files changed

+237
-3
lines changed

5 files changed

+237
-3
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@ python verify_connection.py
106106
A successful run will print "Successfully connected to device!" followed by a dictionary of your device's information.
107107

108108
4. LLM Configuration
109-
API keys for LLM providers (OpenAI, Gemini, Together.ai) are configured in `src/llm_controller/llm_interface.py`.
109+
API keys for LLM providers (OpenAI, Gemini, Together.ai, Mistral) are configured in `src/llm_controller/llm_interface.py`.
110110
The system prioritizes loading API keys from environment variables:
111111
- `OPENAI_API_KEY` for OpenAI
112112
- `GEMINI_API_KEY` for Gemini
113113
- `TOGETHER_API_KEY` for Together.ai
114+
- `MISTRAL_API_KEY` for Mistral
114115
- `ANTHROPIC_API_KEY` for Anthropic (if ever fully implemented)
115116

116117
If environment variables are not set, it falls back to `config/llm_config.yml`.

config/llm_config.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,25 @@ openai:
2626
# API request timeout in seconds.
2727
request_timeout: 60 # seconds
2828

29+
# --- Mistral Specific Configuration ---
30+
mistral:
31+
# API key for Mistral
32+
# IMPORTANT: Load from an environment variable.
33+
# The LLMInterface will check os.getenv("MISTRAL_API_KEY") first.
34+
api_key: "YOUR_MISTRAL_API_KEY_PLACEHOLDER" # Fallback if env var not set.
35+
36+
# Default model, e.g., "open-mistral-7b", "open-mixtral-8x7b"
37+
default_model: "open-mistral-7b"
38+
39+
# Default temperature for model responses.
40+
temperature: 0.7
41+
42+
# Default maximum number of tokens to generate.
43+
max_tokens: 1024
44+
45+
# API request timeout in seconds.
46+
request_timeout: 60 # seconds
47+
2948
# --- Together.ai Specific Configuration ---
3049
together:
3150
# API key for Together.ai

docs/api_references.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,23 @@ This section will cover details for the LLM providers supported by the controlle
226226

227227
*(Add sections for other LLM providers like Anthropic Claude if they are considered for integration.)*
228228

229+
### 2.5. Mistral API
230+
231+
* **Official Documentation Link**: `https://docs.mistral.ai`
232+
* **Authentication**: API Key passed via the `MISTRAL_API_KEY` environment variable or in `llm_config.yml`.
233+
* **Endpoint**: `POST https://api.mistral.ai/v1/chat/completions`
234+
* **Example Request Body**:
235+
```json
236+
{
237+
"model": "open-mistral-7b",
238+
"messages": [
239+
{"role": "user", "content": "Hello"}
240+
],
241+
"response_format": {"type": "json_object"}
242+
}
243+
```
244+
* **Notes**: The API closely mirrors OpenAI's chat completion format. The controller sends requests asynchronously with `httpx`.
245+
229246
## 3. Other Relevant APIs/Libraries
230247

231248
* **openatx/uiautomator2**:

src/llm_controller/llm_interface.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class LLMInterface:
1414
"""
1515
Handles communication with the Large Language Model (LLM).
16-
Supports Gemini, OpenAI, Together.ai, and AWS Bedrock providers.
16+
Supports Gemini, OpenAI, Together.ai, Mistral, and AWS Bedrock providers.
1717
"""
1818
def __init__(self, config_loader: ConfigLoader):
1919
"""
@@ -103,6 +103,27 @@ def __init__(self, config_loader: ConfigLoader):
103103
self.together_client = AsyncTogether(api_key=self.api_key)
104104
else:
105105
self.together_client = None
106+
elif self.provider == "mistral":
107+
mistral_settings = llm_configs.get("mistral", {})
108+
self.api_key = os.getenv("MISTRAL_API_KEY")
109+
if self.api_key:
110+
logger.info("Loaded Mistral API key from MISTRAL_API_KEY environment variable.")
111+
else:
112+
self.api_key = mistral_settings.get("api_key")
113+
if self.api_key and self.api_key != "YOUR_MISTRAL_API_KEY_PLACEHOLDER":
114+
logger.info("Loaded Mistral API key from llm_config.yml.")
115+
logger.warning("For production, it's recommended to set the MISTRAL_API_KEY environment variable.")
116+
elif self.api_key == "YOUR_MISTRAL_API_KEY_PLACEHOLDER" or not self.api_key:
117+
self.api_key = ""
118+
logger.warning("Using empty API key for Mistral. Ensure MISTRAL_API_KEY env var is set or llm_config.yml has a valid key.")
119+
if not mistral_settings and not self.api_key:
120+
raise ConfigError("Mistral configuration missing and MISTRAL_API_KEY environment variable not set.")
121+
122+
self.model_name = mistral_settings.get("default_model", "open-mistral-7b")
123+
self.api_url = "https://api.mistral.ai/v1/chat/completions"
124+
self.temperature = mistral_settings.get("temperature", 0.7)
125+
self.max_tokens = mistral_settings.get("max_tokens", 1024)
126+
self.mistral_settings = mistral_settings
106127
elif self.provider == "anthropic": # Example, not fully implemented
107128
anthropic_settings = llm_configs.get("anthropic", {})
108129
self.api_key = os.getenv("ANTHROPIC_API_KEY")
@@ -329,6 +350,55 @@ async def get_llm_action_json(self, messages: list[dict]) -> dict:
329350
logger.error(f"Failed to parse Together LLM response as JSON: {e}")
330351
logger.error(f"Together LLM response string was: {action_json_str}")
331352
raise LLMInterfaceError(f"Together LLM response was not valid JSON: {action_json_str}")
353+
elif self.provider == "mistral":
354+
payload = {
355+
"model": self.model_name,
356+
"messages": messages,
357+
"response_format": {"type": "json_object"}
358+
}
359+
if hasattr(self, 'temperature'):
360+
payload["temperature"] = self.temperature
361+
if hasattr(self, 'max_tokens'):
362+
payload["max_tokens"] = self.max_tokens
363+
364+
headers = {
365+
"Authorization": f"Bearer {self.api_key}",
366+
"Content-Type": "application/json"
367+
}
368+
369+
logger.debug(f"Sending request to Mistral API: {self.api_url}")
370+
logger.debug(f"Payload (first 200 chars of messages): {json.dumps(messages, indent=2)[:200]}")
371+
372+
try:
373+
async with httpx.AsyncClient() as client:
374+
timeout = self.mistral_settings.get("request_timeout", 60.0)
375+
response = await client.post(self.api_url, headers=headers, json=payload, timeout=timeout)
376+
377+
response.raise_for_status()
378+
result = response.json()
379+
logger.debug(f"Raw Mistral API response: {json.dumps(result, indent=2)}")
380+
381+
except httpx.HTTPStatusError as e:
382+
error_content = e.response.text
383+
logger.error(f"Mistral API request failed with status {e.response.status_code}: {error_content}")
384+
raise LLMInterfaceError(f"Mistral API request failed: {e.response.status_code} - {error_content}")
385+
except Exception as e:
386+
logger.error(f"Error during Mistral API call: {e}")
387+
raise LLMInterfaceError(f"Error during Mistral API call: {e}")
388+
389+
if result.get("choices") and result["choices"][0].get("message") and result["choices"][0]["message"].get("content"):
390+
action_json_str = result["choices"][0]["message"]["content"]
391+
logger.info("Successfully received response from Mistral LLM.")
392+
logger.debug(f"Mistral LLM response text (potential JSON): {action_json_str}")
393+
try:
394+
return json.loads(action_json_str)
395+
except json.JSONDecodeError as e:
396+
logger.error(f"Failed to parse Mistral LLM response as JSON: {e}")
397+
logger.error(f"Mistral LLM response string was: {action_json_str}")
398+
raise LLMInterfaceError(f"Mistral LLM response was not valid JSON: {action_json_str}")
399+
else:
400+
logger.error(f"Unexpected Mistral LLM response structure: {result}")
401+
raise LLMInterfaceError(f"Unexpected Mistral LLM response structure. Full response: {result}")
332402
elif self.provider == "bedrock":
333403
logger.warning("Bedrock get_llm_action_json called, but using mock response for now.")
334404
mock_response_str = """
@@ -364,13 +434,16 @@ async def get_llm_action_json(self, messages: list[dict]) -> dict:
364434
if not dummy_llm_config_file.exists():
365435
with open(dummy_llm_config_file, "w") as f:
366436
f.write("""
367-
llm_provider: "gemini" # or "openai" or "bedrock"
437+
llm_provider: "gemini" # or "openai" or "mistral" or "bedrock"
368438
gemini:
369439
api_key: "YOUR_GEMINI_API_KEY_PLACEHOLDER"
370440
default_model: "gemini-2.0-flash"
371441
openai:
372442
api_key: "YOUR_OPENAI_API_KEY_PLACEHOLDER"
373443
default_model: "gpt-4o-mini"
444+
mistral:
445+
api_key: "YOUR_MISTRAL_API_KEY_PLACEHOLDER"
446+
default_model: "open-mistral-7b"
374447
bedrock:
375448
region_name: "us-east-1"
376449
profile_name: ""

tests/unit/test_llm_interface.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ def setUp(self):
7272
"request_timeout": 30.0
7373
}
7474
}
75+
self.mistral_config_data = {
76+
"llm_provider": "mistral",
77+
"mistral": {
78+
"api_key": "test_mistral_api_key",
79+
"default_model": "open-mistral-7b",
80+
"temperature": 0.5,
81+
"max_tokens": 512,
82+
"request_timeout": 30.0
83+
}
84+
}
7585
self.bedrock_config_data = {
7686
"llm_provider": "bedrock",
7787
"bedrock": {
@@ -273,6 +283,43 @@ def test_init_together_placeholder_api_key_becomes_empty(self):
273283
self.assertEqual(llm_interface.api_key, "")
274284
self.assertTrue(any("Using empty API key for Together.ai" in msg for msg in log_watcher.output))
275285

286+
@patch.dict(os.environ, {}, clear=True)
287+
def test_init_mistral_api_key_from_env(self):
288+
"""Test Mistral API key loaded from environment variable."""
289+
os.environ["MISTRAL_API_KEY"] = "env_mistral_key"
290+
self._create_llm_config_file(self.mistral_config_data)
291+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
292+
self.assertEqual(llm_interface.provider, "mistral")
293+
self.assertEqual(llm_interface.api_key, "env_mistral_key")
294+
self.assertEqual(llm_interface.model_name, self.mistral_config_data["mistral"]["default_model"])
295+
296+
@patch.dict(os.environ, {}, clear=True)
297+
def test_init_mistral_api_key_from_yaml_if_not_in_env(self):
298+
"""Test Mistral API key loaded from YAML when not in environment."""
299+
self._create_llm_config_file(self.mistral_config_data)
300+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
301+
self.assertEqual(llm_interface.provider, "mistral")
302+
self.assertEqual(llm_interface.api_key, "test_mistral_api_key")
303+
self.assertEqual(llm_interface.model_name, self.mistral_config_data["mistral"]["default_model"])
304+
305+
@patch.dict(os.environ, {}, clear=True)
306+
def test_init_mistral_config_section_missing_and_env_var_unset_raises_error(self):
307+
"""Test ConfigError if Mistral provider section is missing and env var not set."""
308+
config_missing = {"llm_provider": "mistral"}
309+
self._create_llm_config_file(config_missing)
310+
with self.assertRaisesRegex(ConfigError, "Mistral configuration missing and MISTRAL_API_KEY environment variable not set."):
311+
LLMInterface(config_loader=self.mock_config_loader)
312+
313+
@patch.dict(os.environ, {}, clear=True)
314+
def test_init_mistral_placeholder_api_key_becomes_empty(self):
315+
"""Test placeholder Mistral API key becomes empty when env var not set."""
316+
config_placeholder = {"llm_provider": "mistral", "mistral": {"api_key": "YOUR_MISTRAL_API_KEY_PLACEHOLDER", "default_model": "m"}}
317+
self._create_llm_config_file(config_placeholder)
318+
with self.assertLogs('llm_controller.llm_interface', level='WARNING') as log_watcher:
319+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
320+
self.assertEqual(llm_interface.api_key, "")
321+
self.assertTrue(any("Using empty API key for Mistral" in msg for msg in log_watcher.output))
322+
276323
# --- End API Key Handling Tests ---
277324

278325
# --- Gemini API Call Tests ---
@@ -519,6 +566,83 @@ async def test_get_llm_action_json_together_malformed_json(self, MockTogether):
519566
await llm_interface.get_llm_action_json([{"role": "user", "content": "tap"}])
520567
# --- End Together API Call Tests ---
521568

569+
# --- Mistral API Call Tests ---
570+
@patch('llm_controller.llm_interface.httpx.AsyncClient')
571+
async def test_get_llm_action_json_mistral_success(self, MockAsyncClient):
572+
"""Test successful Mistral API call and JSON response parsing."""
573+
self._create_llm_config_file(self.mistral_config_data)
574+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
575+
576+
mock_api_resp_structure = {
577+
"choices": [
578+
{"message": {"content": '{"action": "tap"}'}}
579+
]
580+
}
581+
582+
mock_response = MagicMock(spec=httpx.Response)
583+
mock_response.status_code = 200
584+
mock_response.json.return_value = mock_api_resp_structure
585+
586+
mock_client_instance = MockAsyncClient.return_value.__aenter__.return_value
587+
mock_client_instance.post = AsyncMock(return_value=mock_response)
588+
589+
result_json = await llm_interface.get_llm_action_json([{"role": "user", "content": "tap"}])
590+
self.assertEqual(result_json, {"action": "tap"})
591+
mock_client_instance.post.assert_called_once()
592+
593+
@patch('llm_controller.llm_interface.httpx.AsyncClient')
594+
async def test_get_llm_action_json_mistral_http_status_error(self, MockAsyncClient):
595+
"""Test Mistral API HTTPStatusError handling."""
596+
self._create_llm_config_file(self.mistral_config_data)
597+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
598+
599+
mock_http_response = MagicMock(spec=httpx.Response)
600+
mock_http_response.status_code = 429
601+
mock_http_response.text = "Rate limited"
602+
603+
mock_client_instance = MockAsyncClient.return_value.__aenter__.return_value
604+
mock_client_instance.post = AsyncMock(side_effect=httpx.HTTPStatusError(
605+
"Too Many Requests", request=MagicMock(), response=mock_http_response
606+
))
607+
608+
with self.assertRaisesRegex(LLMInterfaceError, "Mistral API request failed: 429 - Rate limited"):
609+
await llm_interface.get_llm_action_json([{"role": "user", "content": "tap"}])
610+
611+
@patch('llm_controller.llm_interface.httpx.AsyncClient')
612+
async def test_get_llm_action_json_mistral_request_error(self, MockAsyncClient):
613+
"""Test Mistral API request error handling."""
614+
self._create_llm_config_file(self.mistral_config_data)
615+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
616+
617+
mock_client_instance = MockAsyncClient.return_value.__aenter__.return_value
618+
mock_client_instance.post = AsyncMock(side_effect=httpx.ConnectError("conn fail"))
619+
620+
with self.assertRaisesRegex(LLMInterfaceError, "Error during Mistral API call: conn fail"):
621+
await llm_interface.get_llm_action_json([{"role": "user", "content": "tap"}])
622+
623+
@patch('llm_controller.llm_interface.httpx.AsyncClient')
624+
async def test_get_llm_action_json_mistral_malformed_json(self, MockAsyncClient):
625+
"""Test Mistral API returning malformed JSON."""
626+
self._create_llm_config_file(self.mistral_config_data)
627+
llm_interface = LLMInterface(config_loader=self.mock_config_loader)
628+
629+
mock_api_resp_structure = {
630+
"choices": [
631+
{"message": {"content": 'not json'}}
632+
]
633+
}
634+
635+
mock_response = MagicMock(spec=httpx.Response)
636+
mock_response.status_code = 200
637+
mock_response.json.return_value = mock_api_resp_structure
638+
639+
mock_client_instance = MockAsyncClient.return_value.__aenter__.return_value
640+
mock_client_instance.post = AsyncMock(return_value=mock_response)
641+
642+
with self.assertRaisesRegex(LLMInterfaceError, "Mistral LLM response was not valid JSON"):
643+
await llm_interface.get_llm_action_json([{"role": "user", "content": "tap"}])
644+
# --- End Mistral API Call Tests ---
645+
522646
# --- Bedrock API Call Tests (current mock) ---
523647
async def test_get_llm_action_json_bedrock_mock_response(self):
524648
"""Test Bedrock provider path (which currently returns a mock response)."""

0 commit comments

Comments
 (0)