Skip to content

Commit 9e4b905

Browse files
committed
feat: tests for generator
1 parent ae65a55 commit 9e4b905

File tree

3 files changed

+53
-21
lines changed

3 files changed

+53
-21
lines changed

tests/generation/utterances/test_basic_synthesizer.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from unittest.mock import AsyncMock, Mock
22

3-
import pytest
4-
53
from autointent.generation.utterances import SynthesizerChatTemplate, UtteranceGenerator
64

75

@@ -53,23 +51,6 @@ def test_on_dataset(dataset):
5351
assert n_before + len(new_samples) == n_after
5452
assert len(new_samples) == len(dataset.intents)
5553

56-
@pytest.mark.asyncio
57-
async def test_default_chat_template_async(dataset):
58-
template = SynthesizerChatTemplate(dataset, split="train_0")
59-
prompt = template(dataset.intents[0], n_examples=1)
60-
for msg in prompt:
61-
assert not has_unfilled_fields(msg["content"])
62-
assert "extra_instructions" not in prompt
63-
64-
65-
@pytest.mark.asyncio
66-
async def test_extra_instructions_async(dataset):
67-
template = SynthesizerChatTemplate(dataset, split="train_0", extra_instructions="football")
68-
prompt = template(dataset.intents[0], n_examples=1)[0]["content"]
69-
assert "extra_instructions" not in prompt
70-
assert "football" in prompt
71-
72-
7354
def test_on_dataset_async(dataset):
7455
mock_llm = AsyncMock()
7556
mock_llm.get_chat_completion_async.return_value = "1. LLM answer"

tests/generation/utterances/test_evolver.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock
1+
from unittest.mock import AsyncMock, Mock
22

33
from autointent.generation.utterances import AbstractEvolution, UtteranceEvolver
44

@@ -32,3 +32,27 @@ def test_on_dataset(dataset):
3232

3333
assert n_before + len(new_samples) == n_after
3434
assert len(new_samples) == n_before
35+
36+
def test_on_dataset_evolver_async(dataset):
37+
mock_llm = AsyncMock()
38+
mock_llm.get_chat_completion_async.return_value = "LLM answer"
39+
40+
split_name = "train_0"
41+
42+
template = AbstractEvolution()
43+
augmenter = UtteranceEvolver(mock_llm, [template], async_mode=True)
44+
45+
n_before = len(dataset[split_name])
46+
new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=False)
47+
n_after = len(dataset[split_name])
48+
49+
assert n_before == n_after
50+
assert len(new_samples) == n_before
51+
assert all(sample.utterance == "LLM answer" for sample in new_samples)
52+
53+
n_before = len(dataset[split_name])
54+
new_samples = augmenter.augment(dataset, split_name=split_name, n_evolutions=1, update_split=True)
55+
n_after = len(dataset[split_name])
56+
57+
assert n_before + len(new_samples) == n_after
58+
assert len(new_samples) == n_before

tests/generation/utterances/test_generator.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock, patch
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
44

@@ -36,3 +36,30 @@ def test_get_chat_completion(mock_openai_client):
3636

3737
assert response == "Test response"
3838
mock_openai_client.return_value.chat.completions.create.assert_called_once()
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_get_chat_completion_async():
43+
test_messages = [Message(role="user", content="Hello, how are you?")]
44+
45+
mock_response = MagicMock()
46+
mock_response.choices = [MagicMock(message=MagicMock(content="I'm fine, thank you!"))]
47+
48+
with patch("openai.AsyncOpenAI") as mock_async_openai:
49+
mock_instance = mock_async_openai.return_value
50+
mock_instance.chat.completions.create = AsyncMock(return_value=mock_response)
51+
52+
generator = Generator()
53+
54+
result = await generator.get_chat_completion_async(test_messages)
55+
56+
assert result == "I'm fine, thank you!"
57+
58+
mock_instance.chat.completions.create.assert_awaited_once_with(
59+
messages=test_messages,
60+
model=generator.model_name,
61+
max_tokens=150,
62+
n=1,
63+
stop=None,
64+
temperature=0.7,
65+
)

0 commit comments

Comments
 (0)