Skip to content

Commit 7c5c65f

Browse files
committed
employ better multi-label stratifier than skmultilearn
1 parent a1a9adc commit 7c5c65f

File tree

4 files changed

+16
-20
lines changed

4 files changed

+16
-20
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ requires-python = ">=3.10,<3.13"
3333
dependencies = [
3434
"sentence-transformers (>=3,<4)",
3535
"scikit-learn (>=1.5,<2.0)",
36-
"scikit-multilearn (==0.2.0)",
36+
"iterative-stratification (>=0.1.9)",
3737
"appdirs (>=1.4,<2.0)",
3838
"optuna (>=4.0.0,<5.0.0)",
3939
"pathlib (>=1.0.1,<2.0.0)",
@@ -253,7 +253,7 @@ module = [
253253
"xeger",
254254
"appdirs",
255255
"sre_yield",
256-
"skmultilearn.model_selection",
256+
"iterstrat.ml_stratifiers",
257257
"hydra",
258258
"hydra.*",
259259
"transformers",

src/autointent/context/data_handler/_stratification.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
"""
66

77
import logging
8-
import random
98
from collections.abc import Sequence
109

1110
import numpy as np
1211
from datasets import Dataset as HFDataset
1312
from datasets import concatenate_datasets
13+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
1414
from numpy import typing as npt
1515
from sklearn.model_selection import train_test_split
16-
from skmultilearn.model_selection import IterativeStratification
1716

1817
from autointent import Dataset
1918
from autointent.custom_types import LabelType
@@ -155,13 +154,10 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np
155154
Returns:
156155
A sequence containing indices for train and test splits.
157156
"""
158-
if self.random_seed is not None:
159-
# Set all seeds for reproducibility (workaround for bugs in IterativeStratification from skmultilearn)
160-
random.seed(self.random_seed)
161-
splitter = IterativeStratification(
162-
n_splits=2,
163-
order=2,
164-
sample_distribution_per_fold=[test_size, 1.0 - test_size],
157+
splitter = MultilabelStratifiedShuffleSplit(
158+
n_splits=1,
159+
test_size=test_size,
160+
random_state=self.random_seed,
165161
)
166162
return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature])))
167163

tests/data/test_data_handler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def test_data_handler_multilabel_mode(sample_multilabel_data):
8989
assert handler.multilabel is True
9090
assert handler.dataset.n_classes == 2
9191
assert handler.train_utterances(0) == [
92-
"hey, how's it going?",
92+
"farewell and see you later",
93+
"good morning",
9394
"so long and take care",
94-
"hello, nice to meet you",
95-
"later, see you soon",
95+
"greetings and salutations",
9696
]
9797
assert handler.test_utterances() == ["greetings", "farewell"]
98-
assert handler.train_labels(0) == [[1, 0], [0, 1], [0, 1], [1, 0]]
98+
assert handler.train_labels(0) == [[0, 1], [1, 0], [0, 1], [1, 0]]
9999
assert handler.test_labels() == [[0, 1], [1, 0]]
100100

101101

@@ -239,6 +239,6 @@ def test_few_shot_split(dataset):
239239
}
240240

241241
for data_split in dh.dataset:
242-
assert (
243-
Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split]
244-
), f"Failed for {data_split}"
242+
assert Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split], (
243+
f"Failed for {data_split}"
244+
)

tests/data/test_stratificaiton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def test_multilabel_train_test_split(dataset_unsplitted):
3838

3939
assert Split.TRAIN in dataset
4040
assert Split.TEST in dataset
41-
assert dataset[Split.TRAIN].num_rows == 18
42-
assert dataset[Split.TEST].num_rows == 18
41+
assert dataset[Split.TRAIN].num_rows == 19
42+
assert dataset[Split.TEST].num_rows == 17
4343
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)
4444

4545

0 commit comments

Comments
 (0)