Skip to content

Commit e957174

Browse files
kflililili-ms
andauthored
feat(llm): add Azure OpenAI integration (#23)
- Add Azure OpenAI as a new LLM provider option - Configure Azure model deployment via AZURE_OPENAI_MODEL_DEPLOYMENT env var - Add unit tests for Azure OpenAI client creation and query functionality - Update environment example with Azure OpenAI credentials - Update CLI help text to include Azure provider option - Set default Azure model to 'gpt-4o-ms' with configurable fallback Technical details: - Integrate AzureOpenAI client with api version '2024-02-15-preview' - Add Azure endpoint configuration - Extend test coverage for Azure-specific scenarios Co-authored-by: Li Li <lili5@microsoft.com>
1 parent 9ea2811 commit e957174

File tree

4 files changed

+58
-6
lines changed

4 files changed

+58
-6
lines changed

.cursorrules

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ venv/bin/python ./tools/llm_api.py --prompt "What is the capital of France?" --p
2222

2323
The LLM API supports multiple providers:
2424
- OpenAI (default, model: gpt-4o)
25+
- Azure OpenAI (model: configured via AZURE_OPENAI_MODEL_DEPLOYMENT in .env file, defaults to gpt-4o-ms)
2526
- DeepSeek (model: deepseek-chat)
2627
- Anthropic (model: claude-3-sonnet-20240229)
2728
- Gemini (model: gemini-pro)

.env.example

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
OPENAI_API_KEY=your_openai_api_key_here
22
ANTHROPIC_API_KEY=your_anthropic_api_key_here
33
DEEPSEEK_API_KEY=your_deepseek_api_key_here
4-
GOOGLE_API_KEY=your_google_api_key_here
4+
GOOGLE_API_KEY=your_google_api_key_here
5+
AZURE_OPENAI_API_KEY=your_azure_openai_api_key_here
6+
AZURE_OPENAI_MODEL_DEPLOYMENT=gpt-4o-ms

tests/test_llm_api.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,22 @@ def setUp(self):
117117
'OPENAI_API_KEY': 'test-openai-key',
118118
'DEEPSEEK_API_KEY': 'test-deepseek-key',
119119
'ANTHROPIC_API_KEY': 'test-anthropic-key',
120-
'GOOGLE_API_KEY': 'test-google-key'
120+
'GOOGLE_API_KEY': 'test-google-key',
121+
'AZURE_OPENAI_API_KEY': 'test-azure-key',
122+
'AZURE_OPENAI_MODEL_DEPLOYMENT': 'test-model-deployment'
121123
})
122124
self.env_patcher.start()
123-
125+
126+
# Set up Azure OpenAI mock
127+
self.mock_azure_response = MagicMock()
128+
self.mock_azure_choice = MagicMock()
129+
self.mock_azure_message = MagicMock()
130+
self.mock_azure_message.content = "Test Azure OpenAI response"
131+
self.mock_azure_choice.message = self.mock_azure_message
132+
self.mock_azure_response.choices = [self.mock_azure_choice]
133+
self.mock_azure_client = MagicMock()
134+
self.mock_azure_client.chat.completions.create.return_value = self.mock_azure_response
135+
124136
def tearDown(self):
125137
self.env_patcher.stop()
126138

@@ -132,6 +144,18 @@ def test_create_openai_client(self, mock_openai):
132144
mock_openai.assert_called_once_with(api_key='test-openai-key')
133145
self.assertEqual(client, self.mock_openai_client)
134146

147+
@unittest.skipIf(skip_llm_tests, skip_message)
148+
@patch('tools.llm_api.AzureOpenAI')
149+
def test_create_azure_client(self, mock_azure):
150+
mock_azure.return_value = self.mock_azure_client
151+
client = create_llm_client("azure")
152+
mock_azure.assert_called_once_with(
153+
api_key='test-azure-key',
154+
api_version="2024-02-15-preview",
155+
azure_endpoint="https://msopenai.openai.azure.com"
156+
)
157+
self.assertEqual(client, self.mock_azure_client)
158+
135159
@unittest.skipIf(skip_llm_tests, skip_message)
136160
@patch('tools.llm_api.OpenAI')
137161
def test_create_deepseek_client(self, mock_openai):
@@ -186,6 +210,18 @@ def test_query_openai(self, mock_create_client):
186210
temperature=0.7
187211
)
188212

213+
@unittest.skipIf(skip_llm_tests, skip_message)
214+
@patch('tools.llm_api.create_llm_client')
215+
def test_query_azure(self, mock_create_client):
216+
mock_create_client.return_value = self.mock_azure_client
217+
response = query_llm("Test prompt", provider="azure")
218+
self.assertEqual(response, "Test Azure OpenAI response")
219+
self.mock_azure_client.chat.completions.create.assert_called_once_with(
220+
model=os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms'),
221+
messages=[{"role": "user", "content": "Test prompt"}],
222+
temperature=0.7
223+
)
224+
189225
@unittest.skipIf(skip_llm_tests, skip_message)
190226
@patch('tools.llm_api.create_llm_client')
191227
def test_query_deepseek(self, mock_create_client):

tools/llm_api.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env /workspace/tmp_windsurf/venv/bin/python3
22

33
import google.generativeai as genai
4-
from openai import OpenAI
4+
from openai import OpenAI, AzureOpenAI
55
from anthropic import Anthropic
66
import argparse
77
import os
@@ -51,6 +51,15 @@ def create_llm_client(provider="openai"):
5151
return OpenAI(
5252
api_key=api_key
5353
)
54+
elif provider == "azure":
55+
api_key = os.getenv('AZURE_OPENAI_API_KEY')
56+
if not api_key:
57+
raise ValueError("AZURE_OPENAI_API_KEY not found in environment variables")
58+
return AzureOpenAI(
59+
api_key=api_key,
60+
api_version="2024-02-15-preview",
61+
azure_endpoint="https://msopenai.openai.azure.com"
62+
)
5463
elif provider == "deepseek":
5564
api_key = os.getenv('DEEPSEEK_API_KEY')
5665
if not api_key:
@@ -89,6 +98,8 @@ def query_llm(prompt, client=None, model=None, provider="openai"):
8998
if model is None:
9099
if provider == "openai":
91100
model = "gpt-4o"
101+
elif provider == "azure":
102+
model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback
92103
elif provider == "deepseek":
93104
model = "deepseek-chat"
94105
elif provider == "anthropic":
@@ -98,7 +109,7 @@ def query_llm(prompt, client=None, model=None, provider="openai"):
98109
elif provider == "local":
99110
model = "Qwen/Qwen2.5-32B-Instruct-AWQ"
100111

101-
if provider == "openai" or provider == "local" or provider == "deepseek":
112+
if provider in ["openai", "local", "deepseek", "azure"]:
102113
response = client.chat.completions.create(
103114
model=model,
104115
messages=[
@@ -127,7 +138,7 @@ def query_llm(prompt, client=None, model=None, provider="openai"):
127138
def main():
128139
parser = argparse.ArgumentParser(description='Query an LLM with a prompt')
129140
parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM', required=True)
130-
parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek'], default='openai', help='The API provider to use')
141+
parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use')
131142
parser.add_argument('--model', type=str, help='The model to use (default depends on provider)')
132143
args = parser.parse_args()
133144

@@ -140,6 +151,8 @@ def main():
140151
args.model = "claude-3-5-sonnet-20241022"
141152
elif args.provider == 'gemini':
142153
args.model = "gemini-2.0-flash-exp"
154+
elif args.provider == 'azure':
155+
args.model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback
143156

144157
client = create_llm_client(args.provider)
145158
response = query_llm(args.prompt, client, model=args.model, provider=args.provider)

0 commit comments

Comments
 (0)