@@ -21,11 +21,12 @@ class UtteranceGenerator:
2121 punctuation, and length of the desired generations.
2222 """
2323
24- def __init__ (self ,
25- generator : Generator ,
26- prompt_maker : Callable [[Intent , int ], list [Message ]],
27- async_mode : bool = False
28- ) -> None :
24+ def __init__ (
25+ self ,
26+ generator : Generator ,
27+ prompt_maker : Callable [[Intent , int ], list [Message ]],
28+ async_mode : bool = False
29+ ) -> None :
2930 """Initialize."""
3031 self .generator = generator
3132 self .prompt_maker = prompt_maker
@@ -49,27 +50,33 @@ def augment(
4950 split_name : str = Split .TRAIN ,
5051 n_generations : int = 5 ,
5152 update_split : bool = True ,
53+ batch_size : int | None = None
5254 ) -> list [Sample ]:
5355 """
5456 Augment some split of dataset.
5557
56- Note that for now it supports only single-label datasets.
58+ :param dataset: Dataset object
59+ :param split_name: Dataset split (default is TRAIN)
60+ :param n_generations: Number of utterances to generate per intent
61+ :param update_split: Whether to update the dataset split
62+ :param batch_size: Batch size for async generation (None means all at once)
63+ :return: List of generated samples
5764 """
5865 if self .async_mode :
59- return asyncio .run (self ._augment_async (dataset , split_name , n_generations , update_split ))
66+ return asyncio .run (self ._augment_async (dataset , split_name , n_generations , update_split , batch_size ))
67+
6068 original_split = dataset [split_name ]
6169 new_samples = []
6270 for intent in dataset .intents :
63- generated_utterances = self (
64- intent_data = intent ,
65- n_generations = n_generations ,
66- )
71+ generated_utterances = self (intent_data = intent , n_generations = n_generations )
6772 new_samples .extend (
6873 [{Dataset .label_feature : intent .id , Dataset .utterance_feature : ut } for ut in generated_utterances ]
6974 )
75+
7076 if update_split :
7177 generated_split = HFDataset .from_list (new_samples )
7278 dataset [split_name ] = concatenate_datasets ([original_split , generated_split ])
79+
7380 return [Sample (** sample ) for sample in new_samples ]
7481
7582 async def _augment_async (
@@ -78,19 +85,32 @@ async def _augment_async(
7885 split_name : str = Split .TRAIN ,
7986 n_generations : int = 5 ,
8087 update_split : bool = True ,
88+ batch_size : int | None = None
8189 ) -> list [Sample ]:
8290 """
83- Augment some split of dataset asynchronously.
84-
85- Note that for now it supports only single-label datasets.
91+ Augment some split of dataset asynchronously in batches.
92+
93+ :param dataset: Dataset object
94+ :param split_name: Dataset split (default is TRAIN)
95+ :param n_generations: Number of utterances to generate per intent
96+ :param update_split: Whether to update the dataset split
97+ :param batch_size: Batch size for async generation (None means all at once)
98+ :return: List of generated samples
8699 """
87100 original_split = dataset [split_name ]
88101 new_samples = []
89- tasks = []
90102
91- tasks = [self ._call_async (intent_data = intent , n_generations = n_generations ) for intent in dataset .intents ]
103+ if not batch_size :
104+ tasks = [self ._call_async (intent_data = intent , n_generations = n_generations ) for intent in dataset .intents ]
105+ results = await asyncio .gather (* tasks )
92106
93- results = await asyncio .gather (* tasks )
107+ else :
108+ results = []
109+ for start_idx in range (0 , len (dataset .intents ), batch_size ):
110+ batch_intents = dataset .intents [start_idx :start_idx + batch_size ]
111+ tasks = [self ._call_async (intent_data = intent , n_generations = n_generations ) for intent in batch_intents ]
112+ batch_results = await asyncio .gather (* tasks )
113+ results .extend (batch_results )
94114
95115 for i , generated_utterances in enumerate (results ):
96116 intent = dataset .intents [i ]
@@ -113,4 +133,4 @@ def _extract_utterances(response_text: str) -> list[str]:
113133 """
114134 raw_utterances = response_text .split ("\n " )
115135 # remove enumeration
116- return [ut [ut .find (" " ) + 1 :] for ut in raw_utterances ]
136+ return [ut [ut .find (" " ) + 1 :] if " " in ut else ut for ut in raw_utterances ]
0 commit comments