Skip to content

Commit cc647ee

Browse files
исправление ошибок
1 parent 45c8fc8 commit cc647ee

File tree

4 files changed

+15
-18
lines changed

4 files changed

+15
-18
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Generative methods for enriching dataset with synthetic samples."""
22

3+
from ._adversarial import CriticHumanLike, HumanUtteranceGenerator
34
from ._basic import DatasetBalancer, UtteranceGenerator
45
from ._evolution import IncrementalUtteranceEvolver, UtteranceEvolver
5-
from ._adversarial import HumanUtteranceGenerator, CriticHumanLike
66

77
__all__ = [
8+
"CriticHumanLike",
89
"DatasetBalancer",
10+
"HumanUtteranceGenerator",
911
"IncrementalUtteranceEvolver",
1012
"UtteranceEvolver",
1113
"UtteranceGenerator",
12-
"HumanUtteranceGenerator",
13-
"CriticHumanLike"
1414
]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .critic_human_like import CriticHumanLike
22
from .human_utterance_generator import HumanUtteranceGenerator
33

4-
__all__ = ["HumanUtteranceGenerator", "CriticHumanLike"]
4+
__all__ = ["CriticHumanLike", "HumanUtteranceGenerator"]

src/autointent/generation/utterances/_adversarial/human_utterance_generator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
from collections import defaultdict
55
from functools import partial
6+
from typing import Any
67

78
import aiometer
89
from datasets import Dataset as HFDataset
@@ -123,8 +124,8 @@ async def augment_async(
123124
for sample in original_split:
124125
class_to_samples[sample["label"]].append(sample["utterance"])
125126

126-
async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, str]]:
127-
generated: list[dict[str, str]] = []
127+
async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, Any]]:
128+
generated: list[dict[str, Any]] = []
128129
attempts = 0
129130
seed_utterances = class_to_samples[intent_id]
130131
while len(generated) < n_final_per_class and attempts < n_final_per_class * 3:
@@ -136,7 +137,7 @@ async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, str]]
136137
prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected)
137138
utterance = (await self.generator.get_chat_completion_async([prompt])).strip()
138139
if await self.critic.is_human_async(utterance, intent_name):
139-
generated.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: utterance})
140+
generated.append({Dataset.label_feature: int(intent_id), Dataset.utterance_feature: utterance})
140141
break
141142
rejected.append(utterance)
142143
return generated
@@ -155,8 +156,6 @@ async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, str]]
155156

156157
for result in results:
157158
new_samples.extend(result)
158-
for s in new_samples:
159-
s['label'] = int(s['label'])
160159
if update_split:
161160
generated_split = HFDataset.from_list(new_samples)
162161
dataset[split_name] = concatenate_datasets([original_split, generated_split])

tests/generation/utterances/test_adversarial.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1-
from unittest.mock import Mock, AsyncMock
2-
import pytest
3-
from autointent.generation import Generator
1+
from unittest.mock import AsyncMock, Mock
42

5-
from autointent.generation.utterances import HumanUtteranceGenerator, CriticHumanLike
6-
from autointent import Dataset, Sample
3+
from autointent import Sample
4+
from autointent.generation.utterances import CriticHumanLike, HumanUtteranceGenerator
75

86

97
def test_human_utterance_generator_sync(dataset):
108
mock_llm = Mock()
119
mock_llm.get_chat_completion.return_value = "Human-like utterance"
12-
10+
1311
mock_critic = Mock(spec=CriticHumanLike)
1412
mock_critic.is_human.return_value = True
1513

1614
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=False)
17-
15+
1816
n_before = len(dataset["train_0"])
1917
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=2)
2018
n_after = len(dataset["train_0"])
@@ -29,7 +27,7 @@ def test_human_utterance_generator_sync(dataset):
2927
def test_human_utterance_generator_async(dataset):
3028
mock_llm = AsyncMock()
3129
mock_llm.get_chat_completion_async.return_value = "Human-like utterance"
32-
30+
3331
mock_critic = AsyncMock(spec=CriticHumanLike)
3432
mock_critic.is_human_async.return_value = True
3533
generator = HumanUtteranceGenerator(mock_llm, mock_critic, async_mode=True)
@@ -56,4 +54,4 @@ def test_human_utterance_generator_respects_critic(dataset):
5654

5755
new_samples = generator.augment(dataset, split_name="train_0", update_split=False, n_final_per_class=1)
5856
assert len(new_samples) > 0
59-
assert all(mock_critic.is_human.call_count >= 1 for _ in range(len(new_samples)))
57+
assert all(mock_critic.is_human.call_count >= 1 for _ in range(len(new_samples)))

0 commit comments

Comments
 (0)