@@ -51,9 +51,21 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
5151 chat = maker (utterance , intent_data )
5252 return await self .generator .get_chat_completion_async (chat )
5353
54- def __call__ (self , utterance : str , intent_data : Intent , n_evolutions : int = 1 ) -> list [str ]:
54+ def __call__ (
55+ self , utterance : str , intent_data : Intent , n_evolutions : int = 1 , sequential : bool = False
56+ ) -> list [str ]:
5557 """Apply evolutions multiple times (synchronously)."""
56- return [self ._evolve (utterance , intent_data ) for _ in range (n_evolutions )]
58+ current_utterance = utterance
59+ generated_utterances = []
60+
61+ for _ in range (n_evolutions ):
62+ gen_utt = self ._evolve (current_utterance , intent_data )
63+ generated_utterances .append (gen_utt )
64+
65+ if sequential :
66+ current_utterance = gen_utt
67+
68+ return generated_utterances
5769
5870 def augment (
5971 self ,
@@ -62,13 +74,18 @@ def augment(
6274 n_evolutions : int = 1 ,
6375 update_split : bool = True ,
6476 batch_size : int = 4 ,
77+ sequential : bool = False ,
6578 ) -> HFDataset :
6679 """
6780 Augment some split of dataset.
6881
6982 Note that for now it supports only single-label datasets.
7083 """
7184 if self .async_mode :
85+ if sequential :
86+ error = "Sequential and async modes are not compatible"
87+ raise ValueError (error )
88+
7289 return asyncio .run (
7390 self ._augment_async (
7491 dataset = dataset ,
@@ -85,7 +102,9 @@ def augment(
85102 utterance = sample [Dataset .utterance_feature ]
86103 label = sample [Dataset .label_feature ]
87104 intent_data = next (intent for intent in dataset .intents if intent .id == label )
88- generated_utterances = self (utterance = utterance , intent_data = intent_data , n_evolutions = n_evolutions )
105+ generated_utterances = self (
106+ utterance = utterance , intent_data = intent_data , n_evolutions = n_evolutions , sequential = sequential
107+ )
89108 new_samples .extend (
90109 [{Dataset .label_feature : intent_data .id , Dataset .utterance_feature : ut } for ut in generated_utterances ]
91110 )
0 commit comments