11"""Module for balancing datasets through augmentation of underrepresented classes."""
22
33from collections import defaultdict
4- from typing import List
4+ from collections . abc import Callable
55
66from autointent import Dataset
77from autointent .custom_types import Split
8- from autointent .generation .utterances .evolution .evolver import UtteranceEvolver
9- from autointent .generation .utterances .generator import Generator
108from autointent .generation .utterances .basic .utterance_generator import UtteranceGenerator
9+ from autointent .generation .utterances .generator import Generator
10+ from autointent .generation .utterances .schemas import Message
11+ from autointent .schemas import Intent
1112
1213
1314class DatasetBalancer :
1415 """Class for balancing dataset through example augmentation."""
1516
16- class DatasetBalancer :
1717 def __init__ (
1818 self ,
1919 generator : Generator ,
20- evolutions : List ,
21- seed : int = 42 ,
20+ prompt_maker : Callable [[Intent , int ], list [Message ]],
2221 async_mode : bool = False ,
2322 max_samples_per_class : int | None = None ,
2423 ) -> None :
25- if not isinstance (generator , Generator ):
26- raise TypeError ("Generator must be an instance of autointent.generation.utterances.generator.Generator" )
27-
28- if not isinstance (evolutions , list ) or not all (callable (e ) for e in evolutions ):
29- raise TypeError ("Evolutions must be a list of callable objects" )
30-
24+ """
25+ Initialize the UtteranceBalancer.
26+
27+ Args:
28+ generator (Generator): The generator object used to create utterances.
29+ prompt_maker (Callable[[Intent, int], list[Message]]): A callable that creates prompts for the generator.
30+ seed (int, optional): The seed for random number generation. Defaults to 42.
31+ async_mode (bool, optional): Whether to run the generator in asynchronous mode. Defaults to False.
32+ max_samples_per_class (int | None, optional): The maximum number of samples per class. Must be a positive integer or None. Defaults to None.
33+ Raises:
34+ ValueError: If max_samples_per_class is not None and is less than or equal to 0.
35+ """
3136 if max_samples_per_class is not None and max_samples_per_class <= 0 :
32- raise ValueError ("max_samples_per_class must be a positive integer or None" )
33-
34- self .evolver = UtteranceGenerator (generator , evolutions , async_mode )
35- self .max_samples = max_samples_per_class
37+ msg = "max_samples_per_class must be a positive integer or None"
38+ raise ValueError (msg )
3639
40+ self .evolver = UtteranceGenerator (generator = generator , prompt_maker = prompt_maker , async_mode = async_mode )
41+ self .max_samples = max_samples_per_class
3742
3843 def balance (
39- self , dataset : Dataset , split : str = Split .TRAIN , n_evolutions : int = 3 , batch_size : int = 4
44+ self , dataset : Dataset , split : str = Split .TRAIN , batch_size : int = 4
4045 ) -> Dataset :
4146 """
4247 Balances the specified dataset split.
@@ -54,12 +59,11 @@ def balance(
5459 class_counts = self ._count_class_examples (dataset , split )
5560 max_count = max (class_counts .values ())
5661 target_count = self .max_samples if self .max_samples is not None else max_count
57- print (f"Target count per class: { target_count } " ) # Добавить логирование
58-
62+ print (f"Target count per class: { target_count } " )
5963 for class_id , current_count in class_counts .items ():
6064 if current_count < target_count :
6165 needed = target_count - current_count
62- self ._augment_class (dataset , split , class_id , needed , n_evolutions , batch_size )
66+ self ._augment_class (dataset , split , class_id , needed , batch_size )
6367
6468 return dataset
6569
@@ -71,13 +75,13 @@ def _count_class_examples(self, dataset: Dataset, split: str) -> dict[int, int]:
7175 return counts
7276
7377 def _augment_class (
74- self , dataset : Dataset , split : str , class_id : int , needed : int , n_evolutions : int , batch_size : int
78+ self , dataset : Dataset , split : str , class_id : int , needed : int , batch_size : int
7579 ) -> None :
7680 """Generate additional examples for the class."""
7781 print ("\n 📂 DATASET BEFORE AUGMENTATION:" )
7882 self ._print_dataset (dataset , split )
7983 intent = next (i for i in dataset .intents if i .id == class_id )
80- class_name = getattr (intent , ' name' , f' class_{ class_id } ' ) # Получаем имя класса, если доступно
84+ class_name = getattr (intent , " name" , f" class_{ class_id } " )
8185 print (f"\n 🚀 Starting augmentation for class { class_id } ({ class_name } )" )
8286 print (f"📊 Initial samples: { len ([s for s in dataset [split ] if s [Dataset .label_feature ] == class_id ])} " )
8387 print (f"🎯 Target needed: { needed } samples" )
@@ -92,7 +96,7 @@ def _augment_class(
9296
9397 while total_generated < needed :
9498 print (f"\n 🔄 Batch generation: { per_sample_evolutions } evolutions per sample" )
95-
99+
96100 generated = self .evolver .augment (
97101 dataset , split_name = split , n_generations = per_sample_evolutions , update_split = True , batch_size = batch_size
98102 )
@@ -101,10 +105,10 @@ def _augment_class(
101105 print (f"✅ Generated { len (generated )} examples" )
102106 if generated :
103107 print ("🔠 Example generated utterances:" )
104- for i , example in enumerate (generated [:3 ]):
108+ for i , example in enumerate (generated [:3 ]):
105109 utterance = getattr (example , Dataset .utterance_feature , str (example ))
106- print (f" { i + 1 } . { utterance [:60 ]} ..." )
107-
110+ print (f" { i + 1 } . { utterance [:60 ]} ..." )
111+
108112 total_generated += len (generated )
109113 print (f"📈 Progress: { total_generated } /{ needed } ({ min (100 , int (total_generated / needed * 100 ))} %)" )
110114
@@ -119,7 +123,6 @@ def _augment_class(
119123 print ("\n 📦 DATASET AFTER AUGMENTATION:" )
120124 self ._print_dataset (dataset , split )
121125 print ("━" * 50 )
122-
123126
124127 def _remove_extra_samples (self , dataset : Dataset , split : str , class_id : int , extra : int ) -> None :
125128 """Remove extra examples of the class."""
@@ -128,13 +131,14 @@ def _remove_extra_samples(self, dataset: Dataset, split: str, class_id: int, ext
128131
129132 new_data = [s for i , s in enumerate (dataset [split ]) if i not in indices_to_remove ]
130133 dataset [split ] = dataset [split ].from_list (new_data )
134+
131135 def _print_dataset (self , dataset : Dataset , split : str ) -> None :
132- """Helper method to print dataset in readable format"""
133- print (f"Split: { split } " )
134- print ("-" * 50 )
135- for i , sample in enumerate (dataset [split ]):
136- label = sample [Dataset .label_feature ]
137- text = sample [Dataset .utterance_feature ]
138- print (f"{ i + 1 :3d} | { label :15} | { text [:50 ]:<50} ..." )
139- print ("-" * 50 )
140- print (f"Total samples: { len (dataset [split ])} \n " )
136+ """Print the dataset in a readable format. """
137+ print (f"Split: { split } " )
138+ print ("-" * 50 )
139+ for i , sample in enumerate (dataset [split ]):
140+ label = sample [Dataset .label_feature ]
141+ text = sample [Dataset .utterance_feature ]
142+ print (f"{ i + 1 :3d} | { label :15} | { text [:50 ]:<50} ..." )
143+ print ("-" * 50 )
144+ print (f"Total samples: { len (dataset [split ])} \n " )
0 commit comments