Skip to content

Commit 791669e

Browse files
Add few shot (#187)
* init few shot * Update optimizer_config.schema.json * apply few shot to all * Update optimizer_config.schema.json * fix test * lint --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 520b7f8 commit 791669e

File tree

6 files changed

+225
-13
lines changed

6 files changed

+225
-13
lines changed

autointent/configs/_optimization.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,22 @@ class DataConfig(BaseModel):
2121
validation_size: FloatFromZeroToOne = Field(
2222
0.2,
2323
description=(
24-
"Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."
24+
"Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split). "
25+
"If `is_few_shot_train` is True, this value will be ignored."
2526
),
2627
)
2728
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
2829
separation_ratio: FloatFromZeroToOne | None = Field(
2930
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
3031
)
3132
"""Set to float to prevent data leak between scoring and decision nodes."""
33+
is_few_shot_train: bool = Field(False, description="Whether to use few-shot training.")
34+
"""Whether to use few-shot training."""
35+
examples_per_intent: PositiveInt = Field(
36+
8,
37+
description="Number of examples per intent for few-shot validation. If None, all examples will be used.",
38+
)
39+
"""Number of examples per intent for few-shot validation. If None, all examples will be used."""
3240

3341

3442
class LoggingConfig(BaseModel):

autointent/context/data_handler/_data_handler.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from autointent.custom_types import FloatFromZeroToOne, ListOfGenericLabels, ListOfLabels, Split
1212
from autointent.schemas import Tag
1313

14-
from ._stratification import split_dataset
14+
from ._stratification import create_few_shot_split, split_dataset
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -48,9 +48,14 @@ def __init__(
4848
self._n_classes = self.dataset.n_classes
4949

5050
if self.config.scheme == "ho":
51-
self._split_ho(self.config.separation_ratio, self.config.validation_size)
51+
self._split_ho(
52+
self.config.separation_ratio,
53+
self.config.validation_size,
54+
self.config.is_few_shot_train,
55+
self.config.examples_per_intent,
56+
)
5257
elif self.config.scheme == "cv":
53-
self._split_cv()
58+
self._split_cv(self.config.is_few_shot_train, self.config.examples_per_intent)
5459

5560
self._logger = logger
5661

@@ -149,8 +154,8 @@ def test_labels(self) -> ListOfGenericLabels:
149154

150155
def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[str], ListOfLabels]]:
151156
"""Yield folds for cross-validation."""
152-
if self.config.scheme == "ho":
153-
msg = "Cannot call cross-validation on hold-out DataHandler"
157+
if self.config.scheme != "cv":
158+
msg = f"Cannot call cross-validation on {self.config.scheme} DataHandler"
154159
raise RuntimeError(msg)
155160

156161
for j in range(self.config.n_folds):
@@ -165,14 +170,22 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
165170
train_labels = [lab for lab in train_labels if lab is not None]
166171
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]
167172

168-
def _split_ho(self, separation_ratio: FloatFromZeroToOne | None, validation_size: FloatFromZeroToOne) -> None:
173+
def _split_ho(
174+
self,
175+
separation_ratio: FloatFromZeroToOne | None,
176+
validation_size: FloatFromZeroToOne,
177+
is_few_shot: bool,
178+
examples_per_intent: int,
179+
) -> None:
169180
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
170181

171182
if separation_ratio is not None and Split.TRAIN in self.dataset:
172183
self._split_train(separation_ratio)
173184

174185
if not has_validation_split:
175-
self._split_validation_from_train(validation_size)
186+
self._split_validation_from_train(validation_size, is_few_shot, examples_per_intent)
187+
elif is_few_shot:
188+
self._split_few_shot(examples_per_intent)
176189

177190
for split in self.dataset:
178191
n_classes_in_split = self.dataset.get_n_classes(split)
@@ -182,6 +195,27 @@ def _split_ho(self, separation_ratio: FloatFromZeroToOne | None, validation_size
182195
)
183196
raise ValueError(message)
184197

198+
def _split_few_shot(self, examples_per_intent: int) -> None:
199+
if Split.TRAIN in self.dataset:
200+
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = create_few_shot_split(
201+
self.dataset[Split.TRAIN],
202+
self.dataset[Split.VALIDATION],
203+
multilabel=self.dataset.multilabel,
204+
label_column=self.dataset.label_feature,
205+
random_seed=self._seed,
206+
examples_per_label=examples_per_intent,
207+
)
208+
else:
209+
for idx in range(2):
210+
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = create_few_shot_split(
211+
self.dataset[f"{Split.TRAIN}_{idx}"],
212+
self.dataset[f"{Split.VALIDATION}_{idx}"],
213+
multilabel=self.dataset.multilabel,
214+
label_column=self.dataset.label_feature,
215+
random_seed=self._seed,
216+
examples_per_label=examples_per_intent,
217+
)
218+
185219
def _split_train(self, ratio: FloatFromZeroToOne) -> None:
186220
"""Split on two sets.
187221
@@ -199,7 +233,7 @@ def _split_train(self, ratio: FloatFromZeroToOne) -> None:
199233
)
200234
self.dataset.pop(Split.TRAIN)
201235

202-
def _split_cv(self) -> None:
236+
def _split_cv(self, is_few_shot: bool, examples_per_intent: int) -> None:
203237
extra_splits = [split_name for split_name in self.dataset if split_name != Split.TEST]
204238
self.dataset[Split.TRAIN] = concatenate_datasets([self.dataset.pop(split_name) for split_name in extra_splits])
205239

@@ -209,17 +243,21 @@ def _split_cv(self) -> None:
209243
split=Split.TRAIN,
210244
test_size=1 / (self.config.n_folds - j),
211245
random_seed=self._seed,
246+
is_few_shot=is_few_shot,
247+
examples_per_intent=examples_per_intent,
212248
allow_oos_in_train=True,
213249
)
214250
self.dataset[f"{Split.TRAIN}_{self.config.n_folds - 1}"] = self.dataset.pop(Split.TRAIN)
215251

216-
def _split_validation_from_train(self, size: float) -> None:
252+
def _split_validation_from_train(self, size: float, is_few_shot: bool, examples_per_intent: int) -> None:
217253
if Split.TRAIN in self.dataset:
218254
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
219255
self.dataset,
220256
split=Split.TRAIN,
221257
test_size=size,
222258
random_seed=self._seed,
259+
is_few_shot=is_few_shot,
260+
examples_per_intent=examples_per_intent,
223261
allow_oos_in_train=True,
224262
)
225263
else:
@@ -229,6 +267,8 @@ def _split_validation_from_train(self, size: float) -> None:
229267
split=f"{Split.TRAIN}_{idx}",
230268
test_size=size,
231269
random_seed=self._seed,
270+
is_few_shot=is_few_shot,
271+
examples_per_intent=examples_per_intent,
232272
allow_oos_in_train=idx == 1, # for decision node it's ok to have oos in train
233273
)
234274

autointent/context/data_handler/_stratification.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
It includes support for both single-label and multi-label stratified splitting.
55
"""
66

7+
import logging
78
from collections.abc import Sequence
89

910
import numpy as np
@@ -17,6 +18,8 @@
1718
from autointent import Dataset
1819
from autointent.custom_types import LabelType
1920

21+
logger = logging.getLogger(__name__)
22+
2023

2124
class StratifiedSplitter:
2225
"""A class for stratified splitting of datasets.
@@ -32,6 +35,8 @@ def __init__(
3235
label_feature: str,
3336
random_seed: int | None,
3437
shuffle: bool = True,
38+
is_few_shot: bool = False,
39+
examples_per_label: int = 8,
3540
) -> None:
3641
"""Initialize the StratifiedSplitter.
3742
@@ -40,11 +45,15 @@ def __init__(
4045
label_feature: Name of the feature containing labels for stratification.
4146
random_seed: Seed for random number generation to ensure reproducibility.
4247
shuffle: Whether to shuffle the data before splitting.
48+
is_few_shot: Whether the dataset is a few-shot dataset.
49+
examples_per_label: Number of examples per label for few-shot datasets.
4350
"""
4451
self.test_size = test_size
4552
self.label_feature = label_feature
4653
self.random_seed = random_seed
4754
self.shuffle = shuffle
55+
self.is_few_shot = is_few_shot
56+
self.examples_per_label = examples_per_label
4857

4958
def __call__(
5059
self, dataset: HFDataset, multilabel: bool, allow_oos_in_train: bool | None = None
@@ -71,7 +80,16 @@ def __call__(
7180
)
7281
raise ValueError(msg)
7382
splitter = self._split_allow_oos_in_train if allow_oos_in_train else self._split_disallow_oos_in_train
74-
return splitter(dataset, multilabel)
83+
train, test = splitter(dataset, multilabel)
84+
if self.is_few_shot:
85+
train, test = create_few_shot_split(
86+
train,
87+
test,
88+
multilabel=multilabel,
89+
label_column=self.label_feature,
90+
examples_per_label=self.examples_per_label,
91+
)
92+
return train, test
7593

7694
def _has_oos_samples(self, dataset: HFDataset) -> bool:
7795
"""Check if the dataset contains out-of-scope samples.
@@ -287,6 +305,8 @@ def split_dataset(
287305
split: str,
288306
test_size: float,
289307
random_seed: int | None,
308+
is_few_shot: bool = False,
309+
examples_per_intent: int = 8,
290310
allow_oos_in_train: bool | None = None,
291311
) -> tuple[HFDataset, HFDataset]:
292312
"""Split a Dataset object into training and testing subsets.
@@ -296,6 +316,8 @@ def split_dataset(
296316
split: The specific data split to divide.
297317
test_size: Proportion of the dataset to include in the test split.
298318
random_seed: Seed for random number generation.
319+
is_few_shot: Whether the dataset is a few-shot dataset.
320+
examples_per_intent: Number of examples per label for few-shot datasets.
299321
allow_oos_in_train: Whether to allow OOS samples in train split.
300322
301323
Returns:
@@ -305,5 +327,74 @@ def split_dataset(
305327
test_size=test_size,
306328
label_feature=dataset.label_feature,
307329
random_seed=random_seed,
330+
is_few_shot=is_few_shot,
331+
examples_per_label=examples_per_intent,
308332
)
309333
return splitter(dataset[split], dataset.multilabel, allow_oos_in_train=allow_oos_in_train)
334+
335+
336+
def create_few_shot_split(
337+
train_dataset: HFDataset,
338+
validation_dataset: HFDataset,
339+
label_column: str,
340+
examples_per_label: int = 8,
341+
multilabel: bool = False,
342+
random_seed: int | None = None,
343+
) -> tuple[HFDataset, HFDataset]:
344+
"""Create a few-shot dataset split with a specified number of examples per label.
345+
346+
Args:
347+
train_dataset: A Hugging Face dataset or DatasetDict
348+
validation_dataset: A Hugging Face dataset or DatasetDict
349+
label_column: The name of the column containing labels (default: 'label')
350+
examples_per_label: Number of examples to include per label in the train split (default: 8)
351+
multilabel: Whether the dataset is multi-label (default: False)
352+
random_seed: Random seed for reproducibility (default: 42)
353+
354+
Returns:
355+
A tuple containing the train and validation datasets.
356+
"""
357+
# Add a unique index column to track examples
358+
train_dataset = train_dataset.add_column("__index__", list(range(len(train_dataset))))
359+
if multilabel:
360+
_unique_labels = set()
361+
for example in train_dataset:
362+
if example[label_column] is not None:
363+
_unique_labels.add(tuple(example[label_column]))
364+
unique_labels = list(_unique_labels)
365+
else:
366+
unique_labels = train_dataset.unique(label_column)
367+
368+
# Create train dataset by sampling examples_per_label for each label
369+
train_datasets = []
370+
selected_indices = set()
371+
372+
for label in unique_labels:
373+
if multilabel:
374+
label_examples = train_dataset.filter(lambda row: tuple(row[label_column]) == label) # noqa: B023
375+
else:
376+
label_examples = train_dataset.filter(lambda row: row[label_column] == label) # noqa: B023
377+
label_examples = label_examples.shuffle(seed=random_seed)
378+
379+
num_to_select = min(examples_per_label, len(label_examples))
380+
selected_examples = label_examples.select(range(num_to_select))
381+
382+
if num_to_select < examples_per_label:
383+
msg = (
384+
f"Warning: Only {num_to_select} examples available for label '{label}', "
385+
f"which is less than the requested {examples_per_label}"
386+
)
387+
logger.warning(msg)
388+
389+
train_datasets.append(selected_examples)
390+
selected_indices.update([ex["__index__"] for ex in selected_examples])
391+
392+
# Create validation split with remaining examples
393+
extra_validation_dataset = train_dataset.filter(
394+
lambda example: example["__index__"] not in selected_indices
395+
).remove_columns("__index__")
396+
397+
validation_dataset = concatenate_datasets([validation_dataset, extra_validation_dataset])
398+
train_dataset = concatenate_datasets(train_datasets).remove_columns("__index__")
399+
400+
return train_dataset, validation_dataset

docs/optimizer_config.schema.json

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
},
7272
"validation_size": {
7373
"default": 0.2,
74-
"description": "Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).",
74+
"description": "Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split). If `is_few_shot_train` is True, this value will be ignored.",
7575
"maximum": 1,
7676
"minimum": 0,
7777
"title": "Validation Size",
@@ -91,6 +91,19 @@
9191
"default": 0.5,
9292
"description": "Set to float to prevent data leak between scoring and decision nodes.",
9393
"title": "Separation Ratio"
94+
},
95+
"is_few_shot_train": {
96+
"default": false,
97+
"description": "Whether to use few-shot training.",
98+
"title": "Is Few Shot Train",
99+
"type": "boolean"
100+
},
101+
"examples_per_intent": {
102+
"default": 8,
103+
"description": "Number of examples per intent for few-shot validation. If None, all examples will be used.",
104+
"exclusiveMinimum": 0,
105+
"title": "Examples Per Intent",
106+
"type": "integer"
94107
}
95108
},
96109
"title": "DataConfig",
@@ -362,7 +375,9 @@
362375
"scheme": "ho",
363376
"n_folds": 3,
364377
"validation_size": 0.2,
365-
"separation_ratio": 0.5
378+
"separation_ratio": 0.5,
379+
"is_few_shot_train": false,
380+
"examples_per_intent": 8
366381
}
367382
},
368383
"search_space": {

tests/data/test_data_handler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import Counter
2+
13
import pytest
24

35
from autointent import Dataset
@@ -223,3 +225,20 @@ def test_cv_iterator(dataset):
223225
assert count_oos_labels(y_train) == specs["train"]["oos"]
224226
assert len(x_val) == len(y_val) == specs["val"]["total"]
225227
assert count_oos_labels(y_val) == specs["val"]["oos"]
228+
229+
230+
def test_few_shot_split(dataset):
231+
dh = DataHandler(dataset, config=DataConfig(scheme="ho", is_few_shot_train=True, examples_per_intent=2))
232+
233+
desired_specs = {
234+
"train_0": {0: 2, 1: 2, 2: 2, 3: 2},
235+
"train_1": {2: 2, 0: 2, None: 2, 1: 1, 3: 1},
236+
"validation_0": {0: 3, 1: 4, 2: 3, 3: 4},
237+
"validation_1": {None: 14, 3: 1, 0: 1, 1: 1, 2: 1},
238+
"test": {None: 4, 0: 2, 2: 2, 3: 2, 1: 2},
239+
}
240+
241+
for data_split in dh.dataset:
242+
assert (
243+
Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split]
244+
), f"Failed for {data_split}"

0 commit comments

Comments
 (0)