Skip to content

Commit 40d2986

Browse files
adversarial augmentation (#251)
* adversarial augmentation * async mode, lint, typing, etc * async update (aiometer), в гугл колабе ускорилось в раза 2-3 * async * mypy fix * mypy again and init * fix * move to proper directory * run formatter * тесты для адверсариал аугментации * исправление ошибок * опять ошибка * я не сдамся * add disclaimer --------- Co-authored-by: voorhs <[email protected]>
1 parent 13d8c40 commit 40d2986

File tree

5 files changed

+359
-0
lines changed

5 files changed

+359
-0
lines changed

src/autointent/generation/utterances/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Generative methods for enriching dataset with synthetic samples."""
22

3+
from ._adversarial import CriticHumanLike, HumanUtteranceGenerator
34
from ._basic import DatasetBalancer, UtteranceGenerator
45
from ._evolution import IncrementalUtteranceEvolver, UtteranceEvolver
56

67
__all__ = [
8+
"CriticHumanLike",
79
"DatasetBalancer",
10+
"HumanUtteranceGenerator",
811
"IncrementalUtteranceEvolver",
912
"UtteranceEvolver",
1013
"UtteranceGenerator",
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .critic_human_like import CriticHumanLike
2+
from .human_utterance_generator import HumanUtteranceGenerator
3+
4+
__all__ = ["CriticHumanLike", "HumanUtteranceGenerator"]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""CriticHumanLike class for distinguishing human vs generated utterances."""
2+
3+
from typing import Literal
4+
5+
from pydantic import BaseModel
6+
7+
from autointent.generation import Generator
8+
from autointent.generation.chat_templates import Message, Role
9+
10+
11+
class CriticResponse(BaseModel):
12+
"""Structured answer."""
13+
14+
reasoning: str
15+
label: Literal["human", "generated"]
16+
17+
18+
class CriticHumanLike:
19+
"""A simple critic class that classifies user utterances as either 'human' or 'generated'.
20+
21+
using an LLM-based binary classifier prompt.
22+
"""
23+
24+
def __init__(self, generator: Generator, max_retries: int = 3) -> None:
25+
"""Initialize the CriticFirst.
26+
27+
Args:
28+
generator: Wrapper for the LLM API to generate classification responses.
29+
max_retries: Maximum number of attempts to retry classification if the response is invalid.
30+
"""
31+
self.generator = generator
32+
self.max_retries = max_retries
33+
34+
def build_classification_prompt(self, example: str, intent_name: str) -> Message:
35+
"""Args.
36+
37+
example: The user utterance to classify.
38+
intent_name: The name of the intent associated with the utterance.
39+
40+
Returns:
41+
Message: A formatted message prompt for classification.
42+
"""
43+
content = (
44+
"You are a critic that determines whether a user utterance was written by a human or "
45+
"generated by a language model.\n\n"
46+
f"Intent: {intent_name}\n"
47+
f'Utterance: "{example}"\n\n'
48+
"Here is an example of a human-written utterance for this intent:\n"
49+
'"Could you please help me find the nearest coffee shop?"\n\n'
50+
"Respond in **JSON format** with three keys:\n"
51+
"- `reasoning`: a short chain-of-thought where you explain your logic\n"
52+
"- `label`: must be either `human` or `generated`\n"
53+
"Example:\n"
54+
"{\n"
55+
' "reasoning": "The phrasing includes casual contractions and natural hesitation. The utterance '
56+
'flows similarly to how a human would speak spontaneously.",\n'
57+
' "label": "human",\n'
58+
"}"
59+
)
60+
return Message(role=Role.USER, content=content)
61+
62+
def is_human(self, utterance: str, intent_name: str) -> bool:
63+
"""Args.
64+
65+
utterance: The utterance to evaluate.
66+
intent_name: The associated intent.
67+
68+
Returns:
69+
bool: True if classified as human, False otherwise.
70+
"""
71+
message = self.build_classification_prompt(utterance, intent_name)
72+
response = self.generator.get_structured_output_sync(
73+
messages=[message], output_model=CriticResponse, max_retries=self.max_retries
74+
)
75+
return response.label == "human"
76+
77+
async def is_human_async(self, utterance: str, intent_name: str) -> bool:
78+
message = self.build_classification_prompt(utterance, intent_name)
79+
80+
response = await self.generator.get_structured_output_async(
81+
messages=[message], output_model=CriticResponse, max_retries=self.max_retries
82+
)
83+
return response.label == "human"
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import asyncio
2+
import logging
3+
import random
4+
from collections import defaultdict
5+
from functools import partial
6+
from typing import Any
7+
8+
import aiometer
9+
from datasets import Dataset as HFDataset
10+
from datasets import concatenate_datasets
11+
12+
from autointent import Dataset
13+
from autointent.custom_types import Split
14+
from autointent.generation import Generator
15+
from autointent.generation.chat_templates._evolution_templates_schemas import Message, Role
16+
from autointent.schemas import Sample
17+
18+
from .critic_human_like import CriticHumanLike
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class HumanUtteranceGenerator:
24+
"""Generator of human-like utterances.
25+
26+
This class rewrites given user utterances to make them sound more natural and human-like,
27+
while preserving their original intent. The generation process is iterative and attempts
28+
to bypass a critic that identifies machine-generated text.
29+
30+
.. warning:: This method is experimental and can yield inferior data quality.
31+
32+
"""
33+
34+
def __init__(
35+
self,
36+
generator: Generator,
37+
critic: CriticHumanLike,
38+
async_mode: bool = False,
39+
max_at_once: int = 5,
40+
max_per_second: int = 10,
41+
) -> None:
42+
"""Initialize the HumanUtteranceGeneratoror.
43+
44+
Args:
45+
generator: Wrapper for the LLM API used to generate utterances.
46+
critic: Critic to determine whether the generated utterance sounds human-like.
47+
async_mode: Whether to use asynchronous mode for generation.
48+
max_at_once: Maximum number of concurrent async tasks.
49+
max_per_second: Maximum number of tasks per second.
50+
"""
51+
self.generator = generator
52+
self.critic = critic
53+
self.async_mode = async_mode
54+
self.max_at_once = max_at_once
55+
self.max_per_second = max_per_second
56+
57+
def augment(
58+
self, dataset: Dataset, split_name: str = Split.TRAIN, update_split: bool = True, n_final_per_class: int = 5
59+
) -> list[Sample]:
60+
"""Generate human-like utterances for each intent by iteratively refining machine-generated candidates.
61+
62+
Args:
63+
dataset: The dataset to augment.
64+
split_name: The name of the split to augment (e.g., 'train').
65+
update_split: Whether to update the dataset split with the new utterances.
66+
n_final_per_class: Number of successful utterances to generate per intent.
67+
68+
Returns:
69+
list[Sample]: List of newly generated samples.
70+
"""
71+
if self.async_mode:
72+
return asyncio.run(
73+
self.augment_async(
74+
dataset=dataset,
75+
split_name=split_name,
76+
update_split=update_split,
77+
n_final_per_class=n_final_per_class,
78+
)
79+
)
80+
original_split = dataset[split_name]
81+
id_to_name = {intent.id: intent.name for intent in dataset.intents}
82+
new_samples = []
83+
84+
class_to_samples = defaultdict(list)
85+
for sample in original_split:
86+
class_to_samples[sample["label"]].append(sample["utterance"])
87+
88+
for intent_id, intent_name in id_to_name.items():
89+
if intent_name is None:
90+
logger.warning("Intent with id %s has no name! Skipping it...", intent_id)
91+
continue
92+
generated_count = 0
93+
attempt = 0
94+
95+
seed_utterances = class_to_samples.get(intent_id, [])
96+
if not seed_utterances:
97+
continue
98+
99+
while generated_count < n_final_per_class and attempt < n_final_per_class * 3:
100+
attempt += 1
101+
n_seeds = min(3, len(seed_utterances))
102+
seed_examples = random.sample(seed_utterances, k=n_seeds)
103+
rejected: list[str] = []
104+
105+
for _ in range(3):
106+
prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected)
107+
generated = self.generator.get_chat_completion([prompt]).strip()
108+
if self.critic.is_human(generated, intent_name):
109+
new_samples.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: generated})
110+
generated_count += 1
111+
break
112+
rejected.append(generated)
113+
if update_split:
114+
generated_split = HFDataset.from_list(new_samples)
115+
dataset[split_name] = concatenate_datasets([original_split, generated_split])
116+
117+
return [Sample(**sample) for sample in new_samples]
118+
119+
async def augment_async(
120+
self, dataset: Dataset, split_name: str = Split.TRAIN, update_split: bool = True, n_final_per_class: int = 5
121+
) -> list[Sample]:
122+
original_split = dataset[split_name]
123+
id_to_name = {intent.id: intent.name for intent in dataset.intents}
124+
new_samples = []
125+
126+
class_to_samples = defaultdict(list)
127+
for sample in original_split:
128+
class_to_samples[sample["label"]].append(sample["utterance"])
129+
130+
async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, Any]]:
131+
generated: list[dict[str, Any]] = []
132+
attempts = 0
133+
seed_utterances = class_to_samples[intent_id]
134+
while len(generated) < n_final_per_class and attempts < n_final_per_class * 3:
135+
attempts += 1
136+
seed_examples = random.sample(seed_utterances, k=min(3, len(seed_utterances)))
137+
rejected: list[str] = []
138+
139+
for _ in range(3):
140+
prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected)
141+
utterance = (await self.generator.get_chat_completion_async([prompt])).strip()
142+
if await self.critic.is_human_async(utterance, intent_name):
143+
generated.append({Dataset.label_feature: int(intent_id), Dataset.utterance_feature: utterance})
144+
break
145+
rejected.append(utterance)
146+
return generated
147+
148+
tasks = [
149+
partial(generate_one, str(intent_id), intent_name)
150+
for intent_id, intent_name in id_to_name.items()
151+
if class_to_samples.get(intent_id) and intent_name is not None
152+
]
153+
154+
results = await aiometer.run_all(
155+
tasks,
156+
max_at_once=self.max_at_once,
157+
max_per_second=self.max_per_second,
158+
)
159+
160+
for result in results:
161+
new_samples.extend(result)
162+
if update_split:
163+
generated_split = HFDataset.from_list(new_samples)
164+
dataset[split_name] = concatenate_datasets([original_split, generated_split])
165+
166+
return [Sample(**sample) for sample in new_samples]
167+
168+
def _build_adversarial_prompt(self, intent_name: str, seed_examples: list[str], rejected: list[str]) -> Message:
169+
"""Build a few-shot prompt.
170+
171+
Build a few-shot prompt to guide the generator to create a new human-like utterance
172+
from scratch based on the intent name and example utterances.
173+
174+
Args:
175+
intent_name: The intent of the utterance.
176+
seed_examples: List of 1-3 example utterances for the intent.
177+
rejected: List of previously rejected generations.
178+
179+
Returns:
180+
Message: A formatted prompt instructing the generator to produce a new natural-sounding utterance..
181+
"""
182+
rejected_block = "\n".join(f"- {r}" for r in rejected) if rejected else "None"
183+
examples_block = "\n".join(f'- "{ex}"' for ex in seed_examples)
184+
content = (
185+
f"Your task is to generate a new user utterance that fits the intent '{intent_name}'.\n\n"
186+
"Here are some examples of utterances for this intent:\n"
187+
f"{examples_block}\n\n"
188+
"Preserving its original intent: "
189+
f"'{intent_name}'.\n\n"
190+
f"The following previous attempts were classified as machine-generated and rejected:\n{rejected_block}\n\n"
191+
"Try to write something that would pass as written by a real human. Output a single version only.\n"
192+
"IMPORTANT: You must modify the original utterance."
193+
)
194+
return Message(role=Role.USER, content=content)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from unittest.mock import AsyncMock, Mock
2+
3+
import pytest
4+
5+
from autointent import Dataset
6+
from autointent.generation.utterances import CriticHumanLike, HumanUtteranceGenerator
7+
from autointent.schemas import Sample
8+
9+
10+
@pytest.fixture
11+
def dataset():
12+
return Dataset.from_dict(
13+
{
14+
"intents": [
15+
{"id": 0, "name": "Greeting"},
16+
{"id": 1, "name": "OrderFood"},
17+
],
18+
"train": [
19+
{"utterance": "hello", "label": 0},
20+
{"utterance": "hi there", "label": 0},
21+
{"utterance": "i want pizza", "label": 1},
22+
],
23+
}
24+
)
25+
26+
27+
def test_human_utterance_generator_sync(dataset):
28+
mock_llm = Mock()
29+
mock_llm.get_chat_completion.return_value = "Human-like utterance"
30+
31+
mock_critic = Mock(spec=CriticHumanLike)
32+
mock_critic.is_human.return_value = True
33+
34+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
35+
36+
n_before = len(dataset["train"])
37+
new_samples = generator.augment(dataset, split_name="train", update_split=False, n_final_per_class=2)
38+
n_after = len(dataset["train"])
39+
40+
assert n_before == n_after
41+
assert len(new_samples) > 0
42+
assert all(isinstance(sample, Sample) for sample in new_samples)
43+
assert all("utterance" in sample.dict() for sample in new_samples)
44+
assert all("label" in sample.dict() for sample in new_samples)
45+
46+
47+
def test_human_utterance_generator_async(dataset):
48+
mock_llm = AsyncMock()
49+
mock_llm.get_chat_completion_async.return_value = "Human-like utterance"
50+
51+
mock_critic = AsyncMock(spec=CriticHumanLike)
52+
mock_critic.is_human_async.return_value = True
53+
54+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=True)
55+
56+
n_before = len(dataset["train"])
57+
new_samples = generator.augment(dataset, split_name="train", update_split=False, n_final_per_class=2)
58+
n_after = len(dataset["train"])
59+
assert n_before == n_after
60+
assert len(new_samples) > 0
61+
assert all(isinstance(sample, Sample) for sample in new_samples)
62+
assert all("utterance" in sample.dict() for sample in new_samples)
63+
assert all("label" in sample.dict() for sample in new_samples)
64+
65+
66+
def test_human_utterance_generator_respects_critic(dataset):
67+
mock_llm = Mock()
68+
mock_llm.get_chat_completion.return_value = "Generated utterance"
69+
70+
mock_critic = Mock(spec=CriticHumanLike)
71+
mock_critic.is_human.return_value = True
72+
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
73+
new_samples = generator.augment(dataset, split_name="train", update_split=False, n_final_per_class=1)
74+
assert len(new_samples) > 0
75+
assert mock_critic.is_human.call_count >= 1

0 commit comments

Comments
 (0)