Skip to content

Commit 98c7c84

Browse files
committed
Fixed tests
1 parent 0791b99 commit 98c7c84

File tree

3 files changed

+52
-33
lines changed

3 files changed

+52
-33
lines changed

src/utils/azureopenai/client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from dataclasses import dataclass
32
from typing import Any, Dict, List, Optional
43
import httpx
@@ -31,7 +30,7 @@ async def make_request(
3130
top_p: float = 1.0,
3231
tools: Optional[Any] = None
3332
):
34-
if len(messages) == 1 and messages[0].raw_messages:
33+
if len(messages) == 1 and hasattr(messages[0], "raw_messages"):
3534
message_data = messages[0].raw_messages
3635
else:
3736
message_data = []
@@ -82,4 +81,17 @@ async def make_request(
8281

8382
raise Exception(f"API error ({response.status_code}): {error_msg}")
8483

85-
return response.json()
84+
return response.json()
85+
86+
async def get_completion(
87+
self,
88+
messages: List[Dict[str, Any]],
89+
**kwargs
90+
) -> str:
91+
"""Get just the completion text from a chat request."""
92+
response = await self.make_request(messages, **kwargs)
93+
94+
if not response.get("choices") or len(response["choices"]) == 0:
95+
raise Exception("No completion choices returned from API")
96+
97+
return response["choices"][0]["message"]["content"]

tests/utils/azureopenai/test_azureopenai_chat.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from utils.azureopenai.chat import Chat
3-
from unittest.mock import MagicMock
3+
from unittest.mock import MagicMock, AsyncMock
44
import os
55

66

@@ -16,10 +16,11 @@ def test_chat_create_no_env(monkeypatch):
1616
Chat.create()
1717

1818

19-
def test_chat_send_messages(monkeypatch):
19+
@pytest.mark.asyncio
20+
async def test_chat_send_messages(monkeypatch):
2021
chat = Chat(MagicMock())
21-
chat.client.make_request.return_value = {"choices": [{"message": {"content": "result"}}]}
22-
out = chat.send_messages([{"role": "user", "content": "hi"}])
22+
chat.client.make_request = AsyncMock(return_value={"choices": [{"message": {"content": "result"}}]})
23+
out = await chat.send_messages([{"role": "user", "content": "hi"}])
2324
assert out["choices"][0]["message"]["content"] == "result"
2425
chat.client.make_request.assert_called()
2526

tests/utils/azureopenai/test_azureopenai_client.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,74 @@
11
import pytest
22
from utils.azureopenai.client import Client
3-
from unittest.mock import MagicMock
3+
from unittest.mock import MagicMock, AsyncMock
44
from types import SimpleNamespace
55

66

7-
def test_client_make_request_success(monkeypatch):
7+
@pytest.mark.asyncio
8+
async def test_client_make_request_success(monkeypatch):
89
client = Client(api_key="k", endpoint="http://x", timeout=1)
9-
mock_post = MagicMock()
10+
mock_post = AsyncMock()
1011
mock_resp = MagicMock()
1112
mock_resp.status_code = 200
1213
mock_resp.json.return_value = {"choices": [{"message": {"content": "hi"}}]}
1314
mock_post.return_value = mock_resp
1415
monkeypatch.setattr(client.http_client, "post", mock_post)
15-
# Use a message with raw_messages attribute to avoid AttributeError
16-
msg = SimpleNamespace(role="user", content="hi", raw_messages=[{"role": "user", "content": "hi"}])
17-
out = client.make_request([msg])
16+
17+
# Test with a regular dict instead of SimpleNamespace
18+
out = await client.make_request([{"role": "user", "content": "hi"}])
1819
assert out["choices"][0]["message"]["content"] == "hi"
1920
mock_post.assert_called()
2021

2122

22-
def test_client_make_request_error(monkeypatch):
23+
@pytest.mark.asyncio
24+
async def test_client_make_request_error(monkeypatch):
2325
client = Client(api_key="k", endpoint="http://x", timeout=1)
24-
mock_post = MagicMock()
26+
mock_post = AsyncMock()
2527
mock_resp = MagicMock()
2628
mock_resp.status_code = 400
2729
mock_resp.json.return_value = {"error": {"message": "fail"}}
2830
mock_post.return_value = mock_resp
2931
monkeypatch.setattr(client.http_client, "post", mock_post)
30-
# Use a message with raw_messages attribute to avoid AttributeError
31-
msg = SimpleNamespace(role="user", content="hi", raw_messages=[{"role": "user", "content": "hi"}])
32+
33+
# Test with a regular dict instead of SimpleNamespace
3234
with pytest.raises(Exception) as exc:
33-
client.make_request([msg])
35+
await client.make_request([{"role": "user", "content": "hi"}])
3436
assert "API error" in str(exc.value)
3537

3638

37-
def test_client_get_completion(monkeypatch):
39+
@pytest.mark.asyncio
40+
async def test_client_get_completion(monkeypatch):
3841
client = Client(api_key="k", endpoint="http://x", timeout=1)
39-
mock_make = MagicMock()
40-
mock_make.return_value = {"choices": [{"message": {"content": "hi"}}]}
41-
client.make_request = mock_make
42-
out = client.get_completion([{"role": "user", "content": "hi"}])
42+
# Add the get_completion method implementation for testing
43+
async def mock_make_request(*args, **kwargs):
44+
return {"choices": [{"message": {"content": "hi"}}]}
45+
client.make_request = mock_make_request
46+
out = await client.get_completion([{"role": "user", "content": "hi"}])
4347
assert out == "hi"
4448

4549

46-
def test_client_get_completion_no_choices(monkeypatch):
50+
@pytest.mark.asyncio
51+
async def test_client_get_completion_no_choices(monkeypatch):
4752
client = Client(api_key="k", endpoint="http://x", timeout=1)
48-
mock_make = MagicMock()
49-
mock_make.return_value = {"choices": []}
50-
client.make_request = mock_make
53+
# Add the get_completion method implementation for testing
54+
async def mock_make_request(*args, **kwargs):
55+
return {"choices": []}
56+
client.make_request = mock_make_request
5157
with pytest.raises(Exception):
52-
client.get_completion([{"role": "user", "content": "hi"}])
58+
await client.get_completion([{"role": "user", "content": "hi"}])
5359

5460

55-
def test_client_make_request_with_tools(monkeypatch):
61+
@pytest.mark.asyncio
62+
async def test_client_make_request_with_tools(monkeypatch):
5663
client = Client(api_key="k", endpoint="http://x", timeout=1)
57-
mock_post = MagicMock()
64+
mock_post = AsyncMock()
5865
mock_resp = MagicMock()
5966
mock_resp.status_code = 200
6067
mock_resp.json.return_value = {"choices": [{"message": {"content": "hi"}}]}
6168
mock_post.return_value = mock_resp
6269
monkeypatch.setattr(client.http_client, "post", mock_post)
63-
# Use a message with raw_messages attribute to avoid AttributeError
64-
msg = SimpleNamespace(role="user", content="hi", raw_messages=[{"role": "user", "content": "hi"}])
65-
out = client.make_request([msg], tools=[{"type": "tool"}])
70+
71+
out = await client.make_request([{"role": "user", "content": "hi"}], tools=[{"type": "tool"}])
6672
assert out["choices"][0]["message"]["content"] == "hi"
6773
mock_post.assert_called()
6874

0 commit comments

Comments
 (0)