|
2 | 2 | import logging |
3 | 3 | import random |
4 | 4 | from collections import defaultdict |
| 5 | +from functools import partial |
5 | 6 |
|
| 7 | +import aiometer |
6 | 8 | from datasets import Dataset as HFDataset |
7 | 9 | from datasets import concatenate_datasets |
8 | 10 |
|
|
15 | 17 | from .critic_human_like import CriticHumanLike |
16 | 18 |
|
17 | 19 | logger = logging.getLogger(__name__) |
18 | | - |
19 | | - |
20 | 20 | class HumanUtteranceGenerator: |
21 | 21 | """Generator of human-like utterances. |
22 | 22 |
|
@@ -110,29 +110,38 @@ async def augment_async( |
110 | 110 | for sample in original_split: |
111 | 111 | class_to_samples[sample["label"]].append(sample["utterance"]) |
112 | 112 |
|
113 | | - for intent_id, intent_name in id_to_name.items(): |
| 113 | + |
| 114 | + async def generate_one(intent_id: str, intent_name: str) -> list[dict]: |
114 | 115 | if intent_name is None: |
115 | 116 | logger.warning("Intent with id %s has no name! Skipping it...", intent_id) |
116 | | - continue |
117 | | - generated_count = 0 |
118 | | - attempt = 0 |
119 | | - seed_utterances = class_to_samples.get(intent_id, []) |
120 | | - if not seed_utterances: |
121 | | - continue |
122 | | - |
123 | | - while generated_count < n_final_per_class and attempt < n_final_per_class * 3: |
124 | | - attempt += 1 |
| 117 | + generated = [] |
| 118 | + attempts = 0 |
| 119 | + seed_utterances = class_to_samples[intent_id] |
| 120 | + while len(generated) < n_final_per_class and attempts < n_final_per_class * 3: |
| 121 | + attempts += 1 |
125 | 122 | seed_examples = random.sample(seed_utterances, k=min(3, len(seed_utterances))) |
126 | | - rejected: list[str] = [] |
| 123 | + rejected = [] |
127 | 124 |
|
128 | 125 | for _ in range(3): |
129 | 126 | prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected) |
130 | | - generated = (await self.generator.get_chat_completion_async([prompt])).strip() |
131 | | - if await self.critic.is_human_async(generated, intent_name): |
132 | | - new_samples.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: generated}) |
133 | | - generated_count += 1 |
| 127 | + utterance = (await self.generator.get_chat_completion_async([prompt])).strip() |
| 128 | + if await self.critic.is_human_async(utterance, intent_name): |
| 129 | + generated.append({Dataset.label_feature: intent_id, Dataset.utterance_feature: utterance}) |
134 | 130 | break |
135 | | - rejected.append(generated) |
| 131 | + rejected.append(utterance) |
| 132 | + return generated |
| 133 | + tasks = [ |
| 134 | + partial(generate_one, intent_id, intent_name) |
| 135 | + for intent_id, intent_name in id_to_name.items() |
| 136 | + if class_to_samples.get(intent_id) and intent_name is not None |
| 137 | + ] |
| 138 | + |
| 139 | + results = await aiometer.run_all( |
| 140 | + tasks, max_at_once=5, max_per_second=10 |
| 141 | + ) |
| 142 | + |
| 143 | + for result in results: |
| 144 | + new_samples.extend(result) |
136 | 145 |
|
137 | 146 | if update_split: |
138 | 147 | generated_split = HFDataset.from_list(new_samples) |
|
0 commit comments