diff --git a/.cursorrules b/.cursorrules index 3df2128..7d5a57c 100644 --- a/.cursorrules +++ b/.cursorrules @@ -22,6 +22,7 @@ venv/bin/python ./tools/llm_api.py --prompt "What is the capital of France?" --p The LLM API supports multiple providers: - OpenAI (default, model: gpt-4o) +- Azure OpenAI (model: configured via AZURE_OPENAI_MODEL_DEPLOYMENT in .env file, defaults to gpt-4o-ms) - DeepSeek (model: deepseek-chat) - Anthropic (model: claude-3-sonnet-20240229) - Gemini (model: gemini-pro) diff --git a/.env.example b/.env.example index 99ab0de..12ff991 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,6 @@ OPENAI_API_KEY=your_openai_api_key_here ANTHROPIC_API_KEY=your_anthropic_api_key_here DEEPSEEK_API_KEY=your_deepseek_api_key_here -GOOGLE_API_KEY=your_google_api_key_here \ No newline at end of file +GOOGLE_API_KEY=your_google_api_key_here +AZURE_OPENAI_API_KEY=your_azure_openai_api_key_here +AZURE_OPENAI_MODEL_DEPLOYMENT=gpt-4o-ms \ No newline at end of file diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 60b1ed9..61d08e2 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -117,10 +117,22 @@ def setUp(self): 'OPENAI_API_KEY': 'test-openai-key', 'DEEPSEEK_API_KEY': 'test-deepseek-key', 'ANTHROPIC_API_KEY': 'test-anthropic-key', - 'GOOGLE_API_KEY': 'test-google-key' + 'GOOGLE_API_KEY': 'test-google-key', + 'AZURE_OPENAI_API_KEY': 'test-azure-key', + 'AZURE_OPENAI_MODEL_DEPLOYMENT': 'test-model-deployment' }) self.env_patcher.start() - + + # Set up Azure OpenAI mock + self.mock_azure_response = MagicMock() + self.mock_azure_choice = MagicMock() + self.mock_azure_message = MagicMock() + self.mock_azure_message.content = "Test Azure OpenAI response" + self.mock_azure_choice.message = self.mock_azure_message + self.mock_azure_response.choices = [self.mock_azure_choice] + self.mock_azure_client = MagicMock() + self.mock_azure_client.chat.completions.create.return_value = self.mock_azure_response + def tearDown(self): self.env_patcher.stop() @@ -132,6 +144,18 @@ def test_create_openai_client(self, mock_openai): mock_openai.assert_called_once_with(api_key='test-openai-key') self.assertEqual(client, self.mock_openai_client) + @unittest.skipIf(skip_llm_tests, skip_message) + @patch('tools.llm_api.AzureOpenAI') + def test_create_azure_client(self, mock_azure): + mock_azure.return_value = self.mock_azure_client + client = create_llm_client("azure") + mock_azure.assert_called_once_with( + api_key='test-azure-key', + api_version="2024-02-15-preview", + azure_endpoint="https://msopenai.openai.azure.com" + ) + self.assertEqual(client, self.mock_azure_client) + @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.OpenAI') def test_create_deepseek_client(self, mock_openai): @@ -186,6 +210,18 @@ def test_query_openai(self, mock_create_client): temperature=0.7 ) + @unittest.skipIf(skip_llm_tests, skip_message) + @patch('tools.llm_api.create_llm_client') + def test_query_azure(self, mock_create_client): + mock_create_client.return_value = self.mock_azure_client + response = query_llm("Test prompt", provider="azure") + self.assertEqual(response, "Test Azure OpenAI response") + self.mock_azure_client.chat.completions.create.assert_called_once_with( + model=os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms'), + messages=[{"role": "user", "content": "Test prompt"}], + temperature=0.7 + ) + @unittest.skipIf(skip_llm_tests, skip_message) @patch('tools.llm_api.create_llm_client') def test_query_deepseek(self, mock_create_client): diff --git a/tools/llm_api.py b/tools/llm_api.py index 6fdf4dd..235746c 100644 --- a/tools/llm_api.py +++ b/tools/llm_api.py @@ -1,7 +1,7 @@ #!/usr/bin/env /workspace/tmp_windsurf/venv/bin/python3 import google.generativeai as genai -from openai import OpenAI +from openai import OpenAI, AzureOpenAI from anthropic import Anthropic import argparse import os @@ -51,6 +51,15 @@ def create_llm_client(provider="openai"): return OpenAI( api_key=api_key ) + elif provider == "azure": + api_key = os.getenv('AZURE_OPENAI_API_KEY') + if not api_key: + raise ValueError("AZURE_OPENAI_API_KEY not found in environment variables") + return AzureOpenAI( + api_key=api_key, + api_version="2024-02-15-preview", + azure_endpoint="https://msopenai.openai.azure.com" + ) elif provider == "deepseek": api_key = os.getenv('DEEPSEEK_API_KEY') if not api_key: @@ -89,6 +98,8 @@ def query_llm(prompt, client=None, model=None, provider="openai"): if model is None: if provider == "openai": model = "gpt-4o" + elif provider == "azure": + model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback elif provider == "deepseek": model = "deepseek-chat" elif provider == "anthropic": @@ -98,7 +109,7 @@ def query_llm(prompt, client=None, model=None, provider="openai"): elif provider == "local": model = "Qwen/Qwen2.5-32B-Instruct-AWQ" - if provider == "openai" or provider == "local" or provider == "deepseek": + if provider in ["openai", "local", "deepseek", "azure"]: response = client.chat.completions.create( model=model, messages=[ @@ -127,7 +138,7 @@ def query_llm(prompt, client=None, model=None, provider="openai"): def main(): parser = argparse.ArgumentParser(description='Query an LLM with a prompt') parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM', required=True) - parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek'], default='openai', help='The API provider to use') + parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use') parser.add_argument('--model', type=str, help='The model to use (default depends on provider)') args = parser.parse_args() @@ -140,6 +151,8 @@ def main(): args.model = "claude-3-5-sonnet-20241022" elif args.provider == 'gemini': args.model = "gemini-2.0-flash-exp" + elif args.provider == 'azure': + args.model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback client = create_llm_client(args.provider) response = query_llm(args.prompt, client, model=args.model, provider=args.provider)