Skip to content

Commit 9496396

Browse files
committed
adding tests for Azure OpenAI provider
1 parent 90cbc4e commit 9496396

File tree

2 files changed

+126
-37
lines changed

2 files changed

+126
-37
lines changed

python_gpt_po/services/model_manager.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
class ModelManager:
1515
"""Class to manage models from different providers."""
1616

17-
# pylint: disable=too-many-branches
1817
@staticmethod
1918
def get_available_models(provider_clients: ProviderClients, provider: ModelProvider) -> List[str]:
2019
"""Retrieve available models from a specific provider."""
@@ -73,17 +72,24 @@ def get_available_models(provider_clients: ProviderClients, provider: ModelProvi
7372
logging.error("DeepSeek API key not set")
7473

7574
elif provider == ModelProvider.AZURE_OPENAI:
76-
if provider_clients.azure_openai_client:
77-
response = provider_clients.azure_openai_client.models.list()
78-
models = [model.id for model in response.data]
79-
else:
80-
logging.error("Azure OpenAI client not initialized")
75+
return ModelManager._get_azure_openai_models(provider_clients)
8176

8277
except Exception as e:
8378
logging.error("Error fetching models from %s: %s", provider.value, str(e))
8479

8580
return models
86-
# pylint: enable=too-many-branches
81+
82+
83+
@staticmethod
84+
def _get_azure_openai_models(provider_clients: ProviderClients) -> List[str]:
85+
"""Retrieve models from Azure OpenAI."""
86+
if provider_clients.azure_openai_client:
87+
response = provider_clients.azure_openai_client.models.list()
88+
return [model.id for model in response.data]
89+
90+
logging.error("Azure OpenAI client not initialized")
91+
return []
92+
8793

8894
@staticmethod
8995
def validate_model(provider_clients: ProviderClients, provider: ModelProvider, model: str) -> bool:
@@ -119,7 +125,7 @@ def get_default_model(provider: ModelProvider) -> str:
119125
ModelProvider.OPENAI: "gpt-4o-mini",
120126
ModelProvider.ANTHROPIC: "claude-3-5-haiku-latest",
121127
ModelProvider.DEEPSEEK: "deepseek-chat",
122-
ModelProvider.AZURE_OPENAI: "",
128+
ModelProvider.AZURE_OPENAI: "gpt-35-turbo",
123129
}
124130
return default_models.get(provider, "")
125131

python_gpt_po/tests/test_multi_provider.py

Lines changed: 112 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import logging
6+
import os
67
from unittest.mock import MagicMock, patch
78

89
import pytest
@@ -47,7 +48,7 @@
4748
"""
4849

4950
# Sample model responses for different providers
50-
OPENAI_MODELS_RESPONSE = {
51+
OPENAI_MODELS_RESPONSE: dict[str, object] = {
5152
"data": [
5253
{"id": "gpt-4"},
5354
{"id": "gpt-4-turbo"},
@@ -57,7 +58,9 @@
5758
"object": "list"
5859
}
5960

60-
ANTHROPIC_MODELS_RESPONSE = {
61+
AZURE_OPENAI_MODELS_RESPONSE: dict[str, object] = OPENAI_MODELS_RESPONSE
62+
63+
ANTHROPIC_MODELS_RESPONSE: dict[str, object] = {
6164
"data": [
6265
{"type": "model", "id": "claude-3-7-sonnet-20250219", "display_name": "Claude 3.7 Sonnet", "created_at": "2025-02-19T00:00:00Z"},
6366
{"type": "model", "id": "claude-3-5-sonnet-20241022", "display_name": "Claude 3.5 Sonnet", "created_at": "2024-10-22T00:00:00Z"},
@@ -69,15 +72,15 @@
6972
"last_id": "claude-3-opus-20240229"
7073
}
7174

72-
DEEPSEEK_MODELS_RESPONSE = {
75+
DEEPSEEK_MODELS_RESPONSE: dict[str, object] = {
7376
"data": [
7477
{"id": "deepseek-chat"},
7578
{"id": "deepseek-coder"}
7679
]
7780
}
7881

7982
# Translation responses for different providers
80-
OPENAI_TRANSLATION_RESPONSE = {
83+
OPENAI_TRANSLATION_RESPONSE: dict[str, object] = {
8184
"choices": [
8285
{
8386
"message": {
@@ -87,15 +90,17 @@
8790
]
8891
}
8992

90-
ANTHROPIC_TRANSLATION_RESPONSE = {
93+
AZURE_OPENAI_TRANSLATION_RESPONSE: dict[str, object] = OPENAI_TRANSLATION_RESPONSE
94+
95+
ANTHROPIC_TRANSLATION_RESPONSE: dict[str, object] = {
9196
"content": [
9297
{
9398
"text": '["Bonjour", "Monde", "Bienvenue dans notre application", "Au revoir"]'
9499
}
95100
]
96101
}
97102

98-
DEEPSEEK_TRANSLATION_RESPONSE = {
103+
DEEPSEEK_TRANSLATION_RESPONSE: dict[str, object] = {
99104
"choices": [
100105
{
101106
"message": {
@@ -107,28 +112,30 @@
107112

108113

109114
@pytest.fixture
110-
def temp_po_file(tmp_path):
115+
def temp_po_file(tmp_path: str) -> str:
111116
"""Create a temporary PO file for testing."""
112-
po_file_path = tmp_path / "test.po"
117+
po_file_path = os.path.join(tmp_path, "test.po")
113118
with open(po_file_path, "w", encoding="utf-8") as f:
114119
f.write(SAMPLE_PO_CONTENT)
115120
return str(po_file_path)
116121

117122

118123
@pytest.fixture
119-
def mock_provider_clients():
124+
def mock_provider_clients() -> ProviderClients:
120125
"""Mock provider clients for testing."""
121126
clients = ProviderClients()
122127
clients.openai_client = MagicMock()
123128
clients.anthropic_client = MagicMock()
124129
clients.anthropic_client.api_key = "sk-ant-mock-key"
125130
clients.deepseek_api_key = "sk-deepseek-mock-key"
126131
clients.deepseek_base_url = "https://api.deepseek.com/v1"
132+
clients.azure_openai_client = MagicMock()
133+
clients.azure_openai_client.api_key = "sk-aoi-mock-key"
127134
return clients
128135

129136

130137
@pytest.fixture
131-
def translation_config_openai(mock_provider_clients):
138+
def translation_config_openai(mock_provider_clients: ProviderClients) -> TranslationConfig:
132139
"""Create an OpenAI translation config for testing."""
133140
return TranslationConfig(
134141
provider_clients=mock_provider_clients,
@@ -141,7 +148,20 @@ def translation_config_openai(mock_provider_clients):
141148

142149

143150
@pytest.fixture
144-
def translation_config_anthropic(mock_provider_clients):
151+
def translation_config_azure_openai(mock_provider_clients: ProviderClients) -> TranslationConfig:
152+
"""Create an OpenAI translation config for testing."""
153+
return TranslationConfig(
154+
provider_clients=mock_provider_clients,
155+
provider=ModelProvider.AZURE_OPENAI,
156+
model="gpt-3.5-turbo",
157+
bulk_mode=True,
158+
fuzzy=False,
159+
folder_language=False
160+
)
161+
162+
163+
@pytest.fixture
164+
def translation_config_anthropic(mock_provider_clients: ProviderClients) -> TranslationConfig:
145165
"""Create an Anthropic translation config for testing."""
146166
return TranslationConfig(
147167
provider_clients=mock_provider_clients,
@@ -154,7 +174,7 @@ def translation_config_anthropic(mock_provider_clients):
154174

155175

156176
@pytest.fixture
157-
def translation_config_deepseek(mock_provider_clients):
177+
def translation_config_deepseek(mock_provider_clients: ProviderClients) -> TranslationConfig:
158178
"""Create a DeepSeek translation config for testing."""
159179
return TranslationConfig(
160180
provider_clients=mock_provider_clients,
@@ -167,25 +187,31 @@ def translation_config_deepseek(mock_provider_clients):
167187

168188

169189
@pytest.fixture
170-
def translation_service_openai(translation_config_openai):
190+
def translation_service_openai(translation_config_openai: TranslationConfig) -> TranslationService:
171191
"""Create an OpenAI translation service for testing."""
172192
return TranslationService(config=translation_config_openai)
173193

174194

175195
@pytest.fixture
176-
def translation_service_anthropic(translation_config_anthropic):
196+
def translation_service_azure_openai(translation_config_azure_openai: TranslationConfig) -> TranslationService:
197+
"""Create an Azure OpenAI translation service for testing."""
198+
return TranslationService(config=translation_config_azure_openai)
199+
200+
201+
@pytest.fixture
202+
def translation_service_anthropic(translation_config_anthropic: TranslationConfig) -> TranslationService:
177203
"""Create an Anthropic translation service for testing."""
178204
return TranslationService(config=translation_config_anthropic)
179205

180206

181207
@pytest.fixture
182-
def translation_service_deepseek(translation_config_deepseek):
208+
def translation_service_deepseek(translation_config_deepseek: TranslationConfig) -> TranslationService:
183209
"""Create a DeepSeek translation service for testing."""
184210
return TranslationService(config=translation_config_deepseek)
185211

186212

187213
@patch('requests.get')
188-
def test_get_openai_models(mock_get, mock_provider_clients):
214+
def test_get_openai_models(mock_get, mock_provider_clients: ProviderClients):
189215
"""Test getting OpenAI models."""
190216
# Setup mock response
191217
mock_response = MagicMock()
@@ -206,8 +232,30 @@ def test_get_openai_models(mock_get, mock_provider_clients):
206232
assert "gpt-4" in models
207233

208234

235+
@patch('requests.get')
236+
def test_get_ayure_openai_models(mock_get, mock_provider_clients: ProviderClients):
237+
"""Test getting OpenAI models."""
238+
# Setup mock response
239+
mock_response = MagicMock()
240+
mock_response.json.return_value = AZURE_OPENAI_MODELS_RESPONSE
241+
mock_response.raise_for_status = MagicMock()
242+
mock_get.return_value = mock_response
243+
244+
# Mock the OpenAI client's models.list method
245+
models_list_mock = MagicMock()
246+
models_list_mock.data = [MagicMock(id="gpt-4"), MagicMock(id="gpt-3.5-turbo")]
247+
mock_provider_clients.azure_openai_client.models.list.return_value = models_list_mock
248+
249+
# Call the function
250+
model_manager = ModelManager()
251+
models = model_manager.get_available_models(mock_provider_clients, ModelProvider.AZURE_OPENAI)
252+
253+
# Assert models are returned correctly
254+
assert "gpt-3.5-turbo" in models
255+
256+
209257
@responses.activate
210-
def test_get_anthropic_models(mock_provider_clients):
258+
def test_get_anthropic_models(mock_provider_clients: ProviderClients):
211259
"""Test getting Anthropic models."""
212260
# Setup mock response
213261
responses.add(
@@ -227,7 +275,7 @@ def test_get_anthropic_models(mock_provider_clients):
227275

228276

229277
@responses.activate
230-
def test_get_deepseek_models(mock_provider_clients):
278+
def test_get_deepseek_models(mock_provider_clients: ProviderClients):
231279
"""Test getting DeepSeek models."""
232280
# Setup mock response
233281
responses.add(
@@ -247,7 +295,7 @@ def test_get_deepseek_models(mock_provider_clients):
247295

248296

249297
@patch('python_gpt_po.services.translation_service.requests.post')
250-
def test_translate_bulk_openai(mock_post, translation_service_openai):
298+
def test_translate_bulk_openai(mock_post, translation_service_openai: TranslationService):
251299
"""Test bulk translation with OpenAI."""
252300
# Setup mock response
253301
mock_response = MagicMock()
@@ -267,7 +315,27 @@ def test_translate_bulk_openai(mock_post, translation_service_openai):
267315

268316

269317
@patch('python_gpt_po.services.translation_service.requests.post')
270-
def test_translate_bulk_anthropic(mock_post, translation_service_anthropic):
318+
def test_translate_bulk_azure_openai(mock_post, translation_service_azure_openai: TranslationService):
319+
"""Test bulk translation with OpenAI."""
320+
# Setup mock response
321+
mock_response = MagicMock()
322+
mock_response.json.return_value = AZURE_OPENAI_TRANSLATION_RESPONSE
323+
mock_post.return_value = mock_response
324+
325+
# Call function
326+
translation_service_azure_openai.config.provider_clients.azure_openai_client.chat.completions.create.return_value = MagicMock(
327+
choices=[MagicMock(message=MagicMock(content='["Bonjour", "Monde", "Bienvenue dans notre application", "Au revoir"]'))]
328+
)
329+
330+
texts = ["Hello", "World", "Welcome to our application", "Goodbye"]
331+
translations = translation_service_azure_openai.translate_bulk(texts, "fr", "test.po")
332+
333+
# Assert translations are correct
334+
assert translations == ["Bonjour", "Monde", "Bienvenue dans notre application", "Au revoir"]
335+
336+
337+
@patch('python_gpt_po.services.translation_service.requests.post')
338+
def test_translate_bulk_anthropic(mock_post, translation_service_anthropic: TranslationService):
271339
"""Test bulk translation with Anthropic."""
272340
# Setup mock client response
273341
translation_service_anthropic.config.provider_clients.anthropic_client.messages.create.return_value = MagicMock(
@@ -282,7 +350,7 @@ def test_translate_bulk_anthropic(mock_post, translation_service_anthropic):
282350

283351

284352
@responses.activate
285-
def test_translate_bulk_deepseek(translation_service_deepseek):
353+
def test_translate_bulk_deepseek(translation_service_deepseek: TranslationService):
286354
"""Test bulk translation with DeepSeek."""
287355
# Setup mock response
288356
responses.add(
@@ -307,7 +375,7 @@ def test_translate_bulk_deepseek(translation_service_deepseek):
307375
assert translations == ["Bonjour", "Monde", "Bienvenue dans notre application", "Au revoir"]
308376

309377

310-
def test_clean_json_response(translation_service_deepseek):
378+
def test_clean_json_response(translation_service_deepseek: TranslationService):
311379
"""Test cleaning JSON responses from different formats."""
312380
# Test markdown code block format
313381
markdown_json = "```json\n[\"Bonjour\", \"Monde\"]\n```"
@@ -326,9 +394,12 @@ def test_clean_json_response(translation_service_deepseek):
326394

327395

328396
@patch('polib.pofile')
329-
def test_process_po_file_all_providers(mock_pofile, translation_service_openai,
330-
translation_service_anthropic,
331-
translation_service_deepseek, temp_po_file):
397+
def test_process_po_file_all_providers(mock_pofile,
398+
translation_service_openai: TranslationService,
399+
translation_service_anthropic: TranslationService,
400+
translation_service_deepseek: TranslationService,
401+
translation_service_azure_openai: TranslationService,
402+
temp_po_file: str):
332403
"""Test processing a PO file with all providers."""
333404
# Create a mock PO file
334405
mock_po = MagicMock()
@@ -346,7 +417,10 @@ def test_process_po_file_all_providers(mock_pofile, translation_service_openai,
346417
mock_pofile.return_value = mock_po
347418

348419
# Setup translation method mocks for each service
349-
for i, service in enumerate([translation_service_openai, translation_service_anthropic, translation_service_deepseek]):
420+
for i, service in enumerate([translation_service_openai,
421+
translation_service_anthropic,
422+
translation_service_deepseek,
423+
translation_service_azure_openai]):
350424
# Create a fresh mock for each service
351425
mock_po_new = MagicMock()
352426
mock_po_new.__iter__.return_value = mock_entries
@@ -365,8 +439,9 @@ def test_process_po_file_all_providers(mock_pofile, translation_service_openai,
365439
service.get_translations.assert_called_once()
366440
mock_po_new.save.assert_called_once()
367441

442+
368443
@patch('python_gpt_po.services.po_file_handler.POFileHandler.disable_fuzzy_translations')
369-
def test_fuzzy_flag_handling(mock_disable_fuzzy, translation_service_openai, temp_po_file):
444+
def test_fuzzy_flag_handling(mock_disable_fuzzy, translation_service_openai: TranslationService, temp_po_file):
370445
"""Test handling of fuzzy translations."""
371446
# Enable fuzzy flag
372447
translation_service_openai.config.fuzzy = True
@@ -388,7 +463,10 @@ def test_fuzzy_flag_handling(mock_disable_fuzzy, translation_service_openai, tem
388463

389464

390465
def test_validation_model_connection_all_providers(
391-
translation_service_openai, translation_service_anthropic, translation_service_deepseek
466+
translation_service_openai: TranslationService,
467+
translation_service_anthropic: TranslationService,
468+
translation_service_deepseek: TranslationService,
469+
translation_service_azure_openai: TranslationService
392470
):
393471
"""Test validating connection to all providers."""
394472
# Configure OpenAI mock
@@ -398,6 +476,10 @@ def test_validation_model_connection_all_providers(
398476
translation_service_anthropic.config.provider_clients.anthropic_client.messages.create.return_value = MagicMock()
399477

400478
# Configure DeepSeek mock
479+
480+
# Configure Azure OpenAI mock
481+
translation_service_azure_openai.config.provider_clients.azure_openai_client.chat.completions.create.return_value = MagicMock()
482+
401483
with patch('requests.post') as mock_post:
402484
mock_response = MagicMock()
403485
mock_response.raise_for_status = MagicMock()
@@ -407,11 +489,12 @@ def test_validation_model_connection_all_providers(
407489
assert translation_service_openai.validate_provider_connection() is True
408490
assert translation_service_anthropic.validate_provider_connection() is True
409491
assert translation_service_deepseek.validate_provider_connection() is True
492+
assert translation_service_azure_openai.validate_provider_connection() is True
410493

411494

412495
@patch('os.walk')
413496
@patch('polib.pofile')
414-
def test_scan_and_process_po_files(mock_pofile, mock_walk, translation_service_openai):
497+
def test_scan_and_process_po_files(mock_pofile, mock_walk, translation_service_openai: TranslationService):
415498
"""Test scanning and processing PO files."""
416499
# Setup mock directory structure
417500
mock_walk.return_value = [

0 commit comments

Comments
 (0)