Skip to content

Commit fb18058

Browse files
committed
fix: main
1 parent 9e4b905 commit fb18058

File tree

2 files changed

+72
-25
lines changed

2 files changed

+72
-25
lines changed

autointent/generation/utterances/basic/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
logger = logging.getLogger(__name__)
1414

1515

16-
async def main() -> None:
16+
def main() -> None:
1717
"""CLI endpoint."""
1818
parser = ArgumentParser()
1919
parser.add_argument(

autointent/generation/utterances/evolution/evolver.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def _evolve_async(self, utterance: str, intent_data: Intent) -> str:
5252
return await self.generator.get_chat_completion_async(chat)
5353

5454
def __call__(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]:
55-
"""Apply evolutions multiple times."""
55+
"""Apply evolutions multiple times (synchronously)."""
5656
return [self._evolve(utterance, intent_data) for _ in range(n_evolutions)]
5757

5858
async def _call_async(self, utterance: str, intent_data: Intent, n_evolutions: int = 1) -> list[str]:
@@ -61,15 +61,29 @@ async def _call_async(self, utterance: str, intent_data: Intent, n_evolutions: i
6161
return await asyncio.gather(*tasks)
6262

6363
def augment(
64-
self, dataset: Dataset, split_name: str = Split.TRAIN, n_evolutions: int = 1, update_split: bool = True
64+
self,
65+
dataset: Dataset,
66+
split_name: str = Split.TRAIN,
67+
n_evolutions: int = 1,
68+
update_split: bool = True,
69+
batch_size: int | None = None
6570
) -> list[Sample]:
6671
"""
6772
Augment some split of dataset.
6873
6974
Note that for now it supports only single-label datasets.
7075
"""
7176
if self.async_mode:
72-
return asyncio.run(self._augment_async(dataset, split_name, n_evolutions, update_split))
77+
return asyncio.run(
78+
self._augment_async(
79+
dataset=dataset,
80+
split_name=split_name,
81+
n_evolutions=n_evolutions,
82+
update_split=update_split,
83+
batch_size=batch_size
84+
)
85+
)
86+
7387
original_split = dataset[split_name]
7488
new_samples = []
7589
for sample in original_split:
@@ -80,37 +94,70 @@ def augment(
8094
new_samples.extend(
8195
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
8296
)
97+
8398
if update_split:
8499
generated_split = HFDataset.from_list(new_samples)
85100
dataset[split_name] = concatenate_datasets([original_split, generated_split])
101+
86102
return [Sample(**sample) for sample in new_samples]
87103

88104
async def _augment_async(
89-
self, dataset: Dataset, split_name: str = Split.TRAIN, n_evolutions: int = 1, update_split: bool = True
105+
self,
106+
dataset: Dataset,
107+
split_name: str = Split.TRAIN,
108+
n_evolutions: int = 1,
109+
update_split: bool = True,
110+
batch_size: int | None = None
90111
) -> list[Sample]:
91-
"""
92-
Augment some split of dataset asynchronously.
93-
94-
Note that for now it supports only single-label datasets.
95-
"""
96112
original_split = dataset[split_name]
97113
new_samples = []
98-
tasks = []
99114

100-
for sample in original_split:
101-
utterance = sample[Dataset.utterance_feature]
102-
label = sample[Dataset.label_feature]
103-
intent_data = next(intent for intent in dataset.intents if intent.id == label)
104-
tasks.append(self._call_async(utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions))
105-
106-
results = await asyncio.gather(*tasks)
107-
108-
for i, generated_utterances in enumerate(results):
109-
intent_data = next(intent for intent in dataset.intents if intent.id == original_split[i][
110-
Dataset.label_feature])
111-
new_samples.extend(
112-
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
113-
)
115+
if not batch_size:
116+
tasks = []
117+
for sample in original_split:
118+
utterance = sample[Dataset.utterance_feature]
119+
label = sample[Dataset.label_feature]
120+
intent_data = next(intent for intent in dataset.intents if intent.id == label)
121+
tasks.append(
122+
self._call_async(utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions)
123+
)
124+
125+
results = await asyncio.gather(*tasks)
126+
127+
for i, generated_utterances in enumerate(results):
128+
intent_data = next(
129+
intent for intent in dataset.intents if intent.id == original_split[i][Dataset.label_feature]
130+
)
131+
new_samples.extend(
132+
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut}
133+
for ut in generated_utterances]
134+
)
135+
136+
else:
137+
total_samples = len(original_split)
138+
for start_idx in range(0, total_samples, batch_size):
139+
batch = original_split[start_idx : start_idx + batch_size]
140+
tasks = []
141+
for utterance, label in zip(
142+
batch[Dataset.utterance_feature],
143+
batch[Dataset.label_feature],
144+
strict=False
145+
):
146+
intent_data = next(intent for intent in dataset.intents if intent.id == label)
147+
tasks.append(
148+
self._call_async(utterance=utterance, intent_data=intent_data, n_evolutions=n_evolutions)
149+
)
150+
151+
batch_results = await asyncio.gather(*tasks)
152+
153+
for i, generated_utterances in enumerate(batch_results):
154+
intent_data = next(
155+
intent for intent in dataset.intents if intent.id == batch[Dataset.label_feature][i]
156+
)
157+
new_samples.extend(
158+
[{Dataset.label_feature: intent_data.id, Dataset.utterance_feature: ut}
159+
for ut in generated_utterances]
160+
)
114161

115162
if update_split:
116163
generated_split = HFDataset.from_list(new_samples)

0 commit comments

Comments
 (0)