@@ -52,7 +52,7 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
5252 return await self .generator .get_chat_completion_async (chat )
5353
5454 def __call__ (self , utterance : str , intent_data : Intent , n_evolutions : int = 1 ) -> list [str ]:
55- """Apply evolutions multiple times."""
55+ """Apply evolutions multiple times (synchronously) ."""
5656 return [self ._evolve (utterance , intent_data ) for _ in range (n_evolutions )]
5757
5858 async def _call_async (self , utterance : str , intent_data : Intent , n_evolutions : int = 1 ) -> list [str ]:
@@ -61,15 +61,29 @@ async def _call_async(self, utterance: str, intent_data: Intent, n_evolutions: i
6161 return await asyncio .gather (* tasks )
6262
6363 def augment (
64- self , dataset : Dataset , split_name : str = Split .TRAIN , n_evolutions : int = 1 , update_split : bool = True
64+ self ,
65+ dataset : Dataset ,
66+ split_name : str = Split .TRAIN ,
67+ n_evolutions : int = 1 ,
68+ update_split : bool = True ,
69+ batch_size : int | None = None
6570 ) -> list [Sample ]:
6671 """
6772 Augment some split of dataset.
6873
6974 Note that for now it supports only single-label datasets.
7075 """
7176 if self .async_mode :
72- return asyncio .run (self ._augment_async (dataset , split_name , n_evolutions , update_split ))
77+ return asyncio .run (
78+ self ._augment_async (
79+ dataset = dataset ,
80+ split_name = split_name ,
81+ n_evolutions = n_evolutions ,
82+ update_split = update_split ,
83+ batch_size = batch_size
84+ )
85+ )
86+
7387 original_split = dataset [split_name ]
7488 new_samples = []
7589 for sample in original_split :
@@ -80,37 +94,70 @@ def augment(
8094 new_samples .extend (
8195 [{Dataset .label_feature : intent_data .id , Dataset .utterance_feature : ut } for ut in generated_utterances ]
8296 )
97+
8398 if update_split :
8499 generated_split = HFDataset .from_list (new_samples )
85100 dataset [split_name ] = concatenate_datasets ([original_split , generated_split ])
101+
86102 return [Sample (** sample ) for sample in new_samples ]
87103
88104 async def _augment_async (
89- self , dataset : Dataset , split_name : str = Split .TRAIN , n_evolutions : int = 1 , update_split : bool = True
105+ self ,
106+ dataset : Dataset ,
107+ split_name : str = Split .TRAIN ,
108+ n_evolutions : int = 1 ,
109+ update_split : bool = True ,
110+ batch_size : int | None = None
90111 ) -> list [Sample ]:
91- """
92- Augment some split of dataset asynchronously.
93-
94- Note that for now it supports only single-label datasets.
95- """
96112 original_split = dataset [split_name ]
97113 new_samples = []
98- tasks = []
99114
100- for sample in original_split :
101- utterance = sample [Dataset .utterance_feature ]
102- label = sample [Dataset .label_feature ]
103- intent_data = next (intent for intent in dataset .intents if intent .id == label )
104- tasks .append (self ._call_async (utterance = utterance , intent_data = intent_data , n_evolutions = n_evolutions ))
105-
106- results = await asyncio .gather (* tasks )
107-
108- for i , generated_utterances in enumerate (results ):
109- intent_data = next (intent for intent in dataset .intents if intent .id == original_split [i ][
110- Dataset .label_feature ])
111- new_samples .extend (
112- [{Dataset .label_feature : intent_data .id , Dataset .utterance_feature : ut } for ut in generated_utterances ]
113- )
115+ if not batch_size :
116+ tasks = []
117+ for sample in original_split :
118+ utterance = sample [Dataset .utterance_feature ]
119+ label = sample [Dataset .label_feature ]
120+ intent_data = next (intent for intent in dataset .intents if intent .id == label )
121+ tasks .append (
122+ self ._call_async (utterance = utterance , intent_data = intent_data , n_evolutions = n_evolutions )
123+ )
124+
125+ results = await asyncio .gather (* tasks )
126+
127+ for i , generated_utterances in enumerate (results ):
128+ intent_data = next (
129+ intent for intent in dataset .intents if intent .id == original_split [i ][Dataset .label_feature ]
130+ )
131+ new_samples .extend (
132+ [{Dataset .label_feature : intent_data .id , Dataset .utterance_feature : ut }
133+ for ut in generated_utterances ]
134+ )
135+
136+ else :
137+ total_samples = len (original_split )
138+ for start_idx in range (0 , total_samples , batch_size ):
139+ batch = original_split [start_idx : start_idx + batch_size ]
140+ tasks = []
141+ for utterance , label in zip (
142+ batch [Dataset .utterance_feature ],
143+ batch [Dataset .label_feature ],
144+ strict = False
145+ ):
146+ intent_data = next (intent for intent in dataset .intents if intent .id == label )
147+ tasks .append (
148+ self ._call_async (utterance = utterance , intent_data = intent_data , n_evolutions = n_evolutions )
149+ )
150+
151+ batch_results = await asyncio .gather (* tasks )
152+
153+ for i , generated_utterances in enumerate (batch_results ):
154+ intent_data = next (
155+ intent for intent in dataset .intents if intent .id == batch [Dataset .label_feature ][i ]
156+ )
157+ new_samples .extend (
158+ [{Dataset .label_feature : intent_data .id , Dataset .utterance_feature : ut }
159+ for ut in generated_utterances ]
160+ )
114161
115162 if update_split :
116163 generated_split = HFDataset .from_list (new_samples )
0 commit comments