Skip to content

Commit c18698e

Browse files
async
1 parent 04a3ef6 commit c18698e

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

autointent/generation/utterances/_adversarial/human_utterance_generator.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import logging
33
import random
44
from collections import defaultdict
5+
from functools import partial
56

7+
import aiometer
68
from datasets import Dataset as HFDataset
79
from datasets import concatenate_datasets
810

@@ -15,8 +17,6 @@
1517
from .critic_human_like import CriticHumanLike
1618

1719
logger = logging.getLogger(__name__)
18-
19-
2020
class HumanUtteranceGenerator:
2121
"""Generator of human-like utterances.
2222
@@ -110,29 +110,38 @@ async def augment_async(
110110
for sample in original_split:
111111
class_to_samples[sample["label"]].append(sample["utterance"])
112112

113-
for intent_id, intent_name in id_to_name.items():
113+
114+
async def generate_one(intent_id: str, intent_name: str) -> list[dict]:
114115
if intent_name is None:
115116
logger.warning("Intent with id %s has no name! Skipping it...", intent_id)
116-
continue
117-
generated_count = 0
118-
attempt = 0
119-
seed_utterances = class_to_samples.get(intent_id, [])
120-
if not seed_utterances:
121-
continue
122-
123-
while generated_count < n_final_per_class and attempt < n_final_per_class * 3:
124-
attempt += 1
117+
generated = []
118+
attempts = 0
119+
seed_utterances = class_to_samples[intent_id]
120+
while len(generated) < n_final_per_class and attempts < n_final_per_class * 3:
121+
attempts += 1
125122
seed_examples = random.sample(seed_utterances, k=min(3, len(seed_utterances)))
126-
rejected: list[str] = []
123+
rejected = []
127124

128125
for _ in range(3):
129126
prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected)
130-
generated = (await self.generator.get_chat_completion_async([prompt])).strip()
131-
if await self.critic.is_human_async(generated, intent_name):
132-
new_samples.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: generated})
133-
generated_count += 1
127+
utterance = (await self.generator.get_chat_completion_async([prompt])).strip()
128+
if await self.critic.is_human_async(utterance, intent_name):
129+
generated.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: utterance})
134130
break
135-
rejected.append(generated)
131+
rejected.append(utterance)
132+
return generated
133+
tasks = [
134+
partial(generate_one, intent_id, intent_name)
135+
for intent_id, intent_name in id_to_name.items()
136+
if class_to_samples.get(intent_id) and intent_name is not None
137+
]
138+
139+
results = await aiometer.run_all(
140+
tasks, max_at_once=5, max_per_second=10
141+
)
142+
143+
for result in results:
144+
new_samples.extend(result)
136145

137146
if update_split:
138147
generated_split = HFDataset.from_list(new_samples)

0 commit comments

Comments
 (0)