Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cursorrules
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ venv/bin/python ./tools/llm_api.py --prompt "What is the capital of France?" --p

The LLM API supports multiple providers:
- OpenAI (default, model: gpt-4o)
- Azure OpenAI (model: configured via AZURE_OPENAI_MODEL_DEPLOYMENT in .env file, defaults to gpt-4o-ms)
- DeepSeek (model: deepseek-chat)
- Anthropic (model: claude-3-sonnet-20240229)
- Gemini (model: gemini-pro)
Expand Down
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
OPENAI_API_KEY=your_openai_api_key_here
ANTHROPIC_API_KEY=your_anthropic_api_key_here
DEEPSEEK_API_KEY=your_deepseek_api_key_here
GOOGLE_API_KEY=your_google_api_key_here
GOOGLE_API_KEY=your_google_api_key_here
AZURE_OPENAI_API_KEY=your_azure_openai_api_key_here
AZURE_OPENAI_MODEL_DEPLOYMENT=gpt-4o-ms
40 changes: 38 additions & 2 deletions tests/test_llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,22 @@ def setUp(self):
'OPENAI_API_KEY': 'test-openai-key',
'DEEPSEEK_API_KEY': 'test-deepseek-key',
'ANTHROPIC_API_KEY': 'test-anthropic-key',
'GOOGLE_API_KEY': 'test-google-key'
'GOOGLE_API_KEY': 'test-google-key',
'AZURE_OPENAI_API_KEY': 'test-azure-key',
'AZURE_OPENAI_MODEL_DEPLOYMENT': 'test-model-deployment'
})
self.env_patcher.start()


# Set up Azure OpenAI mock
self.mock_azure_response = MagicMock()
self.mock_azure_choice = MagicMock()
self.mock_azure_message = MagicMock()
self.mock_azure_message.content = "Test Azure OpenAI response"
self.mock_azure_choice.message = self.mock_azure_message
self.mock_azure_response.choices = [self.mock_azure_choice]
self.mock_azure_client = MagicMock()
self.mock_azure_client.chat.completions.create.return_value = self.mock_azure_response

def tearDown(self):
self.env_patcher.stop()

Expand All @@ -132,6 +144,18 @@ def test_create_openai_client(self, mock_openai):
mock_openai.assert_called_once_with(api_key='test-openai-key')
self.assertEqual(client, self.mock_openai_client)

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.AzureOpenAI')
def test_create_azure_client(self, mock_azure):
mock_azure.return_value = self.mock_azure_client
client = create_llm_client("azure")
mock_azure.assert_called_once_with(
api_key='test-azure-key',
api_version="2024-02-15-preview",
azure_endpoint="https://msopenai.openai.azure.com"
)
self.assertEqual(client, self.mock_azure_client)

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.OpenAI')
def test_create_deepseek_client(self, mock_openai):
Expand Down Expand Up @@ -186,6 +210,18 @@ def test_query_openai(self, mock_create_client):
temperature=0.7
)

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.create_llm_client')
def test_query_azure(self, mock_create_client):
mock_create_client.return_value = self.mock_azure_client
response = query_llm("Test prompt", provider="azure")
self.assertEqual(response, "Test Azure OpenAI response")
self.mock_azure_client.chat.completions.create.assert_called_once_with(
model=os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms'),
messages=[{"role": "user", "content": "Test prompt"}],
temperature=0.7
)

@unittest.skipIf(skip_llm_tests, skip_message)
@patch('tools.llm_api.create_llm_client')
def test_query_deepseek(self, mock_create_client):
Expand Down
19 changes: 16 additions & 3 deletions tools/llm_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env /workspace/tmp_windsurf/venv/bin/python3

import google.generativeai as genai
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from anthropic import Anthropic
import argparse
import os
Expand Down Expand Up @@ -51,6 +51,15 @@ def create_llm_client(provider="openai"):
return OpenAI(
api_key=api_key
)
elif provider == "azure":
api_key = os.getenv('AZURE_OPENAI_API_KEY')
if not api_key:
raise ValueError("AZURE_OPENAI_API_KEY not found in environment variables")
return AzureOpenAI(
api_key=api_key,
api_version="2024-02-15-preview",
azure_endpoint="https://msopenai.openai.azure.com"
)
elif provider == "deepseek":
api_key = os.getenv('DEEPSEEK_API_KEY')
if not api_key:
Expand Down Expand Up @@ -89,6 +98,8 @@ def query_llm(prompt, client=None, model=None, provider="openai"):
if model is None:
if provider == "openai":
model = "gpt-4o"
elif provider == "azure":
model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback
elif provider == "deepseek":
model = "deepseek-chat"
elif provider == "anthropic":
Expand All @@ -98,7 +109,7 @@ def query_llm(prompt, client=None, model=None, provider="openai"):
elif provider == "local":
model = "Qwen/Qwen2.5-32B-Instruct-AWQ"

if provider == "openai" or provider == "local" or provider == "deepseek":
if provider in ["openai", "local", "deepseek", "azure"]:
response = client.chat.completions.create(
model=model,
messages=[
Expand Down Expand Up @@ -127,7 +138,7 @@ def query_llm(prompt, client=None, model=None, provider="openai"):
def main():
parser = argparse.ArgumentParser(description='Query an LLM with a prompt')
parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM', required=True)
parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek'], default='openai', help='The API provider to use')
parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use')
parser.add_argument('--model', type=str, help='The model to use (default depends on provider)')
args = parser.parse_args()

Expand All @@ -140,6 +151,8 @@ def main():
args.model = "claude-3-5-sonnet-20241022"
elif args.provider == 'gemini':
args.model = "gemini-2.0-flash-exp"
elif args.provider == 'azure':
args.model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback

client = create_llm_client(args.provider)
response = query_llm(args.prompt, client, model=args.model, provider=args.provider)
Expand Down