Skip to content

Commit 45c8fc8

Browse files
тесты для адверсариал аугментации
1 parent 66fecba commit 45c8fc8

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

src/autointent/generation/utterances/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from ._basic import DatasetBalancer, UtteranceGenerator
44
from ._evolution import IncrementalUtteranceEvolver, UtteranceEvolver
5+
from ._adversarial import HumanUtteranceGenerator, CriticHumanLike
56

67
__all__ = [
78
"DatasetBalancer",
89
"IncrementalUtteranceEvolver",
910
"UtteranceEvolver",
1011
"UtteranceGenerator",
12+
"HumanUtteranceGenerator",
13+
"CriticHumanLike"
1114
]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .critic_human_like import CriticHumanLike
22
from .human_utterance_generator import HumanUtteranceGenerator
33

4-
__all__ = ["HumanUtteranceGenerator"]
4+
__all__ = ["HumanUtteranceGenerator", "CriticHumanLike"]

src/autointent/generation/utterances/_adversarial/human_utterance_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, str]]
155155

156156
for result in results:
157157
new_samples.extend(result)
158-
158+
for s in new_samples:
159+
s['label'] = int(s['label'])
159160
if update_split:
160161
generated_split = HFDataset.from_list(new_samples)
161162
dataset[split_name] = concatenate_datasets([original_split, generated_split])
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from unittest.mock import Mock, AsyncMock
2+
import pytest
3+
from autointent.generation import Generator
4+
5+
from autointent.generation.utterances import HumanUtteranceGenerator, CriticHumanLike
6+
from autointent import Dataset, Sample
7+
8+
9+
def test_human_utterance_generator_sync(dataset):
10+
mock_llm = Mock()
11+
mock_llm.get_chat_completion.return_value = "Human-like utterance"
12+
13+
mock_critic = Mock(spec=CriticHumanLike)
14+
mock_critic.is_human.return_value = True
15+
16+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
17+
18+
n_before = len(dataset["train_0"])
19+
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=2)
20+
n_after = len(dataset["train_0"])
21+
22+
assert n_before == n_after
23+
assert len(new_samples) > 0
24+
assert all(isinstance(sample, Sample) for sample in new_samples)
25+
assert all("utterance" in sample.dict() for sample in new_samples)
26+
assert all("label" in sample.dict() for sample in new_samples)
27+
28+
29+
def test_human_utterance_generator_async(dataset):
30+
mock_llm = AsyncMock()
31+
mock_llm.get_chat_completion_async.return_value = "Human-like utterance"
32+
33+
mock_critic = AsyncMock(spec=CriticHumanLike)
34+
mock_critic.is_human_async.return_value = True
35+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=True)
36+
37+
n_before = len(dataset["train_0"])
38+
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=2)
39+
n_after = len(dataset["train_0"])
40+
41+
assert n_before == n_after
42+
assert len(new_samples) > 0
43+
assert all(isinstance(sample, Sample) for sample in new_samples)
44+
assert all("utterance" in sample.dict() for sample in new_samples)
45+
assert all("label" in sample.dict() for sample in new_samples)
46+
47+
48+
def test_human_utterance_generator_respects_critic(dataset):
49+
mock_llm = Mock()
50+
mock_llm.get_chat_completion.return_value = "Generated utterance"
51+
52+
mock_critic = Mock(spec=CriticHumanLike)
53+
mock_critic.is_human.side_effect = [False, True]
54+
55+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
56+
57+
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=1)
58+
assert len(new_samples) > 0
59+
assert all(mock_critic.is_human.call_count >= 1 for _ in range(len(new_samples)))

0 commit comments

Comments
 (0)