Skip to content

Commit 9d731b3

Browse files
я не сдамся
1 parent 40647a1 commit 9d731b3

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

tests/generation/utterances/test_adversarial.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,29 @@
11
from unittest.mock import AsyncMock, Mock
22

3+
import pytest
4+
5+
from autointent import Dataset
36
from autointent.generation.utterances import CriticHumanLike, HumanUtteranceGenerator
47
from autointent.schemas import Sample
58

69

10+
@pytest.fixture
11+
def dataset():
12+
return Dataset.from_dict(
13+
{
14+
"intents": [
15+
{"id": 0, "name": "Greeting"},
16+
{"id": 1, "name": "OrderFood"},
17+
],
18+
"train": [
19+
{"utterance": "hello", "label": 0},
20+
{"utterance": "hi there", "label": 0},
21+
{"utterance": "i want pizza", "label": 1},
22+
],
23+
}
24+
)
25+
26+
727
def test_human_utterance_generator_sync(dataset):
828
mock_llm = Mock()
929
mock_llm.get_chat_completion.return_value = "Human-like utterance"
@@ -13,9 +33,9 @@ def test_human_utterance_generator_sync(dataset):
1333

1434
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
1535

16-
n_before = len(dataset["train_0"])
17-
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=2)
18-
n_after = len(dataset["train_0"])
36+
n_before = len(dataset["train"])
37+
new_samples = generator.augment(dataset, split_name="train", update_split=False, n_final_per_class=2)
38+
n_after = len(dataset["train"])
1939

2040
assert n_before == n_after
2141
assert len(new_samples) > 0
@@ -30,12 +50,12 @@ def test_human_utterance_generator_async(dataset):
3050

3151
mock_critic = AsyncMock(spec=CriticHumanLike)
3252
mock_critic.is_human_async.return_value = True
33-
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=True)
3453

35-
n_before = len(dataset["train_0"])
36-
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=2)
37-
n_after = len(dataset["train_0"])
54+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=True)
3855

56+
n_before = len(dataset["train"])
57+
new_samples = generator.augment(dataset, split_name="train", update_split=False, n_final_per_class=2)
58+
n_after = len(dataset["train"])
3959
assert n_before == n_after
4060
assert len(new_samples) > 0
4161
assert all(isinstance(sample, Sample) for sample in new_samples)
@@ -48,10 +68,8 @@ def test_human_utterance_generator_respects_critic(dataset):
4868
mock_llm.get_chat_completion.return_value = "Generated utterance"
4969

5070
mock_critic = Mock(spec=CriticHumanLike)
51-
mock_critic.is_human.side_effect = [False, True]
52-
71+
mock_critic.is_human.return_value = True
5372
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
54-
55-
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=1)
73+
new_samples = generator.augment(dataset, split_name="train", update_split=False, n_final_per_class=1)
5674
assert len(new_samples) > 0
57-
assert all(mock_critic.is_human.call_count >= 1 for _ in range(len(new_samples)))
75+
assert mock_critic.is_human.call_count >= 1

0 commit comments

Comments
 (0)