Skip to content

Commit c2010c4

Browse files
authored
fix few shot split (#219)
* fix few shot split * lint
1 parent db7356d commit c2010c4

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

autointent/context/data_handler/_stratification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,16 @@ def __call__(
7272
ValueError: If OOS samples are present but allow_oos_in_train is not specified.
7373
"""
7474
if not self._has_oos_samples(dataset):
75-
return self._split_without_oos(dataset, multilabel, self.test_size)
75+
train, test = self._split_without_oos(dataset, multilabel, self.test_size)
76+
if self.is_few_shot:
77+
train, test = create_few_shot_split(
78+
train,
79+
test,
80+
multilabel=multilabel,
81+
label_column=self.label_feature,
82+
examples_per_label=self.examples_per_label,
83+
)
84+
return train, test
7685
if allow_oos_in_train is None:
7786
msg = (
7887
"Error while splitting dataset. It contains OOS samples, "

tests/data/test_stratificaiton.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,29 @@ def test_multilabel_train_test_split_few_shot(dataset_unsplitted):
6363
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)
6464

6565

66-
def test_multiclass_train_test_split_few_shot(dataset_unsplitted):
66+
@pytest.mark.parametrize("allow_oos_in_train", [True, False])
67+
def test_multiclass_train_test_split_few_shot(dataset_unsplitted, allow_oos_in_train):
68+
train_num_rows = 10 if allow_oos_in_train else 8
69+
test_num_rows = 26 if allow_oos_in_train else 28
70+
examples_per_intent = 2
71+
6772
dataset = dataset_unsplitted
6873
dataset[Split.TRAIN], dataset[Split.TEST] = split_dataset(
6974
dataset,
7075
split=Split.TRAIN,
7176
test_size=0.5,
7277
random_seed=42,
73-
allow_oos_in_train=False,
78+
allow_oos_in_train=allow_oos_in_train,
7479
is_few_shot=True,
75-
examples_per_intent=2,
80+
examples_per_intent=examples_per_intent,
7681
)
7782

7883
assert Split.TRAIN in dataset
7984
assert Split.TEST in dataset
80-
assert dataset[Split.TRAIN].num_rows == 8
81-
assert dataset[Split.TEST].num_rows == 28
85+
assert dataset[Split.TRAIN].num_rows == train_num_rows
86+
assert dataset[Split.TEST].num_rows == test_num_rows
8287
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)
88+
89+
for class_id in range(dataset.get_n_classes(Split.TRAIN)):
90+
class_ds = dataset[Split.TRAIN].filter(lambda x: x["label"] == class_id) # noqa: B023
91+
assert len(class_ds) <= examples_per_intent

0 commit comments

Comments
 (0)