-
Notifications
You must be signed in to change notification settings - Fork 11
russian prompt and augmentation generation #127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import json | ||
| import random | ||
| from argparse import ArgumentParser | ||
| from pathlib import Path | ||
| from typing import List | ||
| from collections import defaultdict | ||
| from autointent import Dataset | ||
| from autointent.generation.utterances.basic.chat_template import SynthesizerChatTemplateRussian | ||
| from autointent.generation.utterances.basic.utterance_generator import UtteranceGenerator | ||
| from autointent.generation.utterances.generator import Generator | ||
|
|
||
| def process_utterances(generated: List[str]) -> List[str]: | ||
| processed = [] | ||
| for ut in generated: | ||
| if "', '" in ut or "',\n" in ut: | ||
| clean_ut = ut.replace("[", "").replace("]", "").replace("'", "") | ||
| split_ut = [u.strip() for u in clean_ut.split(", ") if u.strip()] | ||
| processed.extend(split_ut) | ||
| else: | ||
| processed.append(ut.strip()) | ||
| return processed | ||
|
|
||
| def main(): | ||
| parser = ArgumentParser() | ||
| parser.add_argument("--input-path", type=str, required=True, help="Path to few-shot dataset") | ||
| parser.add_argument("--output-dir", type=str, required=True, help="Directory to save generated datasets") | ||
| parser.add_argument("--n-augment", type=int, required=True, | ||
| help="Max number of augmented examples per class to generate") | ||
| parser.add_argument("--split-numbers", type=int, nargs="+", required=True, | ||
| help="List of example counts to split into (e.g. 1 2 3 5)") | ||
| parser.add_argument("--max-attempts", type=int, default=5, | ||
| help="Max generation attempts per class") | ||
| parser.add_argument("--seed", type=int, default=42, help="Random seed") | ||
| args = parser.parse_args() | ||
|
|
||
| random.seed(args.seed) | ||
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
|
|
||
| dataset = Dataset.from_json(args.input_path) | ||
| template = SynthesizerChatTemplateRussian(dataset, split="train") | ||
| generator = UtteranceGenerator(Generator(), template) | ||
|
|
||
| augmented_samples = [] | ||
| for intent in dataset.intents: | ||
| print(f"\nProcessing intent: {intent.name} (ID: {intent.id})") | ||
|
|
||
| valid_utterances = [] | ||
| attempts = 0 | ||
|
|
||
| while len(valid_utterances) < args.n_augment and attempts < args.max_attempts: | ||
| needed = args.n_augment - len(valid_utterances) | ||
| generated = generator(intent_data=intent, n_generations=needed) | ||
|
|
||
| processed = process_utterances(generated) | ||
| current_valid = [ | ||
| ut for ut in processed | ||
| if ut and len(ut.split()) > 2 | ||
| ] | ||
| valid_utterances.extend(current_valid) | ||
|
|
||
| print(f"Attempt {attempts+1}: " | ||
| f"Generated {len(current_valid)} valid, " | ||
| f"Total {len(valid_utterances)}/{args.n_augment}") | ||
| attempts += 1 | ||
|
|
||
| if len(valid_utterances) < args.n_augment: | ||
| raise RuntimeError( | ||
| f"Failed to generate {args.n_augment} examples for " | ||
| f"{intent.name} after {args.max_attempts} attempts" | ||
| ) | ||
|
|
||
| augmented_samples.extend([ | ||
| {"utterance": ut, "label": intent.id} | ||
| for ut in valid_utterances[:args.n_augment] | ||
| ]) | ||
|
|
||
| raw_augmented_path = Path(args.output_dir) / "raw_augmented_samples.json" | ||
| with open(raw_augmented_path, "w", encoding="utf-8") as f: | ||
| json.dump({ | ||
| "intents": [{"id": intent.id, "name": intent.name} for intent in dataset.intents], | ||
| "samples": augmented_samples | ||
| }, f, indent=4, ensure_ascii=False) | ||
|
|
||
| splits = {} | ||
| max_num = max(args.split_numbers) | ||
| for n in args.split_numbers: | ||
| if n > max_num: | ||
| raise ValueError(f"Requested {n} examples but max is {max_num}") | ||
|
|
||
| class_to_samples = defaultdict(list) | ||
| for sample in augmented_samples: | ||
| class_to_samples[sample["label"]].append(sample) | ||
|
|
||
| selected = [] | ||
| for class_id, samples in class_to_samples.items(): | ||
| selected.extend(random.sample(samples, k=n)) | ||
|
|
||
| splits[n] = selected | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| original_data = dataset["train"].to_list() | ||
| for n, aug_samples in splits.items(): | ||
| combined = original_data + aug_samples | ||
|
|
||
| new_dataset = Dataset.from_dict({ | ||
| "intents": dataset.intents, | ||
| "train": combined | ||
| }) | ||
|
|
||
| output_path = Path(args.output_dir) / f"dataset_{n}_examples.json" | ||
| new_dataset.to_json(output_path) | ||
| print(f"Saved {len(combined)} examples to {output_path}") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import json | ||
| import os | ||
| from argparse import ArgumentParser | ||
| from collections import defaultdict | ||
| from random import seed, sample | ||
| from autointent import Dataset | ||
|
|
||
| def main() -> None: | ||
| parser = ArgumentParser(description="Create few-shot version of multiclass dataset") | ||
| parser.add_argument("--dataset-name", type=str, required=True, | ||
| help="Hugging Face dataset path (e.g. 'AutoIntent/massive_ru')") | ||
| parser.add_argument("--output-path", type=str, required=True, | ||
| help="Path to save few-shot dataset") | ||
| parser.add_argument("--k-shots", type=int, required=True, | ||
| help="Number of examples per class") | ||
| parser.add_argument("--split", type=str, default="train", | ||
| help="Dataset split to process") | ||
| parser.add_argument("--seed", type=int, default=0, | ||
| help="Random seed for reproducibility") | ||
| args = parser.parse_args() | ||
|
|
||
| seed(args.seed) | ||
|
|
||
| dataset = Dataset.from_hub(args.dataset_name) | ||
|
|
||
| class_to_examples = defaultdict(list) | ||
| for example in dataset[args.split]: | ||
| class_to_examples[example["label"]].append(example["utterance"]) | ||
|
|
||
| fewshot_examples = [] | ||
| for class_id, utterances in class_to_examples.items(): | ||
| if len(utterances) < args.k_shots: | ||
| raise ValueError(f"Class {class_id} has only {len(utterances)} examples") | ||
|
|
||
| selected = sample(utterances, args.k_shots) | ||
| fewshot_examples.extend([ | ||
| {"utterance": utt, "label": class_id} for utt in selected | ||
| ]) | ||
|
|
||
| fewshot_dataset = Dataset.from_dict({ | ||
| "intents": dataset.intents, | ||
| args.split: fewshot_examples | ||
| }) | ||
|
|
||
| fewshot_dataset.to_json(args.output_path) | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.